PyTorchのラッピングライブラリであるPyTorch-Lightning、めちゃくちゃ便利ですよね。
https://pytorch-lightning.readthedocs.io/en/latest/
かなり痒い所に手が届く感じですが、一瞬lr_schedulerどうやるかわからなかったので書きます。
目次
lr_shcedulerの導入
optimizerはLightningModule#configure_optimizersでoptimizerと一緒に返せば大丈夫です。Listで返す必要があります。
import torch
import pytorch_lightning as pl
class TestModule(pl.LightningModule):
...
def configure_optimizers():
optimizer = torch.optim.AdamW(lr=self.lr, params=self.model.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
return [optimizer, ], [scheduler, ]
Epoch数が必要なlr_schedulerの場合
CossineAnnearningLRなど、エポック数が必要な手法があります。いちいちコンストラクタにepoch数渡すのは馬鹿馬鹿しいですね。。。
LightningModuleにはTrainerが結び付くのでそこからepoch数を取得すればよいです。self.trainer.max_epochsで取得できます。
import torch
import pytorch_lightning as pl
class TestModule(pl.LightningModule):
...
def configure_optimizers():
optimizer = torch.optim.AdamW(lr=self.lr, params=self.model.parameters())
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.trainer.max_epochs)
return [optimizer, ], [scheduler, ]
補足: ModelCheckpointでエラーが起こるんだけど…
lr_schedulerとは違う話ですが、PyTorch-LightningのModelCheckpoint実行時にNo argument的なエラーが出て保存できないことがありました。
どうもコンストラクタでself.save_hyperparameters()を必ず実行する必要があるみたいです。
import torch
import pytorch_lightning as pl
class TestModule(pl.LightningModule):
...
def __init__(self, パラメータ):
...
self.save_hyperparameters()