[PyTorch-Lightning] lr_shedulerはどうする?

 PyTorchのラッピングライブラリであるPyTorch-Lightning、めちゃくちゃ便利ですよね。

https://pytorch-lightning.readthedocs.io/en/latest/

 かなり痒い所に手が届く感じですが、一瞬lr_schedulerどうやるかわからなかったので書きます。

lr_shcedulerの導入


 optimizerはLightningModule#configure_optimizersでoptimizerと一緒に返せば大丈夫です。Listで返す必要があります。

import torch
import pytorch_lightning as pl

class TestModule(pl.LightningModule):
  ...
  def configure_optimizers():
    optimizer = torch.optim.AdamW(lr=self.lr, params=self.model.parameters())
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
    return [optimizer, ], [scheduler, ]

Epoch数が必要なlr_schedulerの場合


 CossineAnnearningLRなど、エポック数が必要な手法があります。いちいちコンストラクタにepoch数渡すのは馬鹿馬鹿しいですね。。。

 LightningModuleにはTrainerが結び付くのでそこからepoch数を取得すればよいです。self.trainer.max_epochsで取得できます。

import torch
import pytorch_lightning as pl

class TestModule(pl.LightningModule):
  ...
  def configure_optimizers():
    optimizer = torch.optim.AdamW(lr=self.lr, params=self.model.parameters())
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.trainer.max_epochs)   
    return [optimizer, ], [scheduler, ]

補足: ModelCheckpointでエラーが起こるんだけど…

 lr_schedulerとは違う話ですが、PyTorch-LightningのModelCheckpoint実行時にNo argument的なエラーが出て保存できないことがありました。

 どうもコンストラクタでself.save_hyperparameters()を必ず実行する必要があるみたいです。

import torch
import pytorch_lightning as pl

class TestModule(pl.LightningModule):
  ...
  def __init__(self, パラメータ):
    ...
    self.save_hyperparameters()

コメントを残す

メールアドレスが公開されることはありません。