たれぱんのびぼーろく

わたしの備忘録、生物学とプログラミングが多いかも

PyTorchのプロファイラ

PyTorch Lightning:

Class device unit internal
SimpleProfiler L's action
PyTorchProfiler PyTorch’s Autograd Profiler
AdvancedProfiler
XLAProfiler TPU
(PassThroughProfiler) any (default) Just pass-through, do nothing

SimpleProfilerとDataLoader

データロードの経過時間は _TrainingEpochLoop.train_dataloader_next アクションとして記録される。
worker有りのDataLoaderはprefetchをするため、理想的にはこのアクションがほぼ0になる。
ここが非ゼロということはデータ供給が追い付いていないことを意味する。

Trainerの奥深くにしまい込まれており、実装の理解は面倒。
call関係は以下になっており、Trainer.fitの前処理でprofilerがフックされている。

<Trainer>.__init__(...):
    self.fit_loop = _FitLoop(...)

<Trainer>.fit(...):
    call._call_and_handle_interrupt(self._fit_impl, ...)

<Trainer>._fit_impl(...):
    self._run(model, ...)

<Trainer>._run(model, ...):
    results = self._run_stage()

<Trainer>._run_stage():
    self.fit_loop.run()


<_FitLoop>.__init__(...):
    self.epoch_loop = _TrainingEpochLoop(trainer)

<_FitLoop>.run():
    while not self.done:
        self.advance()

<_FitLoop>.advance():
    self.epoch_loop.run(self._data_fetcher)


<_TrainingEpochLoop>.run(data_fetcher):
    self.on_run_start(data_fetcher)
    while not self.done:
        self.advance(data_fetcher)

<_TrainingEpochLoop>.on_run_start(data_fetcher):
    data_fetcher._start_profiler = self._on_before_fetch
    data_fetcher._stop_profiler  = self._on_after_fetch
<_TrainingEpochLoop>._on_before_fetch():
    self.trainer.profiler.start(f"[{self.__class__.__name__}].train_dataloader_next")
<_TrainingEpochLoop>._on_after_fetch():
    self.trainer.profiler.stop(f"[{self.__class__.__name__}].train_dataloader_next")

<_TrainingEpochLoop>.advance(data_fetcher):
    next(data_fetcher)


<_DataFetcher>.__next__():
    self._start_profiler()
    data = next(self.iterator)
    self._stop_profiler()
    return data

<_PrefetchDataFetcher>.__next__():
    batch = self.batches.pop(0)
    self._fetch_next_batch(self.iterator)
    return batch

<_PrefetchDataFetcher>._fetch_next_batch(iterator):
    self._start_profiler()
    batch = next(iterator)
    self._stop_profiler()
    self.batches.append(batch)

_data_fetcherの初期化に関して:

<_FitLoop>.__init__(...):
    self._combined_loader = None
    self._data_fetcher = None

<_FitLoop>.run():
    self.setup_data()
    self.reset()
    self.on_run_start()
    while not self.done:
        self.on_advance_start()
        self.advance()
        self.on_advance_end()
    self.on_run_end()

<_FitLoop>.on_run_start():
    self._data_fetcher = _select_data_fetcher(trainer)

<_FitLoop>.advance():
    combined_loader = self._combined_loader
    self._data_fetcher.setup(combined_loader)
    self.epoch_loop.run(self._data_fetcher)


def _select_data_fetcher(trainer) -> _DataFetcher:
    lightning_module = trainer.lightning_module
    if trainer.testing:
        step_fx_name = "test_step"
    elif trainer.training:
        step_fx_name = "training_step"
    elif trainer.validating or trainer.sanity_checking:
        step_fx_name = "validation_step"
    elif trainer.predicting:
        step_fx_name = "predict_step"
    else:
        raise RuntimeError(f"DataFetcher is unsupported for {trainer.state.stage}")
    step_fx = getattr(lightning_module, step_fx_name)
    if is_param_in_hook_signature(step_fx, "dataloader_iter", explicit=True):
        rank_zero_warn(
            f"Found `dataloader_iter` argument in the `{step_fx_name}`. Note that the support for "
            "this signature is experimental and the behavior is subject to change."
        )
        return _DataLoaderIterDataFetcher()
    return _PrefetchDataFetcher()