[PyTorch] 学習率スケジューリング – Cosine decay rule

 今回は学習率をエポックが進むごとに変化させるルール、学習率スケジューリングについて解説、検証します。

※検証結果については後日書きます

 lossの推移に近いルール、Cosine関数に従うルールの2つについて書いていきます。

ShelfNetの学習率スケジューリング


 ShelfNetというセマンティックセグメンテーションモデルの論文で書かれている手法。

https://arxiv.org/abs/1811.11254

数式

 数式はこうなっています。powerを上げると減衰が早くなります。論文上ではpowerは0.9でした。

エポックごとの推移

 エポックごとの学習率の推移はこんな感じ。元が1e-2で300エポックまで算出。

 powerが大きいほど減衰率が上がる。0.5だと逆になる。論文通り0.9だとほぼ直線。

 1以上においては、速めに減少して最後の方は緩やかになる。一般的なlossのグラフと似たような感じだ。学習率もlossと同じように推移していくと収束しやすいのだろうか。

実装

 PyTorch/Pythonの実装はこんな感じです。

import torch

class LearningRateScheduler:
    def __init__(self, base_lr: float, max_epoch: int, power=0.9):
        self._max_epoch = max_epoch
        self._power = power
        self._base_lr = base_lr

    def __call__(self, epoch: int):
        return (1 - max(epoch - 1, 1) / self._max_epoch) ** self._power * self._base_lr

 ちなみに使い方はこんな感じ。

lr_scheduler_func = LearningRateScheduler(1e-2, 300)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(<Adam, SGDなどのoptimizer>,
                                  lr_lambda=lr_scheduler_func)

# 学習中
lr_scheduler.step(<エポック数>)

Cosine decayの学習率スケジューリング


 次にCosine decayルール。これはEfficientDetで用いられている。EfficientDetだと0.16まで上昇してそこから下がるようにするっぽい。

 調べた感じ論文とかなさそうで、これが近い感じだろうか。

https://medium.com/@scorrea92/cosine-learning-rate-decay-e8b50aa455b

数式

 数式はこんな感じ。warmup_epochまでは上昇し、そこからはCosine関数に従って学習率を落としていく感じです。

エポックごとの推移

 エポックごとの推移はこんな感じ。まあ普通にCosineですね。

 warmup期間がなければ青線のように最初からCosine関数に従う。

実装

 実装はこんな感じ。

import torch

class CosineDecayScheduler:
    def __init__(self, max_epochs: int, warmup_lr_limit=0.16, warmup_epochs=5):
        self._max_epochs = max_epochs
        self._warmup_lr_limit = warmup_lr_limit
        self._warmup_epochs = warmup_epochs

    def __call__(self, epoch: int):
        epoch = max(epoch, 1)
        if epoch <= self._warmup_epochs:
            return self._warmup_lr_limit * epoch / self._warmup_epochs
        epoch -= 1
        rad = math.pi * epoch / self._max_epochs
        weight = (math.cos(rad) + 1.) / 2
        return self._warmup_lr_limit * weight


lr_scheduler_func = CosineDecayScheduler(300)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(<Adam, SGDなどのoptimizer>,
                                  lr_lambda=lr_scheduler_func)

# 学習中
lr_scheduler.step(<エポック数>)

Learning rate schedulerの検証


 検証終わったら書きます。

 精度向上に寄与するFocalLossという損失関数についても記事上げているのでこちらも見てみてください!

コメントを残す

メールアドレスが公開されることはありません。