実はハマりどころだらけだったりする。
- Gを先にBackwardするとDの.gradが貯まる
- => "D先に学習" or "Gの学習時はD.requires_grad = false" or "D学習前にNetD.zero_grad()"
- D学習時にG(z)を渡すとGにまでBackpropして無駄 & Gの.gradが貯まる
- => "D(G(z).detach())" or "Dの学習時はG.requires_grad = false" or "G学習前にNetG.zero_grad()"
- G/D学習の2回で両方G(z) forwardして無駄 (2 forward pass)
- => "fake = G(z)で保持、Dの学習にD(fake.detach())、Gの学習はD(fake)"
- G forward時のrequires_gradが必要なのでrequires_gradいじる他との衝突に注意
- G先に学習すると fake = G_old(z)を使ってDが学習する => fake再利用なら学習はD先G後
- => "fake = G(z)で保持、Dの学習にD(fake.detach())、Gの学習はD(fake)"
標準実装#1: DCGAN - PyTorch official (tutorial, repository)
標準実装#1pl: Vanilla GAN - PyTorch Lightning
fake保持
素直といえば素直だけど、抽象化はしづらい処理.
PyTorch-Lightningとかだと素直には使えない(backward周りすべて自前処理が必要).
削減量
10%くらい計算量が減る (9 step => 8 step).
Gが重いと効果が大きくなりそう
# D training ## forward G(z) => fake D(fake) D(real) ## backward D(fake) D(real) # G training ## forward G(z) => fake (ここだけ省略できる) D(fake) ## backward D(fake) G(z)
ざっくりとだけ理解しているもの
separate mini-batch: バッチ単位の処理はReal内、Fake内でおこなうようにする
例: (loss(D(real)) + loss(D(fake))).backward()すると両方のbatchのlossがBatchNormに流れ込み、バッチ内統計がreal/fakeごちゃ混ぜになる.
Now, with the gradients accumulated from both the all-real and all-fake batches, we call a step of the Discriminator’s optimizer.