[PyTorch] 過学習を防ぐ手法 – Label smoothingの実装

 今回はLabel smoothingをPyTorchで実装する方法について。

Label smoothingとは


 Onehot表現の教師データにノイズを加えて過学習防止、性能向上をはかる手法です。

 αがパラメータであり、大きいほどノイズが大きくなり元の教師データと遠くなります。大きすぎると当然逆効果です。

 教師データがonehotのような極端な値ではなくなるため、過学習防止につながります。

 以下の資料、元の論文いわく多少性能向上 or 変化なしらしいです。

実装


 基本は以下のように教師データを加工するだけです。

n_classes = 10
alpha = 0.3
teacher = Onehotのtensor

noise_val = alpha / n_classes
teacher = teacher * (1 - alpha) + noise_val

 問題はこれをどこに実装するか。

 Augmentationとして実装すると、Onehotじゃない教師データを使う処理ができなくなってしまいます。

 なので自前のLoss関数を作ってそこで教師データを加工するようにします。

 今回はFocalLossの中に実装するようにします。分類タスクのFocalLossにLabel smoothingを導入してみます。

class ClassificationFocalLossWithLabelSmoothing(nn.Module):
    def __init__(self, n_classes: int, alpha=0.3, gamma=2, weights: List[float] = None):
        """
        :param alpha: parameter of Label Smoothing.
        :param n_classes:
        :param gamma: 簡単なサンプルの重み. 大きいほど簡単なサンプルを重視しない.
        :param weights: weights by classes,
        :param logits:
        """
        super().__init__()
        self._alpha = alpha
        self._noise_val = alpha / n_classes
        self._n_classes = n_classes
        self.gamma = gamma
        self.class_weight_tensor = try_cuda(torch.tensor(weights).view(-1, )) if weights else None

    def forward(self, pred: torch.Tensor, teacher: torch.Tensor) -> float:
        """
        :param pred: batch_size, n_classes
        :param teacher: batch_size,
        :return:
        """
        if teacher.ndim == 1:  # 1次元ならonehotの2次元tensorにする
            teacher = torch.eye(self._n_classes)[teacher]
        # Label smoothing.
        teacher = teacher * (1 - self._alpha) + self._noise_val

        ce_loss = F.binary_cross_entropy_with_logits(pred, teacher, reduce=False)
        pt = torch.exp(-ce_loss)

        if self.class_weight_tensor is not None:
            class_weight_tensor = self.class_weight_tensor.expand(pred.shape[0],
                                                                  self.class_weight_tensor.shape[0], )
            focal_loss = (1. - pt) ** self.gamma * (ce_loss * class_weight_tensor)
        else:
            focal_loss = (1. - pt) ** self.gamma * ce_loss

        return torch.mean(focal_loss)


loss_func = ClassificationFocalLossWithLabelSmoothing(3)
pred, teacher = (batch size, classes)形式のtensor
loss = loss_func(pred, teacher)

 FocalLossについての記事もありますのでよかったら見てみてください!

コメントを残す

メールアドレスが公開されることはありません。