引数で初期化され内部状態を持つインスタンスがあったとき、それをいかに保存し、のちに復元するか.
問題整理
class A: def __init__(self, arg1): parts = SubA(arg1) self.var = 1 def add(self): self.var += 5 def save(self): pass def from_snapshot(self): pass # Initialize a = A(2) # Update a.add() # Save snapshot = a.save() # Restore a_2 = A.from_snapshot(snapshot)
見えてくる課題/制約は
すなわちhyperparameter的なものと内部state的なものの2つをsnapshotとして扱う必要がある.
解決案
すべてinit引数にする
内部状態をinit引数にし、内部状態の初期化/復元を同じに扱う.
# Restore snapshot = {arg1: 2, var: 6} a_2 = A(**snapshot)
問題点: 煩雑
内部状態が増減するたびに引数のリライトが必要
透過的に扱えるはずの内部状態を明示的引数として外部に晒している(複雑になるうえ、不正な初期化のリスクも生む)
snapshotを明示的な型として扱わない限り型の支援が減る(dict展開は型支援があいまいになりがち)
static関数から復元する
Cls.from_snapshot関数を用意してsnapshotを突っ込み、そこで初期化と内部状態復元をおこなう.
課題/制約から
- snapshotから引数hparamsと内部stateを取り出し
- hparamsを引数としたinit
- 内部stateの上書き
- 復元済みインスタンスの返却
のステップにわけると制約を分離できることがわかる.
なので
class A: def from_snapshot(ckpt): hparams, state = ckpt a_2 = A(**hparams) a_2.var = state.var return a_2
で綺麗に復元できる.
stateがたくさんある場合、この素朴な実装だと内部状態が増えるたびにrestore関数へも追記が必要.
なのでstate_dict
的な変数の下に保存対象となる内部状態を吊るす形が綺麗.
例
PyTorch-Lightning
PyTorchの制約でhparamsとstateの両方が必ず発生する.
なのでこの方式を内部で使ってる.
実装としてはinit引数の解析でselfを取り除くとかこまごましたテクニックがあるのでいい参考になる.
static関数の命名
hparamsから再構築するという意味ではreconstruct
がよく合っていて、内部状態を載せなおすという意味ではload
がよく合っている.
両方やっている、かつstatic関数なので、Cls.from_snapshot
がちょうどいい命名だと思う.