今回は学習率をエポックが進むごとに変化させるルール、学習率スケジューリングについて解説、検証します。
※検証結果については後日書きます
lossの推移に近いルール、Cosine関数に従うルールの2つについて書いていきます。
目次
ShelfNetの学習率スケジューリング
ShelfNetというセマンティックセグメンテーションモデルの論文で書かれている手法。
https://arxiv.org/abs/1811.11254
数式
数式はこうなっています。powerを上げると減衰が早くなります。論文上ではpowerは0.9でした。
エポックごとの推移
エポックごとの学習率の推移はこんな感じ。元が1e-2で300エポックまで算出。
powerが大きいほど減衰率が上がる。0.5だと逆になる。論文通り0.9だとほぼ直線。
1以上においては、速めに減少して最後の方は緩やかになる。一般的なlossのグラフと似たような感じだ。学習率もlossと同じように推移していくと収束しやすいのだろうか。
実装
PyTorch/Pythonの実装はこんな感じです。
import torch
class LearningRateScheduler:
def __init__(self, base_lr: float, max_epoch: int, power=0.9):
self._max_epoch = max_epoch
self._power = power
self._base_lr = base_lr
def __call__(self, epoch: int):
return (1 - max(epoch - 1, 1) / self._max_epoch) ** self._power * self._base_lr
ちなみに使い方はこんな感じ。
lr_scheduler_func = LearningRateScheduler(1e-2, 300)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(<Adam, SGDなどのoptimizer>,
lr_lambda=lr_scheduler_func)
# 学習中
lr_scheduler.step(<エポック数>)
Cosine decayの学習率スケジューリング
次にCosine decayルール。これはEfficientDetで用いられている。EfficientDetだと0.16まで上昇してそこから下がるようにするっぽい。
調べた感じ論文とかなさそうで、これが近い感じだろうか。
https://medium.com/@scorrea92/cosine-learning-rate-decay-e8b50aa455b
数式
数式はこんな感じ。warmup_epochまでは上昇し、そこからはCosine関数に従って学習率を落としていく感じです。
エポックごとの推移
エポックごとの推移はこんな感じ。まあ普通にCosineですね。
warmup期間がなければ青線のように最初からCosine関数に従う。
実装
実装はこんな感じ。
import torch
class CosineDecayScheduler:
def __init__(self, max_epochs: int, warmup_lr_limit=0.16, warmup_epochs=5):
self._max_epochs = max_epochs
self._warmup_lr_limit = warmup_lr_limit
self._warmup_epochs = warmup_epochs
def __call__(self, epoch: int):
epoch = max(epoch, 1)
if epoch <= self._warmup_epochs:
return self._warmup_lr_limit * epoch / self._warmup_epochs
epoch -= 1
rad = math.pi * epoch / self._max_epochs
weight = (math.cos(rad) + 1.) / 2
return self._warmup_lr_limit * weight
lr_scheduler_func = CosineDecayScheduler(300)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(<Adam, SGDなどのoptimizer>,
lr_lambda=lr_scheduler_func)
# 学習中
lr_scheduler.step(<エポック数>)
Learning rate schedulerの検証
検証終わったら書きます。
精度向上に寄与するFocalLossという損失関数についても記事上げているのでこちらも見てみてください!