たれぱんのびぼーろく

わたしの備忘録、生物学とプログラミングが多いかも

GANの学習実装

実はハマりどころだらけだったりする。

  • Gを先にBackwardするとDの.gradが貯まる
    • => "D先に学習" or "Gの学習時はD.requires_grad = false" or "D学習前にNetD.zero_grad()"
  • D学習時にG(z)を渡すとGにまでBackpropして無駄 & Gの.gradが貯まる
    • => "D(G(z).detach())" or "Dの学習時はG.requires_grad = false" or "G学習前にNetG.zero_grad()"
  • G/D学習の2回で両方G(z) forwardして無駄 (2 forward pass)
    • => "fake = G(z)で保持、Dの学習にD(fake.detach())、Gの学習はD(fake)"
      • G forward時のrequires_gradが必要なのでrequires_gradいじる他との衝突に注意
      • G先に学習すると fake = G_old(z)を使ってDが学習する => fake再利用なら学習はD先G後

標準実装#1: DCGAN - PyTorch official (tutorial, repository)

標準実装#1pl: Vanilla GAN - PyTorch Lightning

fake保持

素直といえば素直だけど、抽象化はしづらい処理.
PyTorch-Lightningとかだと素直には使えない(backward周りすべて自前処理が必要).

削減量

10%くらい計算量が減る (9 step => 8 step).
Gが重いと効果が大きくなりそう

# D training
## forward
G(z) => fake
D(fake)
D(real)
## backward
D(fake)
D(real)
# G training
## forward
G(z) => fake (ここだけ省略できる)
D(fake)
## backward
D(fake)
G(z)

ざっくりとだけ理解しているもの

separate mini-batch: バッチ単位の処理はReal内、Fake内でおこなうようにする
例: (loss(D(real)) + loss(D(fake))).backward()すると両方のbatchのlossがBatchNormに流れ込み、バッチ内統計がreal/fakeごちゃ混ぜになる.

Now, with the gradients accumulated from both the all-real and all-fake batches, we call a step of the Discriminator’s optimizer.