たれぱんのびぼーろく

わたしの備忘録、生物学とプログラミングが多いかも

インスタンス状態の保存と復元

引数で初期化され内部状態を持つインスタンスがあったとき、それをいかに保存し、のちに復元するか.

問題整理

class A:
  def __init__(self, arg1):
    parts = SubA(arg1)
    self.var = 1
  def add(self):
    self.var += 5
  def save(self):
    pass
  def from_snapshot(self):
    pass

# Initialize
a = A(2)
# Update
a.add()
# Save
snapshot = a.save()
# Restore
a_2 = A.from_snapshot(snapshot)

見えてくる課題/制約は

  • arg1 はコンストラクタに与えなくてはならない(後から変更できない)
  • self.var はコンストラクタ引数にない

すなわちhyperparameter的なものと内部state的なものの2つをsnapshotとして扱う必要がある.

解決案

すべてinit引数にする

内部状態をinit引数にし、内部状態の初期化/復元を同じに扱う.

# Restore
snapshot = {arg1: 2, var: 6}
a_2 = A(**snapshot)

問題点: 煩雑
内部状態が増減するたびに引数のリライトが必要
透過的に扱えるはずの内部状態を明示的引数として外部に晒している(複雑になるうえ、不正な初期化のリスクも生む)
snapshotを明示的な型として扱わない限り型の支援が減る(dict展開は型支援があいまいになりがち)

static関数から復元する

Cls.from_snapshot関数を用意してsnapshotを突っ込み、そこで初期化と内部状態復元をおこなう.

課題/制約から

  1. snapshotから引数hparamsと内部stateを取り出し
  2. hparamsを引数としたinit
  3. 内部stateの上書き
  4. 復元済みインスタンスの返却

のステップにわけると制約を分離できることがわかる.
なので

class A:
  def from_snapshot(ckpt):
    hparams, state = ckpt
    a_2 = A(**hparams)
    a_2.var = state.var
    return a_2

で綺麗に復元できる.

stateがたくさんある場合、この素朴な実装だと内部状態が増えるたびにrestore関数へも追記が必要.
なのでstate_dict的な変数の下に保存対象となる内部状態を吊るす形が綺麗.

PyTorch-Lightning

PyTorchの制約でhparamsとstateの両方が必ず発生する.
なのでこの方式を内部で使ってる.
実装としてはinit引数の解析でselfを取り除くとかこまごましたテクニックがあるのでいい参考になる.

static関数の命名

hparamsから再構築するという意味ではreconstructがよく合っていて、内部状態を載せなおすという意味ではloadがよく合っている.
両方やっている、かつstatic関数なので、Cls.from_snapshotがちょうどいい命名だと思う.