PyTorch Lightning:
Class |
device |
unit |
internal |
SimpleProfiler |
|
L's action |
|
PyTorchProfiler |
|
|
PyTorch’s Autograd Profiler |
AdvancedProfiler |
|
|
|
XLAProfiler |
TPU |
|
|
(PassThroughProfiler ) |
any |
|
(default) Just pass-through, do nothing |
SimpleProfilerとDataLoader
データロードの経過時間は _TrainingEpochLoop.train_dataloader_next
アクションとして記録される。
worker有りのDataLoaderはprefetchをするため、理想的にはこのアクションがほぼ0になる。
ここが非ゼロということはデータ供給が追い付いていないことを意味する。
Trainerの奥深くにしまい込まれており、実装の理解は面倒。
call関係は以下になっており、Trainer.fitの前処理でprofilerがフックされている。
<Trainer>.__init__(...):
self.fit_loop = _FitLoop(...)
<Trainer>.fit(...):
call._call_and_handle_interrupt(self._fit_impl, ...)
<Trainer>._fit_impl(...):
self._run(model, ...)
<Trainer>._run(model, ...):
results = self._run_stage()
<Trainer>._run_stage():
self.fit_loop.run()
<_FitLoop>.__init__(...):
self.epoch_loop = _TrainingEpochLoop(trainer)
<_FitLoop>.run():
while not self.done:
self.advance()
<_FitLoop>.advance():
self.epoch_loop.run(self._data_fetcher)
<_TrainingEpochLoop>.run(data_fetcher):
self.on_run_start(data_fetcher)
while not self.done:
self.advance(data_fetcher)
<_TrainingEpochLoop>.on_run_start(data_fetcher):
data_fetcher._start_profiler = self._on_before_fetch
data_fetcher._stop_profiler = self._on_after_fetch
<_TrainingEpochLoop>._on_before_fetch():
self.trainer.profiler.start(f"[{self.__class__.__name__}].train_dataloader_next")
<_TrainingEpochLoop>._on_after_fetch():
self.trainer.profiler.stop(f"[{self.__class__.__name__}].train_dataloader_next")
<_TrainingEpochLoop>.advance(data_fetcher):
next(data_fetcher)
<_DataFetcher>.__next__():
self._start_profiler()
data = next(self.iterator)
self._stop_profiler()
return data
<_PrefetchDataFetcher>.__next__():
batch = self.batches.pop(0)
self._fetch_next_batch(self.iterator)
return batch
<_PrefetchDataFetcher>._fetch_next_batch(iterator):
self._start_profiler()
batch = next(iterator)
self._stop_profiler()
self.batches.append(batch)
_data_fetcherの初期化に関して:
<_FitLoop>.__init__(...):
self._combined_loader = None
self._data_fetcher = None
<_FitLoop>.run():
self.setup_data()
self.reset()
self.on_run_start()
while not self.done:
self.on_advance_start()
self.advance()
self.on_advance_end()
self.on_run_end()
<_FitLoop>.on_run_start():
self._data_fetcher = _select_data_fetcher(trainer)
<_FitLoop>.advance():
combined_loader = self._combined_loader
self._data_fetcher.setup(combined_loader)
self.epoch_loop.run(self._data_fetcher)
def _select_data_fetcher(trainer) -> _DataFetcher:
lightning_module = trainer.lightning_module
if trainer.testing:
step_fx_name = "test_step"
elif trainer.training:
step_fx_name = "training_step"
elif trainer.validating or trainer.sanity_checking:
step_fx_name = "validation_step"
elif trainer.predicting:
step_fx_name = "predict_step"
else:
raise RuntimeError(f"DataFetcher is unsupported for {trainer.state.stage}")
step_fx = getattr(lightning_module, step_fx_name)
if is_param_in_hook_signature(step_fx, "dataloader_iter", explicit=True):
rank_zero_warn(
f"Found `dataloader_iter` argument in the `{step_fx_name}`. Note that the support for "
"this signature is experimental and the behavior is subject to change."
)
return _DataLoaderIterDataFetcher()
return _PrefetchDataFetcher()