[DeepLearning][PyTorch] 精度向上につながる手法 – FocalLossの実装

 今回は有名なモデルでもよく使われるFocalLossという損失関数の実装について書いていきます。

FocalLossとは


 ざっくり言うと、簡単なサンプルは損失値を小さくするようにする損失関数。

 EfficientDetでも適用されている精度向上に繋がる手法です。

 仕組みの詳細は以下のページにあります。

https://qiita.com/agatan/items/53fe8d21f2147b0ac982

Semantic segmentationのFocalLoss実装


 数式そのままPyTorchで実現した感じです。ついでにloss値をクラスごとに重みづけできるようにしてあります。

 セグメンテーションは(batch_size, class, width, height)なので、pred/teacherは4次元である必要があります。

 weightsはクラスごとの重みです。

 gammaは簡単なサンプルをどの程度重視するかです。gammaが大きいほど簡単なサンプルを重視しないようにします。

from torch import nn

class SegmentationFocalLoss(nn.Module):
    def __init__(self, gamma=2, weights: List[float] = None, logits=True):
        """
        :param gamma: 簡単なサンプルの重み. 大きいほど簡単なサンプルを重視しない.
        :param weights: weights by classes,
        :param logits:
        """
        super().__init__()
        self.gamma = gamma
        self.class_weight_tensor = try_cuda(torch.tensor(weights).view(-1, 1, 1)) if weights else None
        self.logits = logits
        if not logits and weights is not None:
            RuntimeWarning("重みを適用するにはlogitsをTrueにしてください.")

    def forward(self, pred: torch.Tensor, teacher: torch.Tensor) -> float:
        """
        :param pred: batch_size, n_classes, height, width
        :param teacher: batch_size, n_classes, height, width
        :return:
        """
        if self.logits:
            ce_loss = F.binary_cross_entropy_with_logits(pred, teacher, reduce=False)
            pt = torch.exp(-ce_loss)
            if self.class_weight_tensor:
                class_weight_tensor = self.class_weight_tensor.expand(pred.shape[0],
                                                                      self.class_weight_tensor.shape[0],
                                                                      self.class_weight_tensor.shape[1],
                                                                      self.class_weight_tensor.shape[2])
                focal_loss = (1. - pt) ** self.gamma * (ce_loss * class_weight_tensor)
            else:
                focal_loss = (1. - pt) ** self.gamma * ce_loss
        else:
            ce_loss = F.cross_entropy(pred, teacher.argmax(1), reduce=False)
            pt = torch.exp(-ce_loss)
            focal_loss = (1. - pt) ** self.gamma * ce_loss
        return torch.mean(focal_loss)

分類(Classification)タスクのFocalLoss


 分類タスクに適用する場合の実装も載せておきます。

 predは(batch_size, n_classes)、teacherは(batch_size, )もしくは(batch_size, n_classes)のshapeである必要があります。

from torch import nn

class ClassificationFocalLoss(nn.Module):
    def __init__(self, n_classes: int, gamma=2, weights: List[float] = None, logits=True):
        """
        :param n_classes: 
        :param gamma: 簡単なサンプルの重み. 大きいほど簡単なサンプルを重視しない.
        :param weights: weights by classes,
        :param logits:
        """
        super().__init__()
        self._n_classes = n_classes
        self.gamma = gamma
        self.class_weight_tensor = try_cuda(torch.tensor(weights).view(-1, )) if weights else None
        self.logits = logits
        if not logits and weights is not None:
            RuntimeWarning("重みを適用するにはlogitsをTrueにしてください.")

    def forward(self, pred: torch.Tensor, teacher: torch.Tensor) -> float:
        """
        :param pred: batch_size, n_classes
        :param teacher: batch_size, 
        :return: 
        """
        if self.logits:
            if teacher.ndim == 1:  # 1次元ならonehotの2次元tensorにする
                teacher = torch.eye(self._n_classes)[teacher]
            ce_loss = F.binary_cross_entropy_with_logits(pred, teacher, reduce=False)
            pt = torch.exp(-ce_loss)

            if self.class_weight_tensor:
                class_weight_tensor = self.class_weight_tensor.expand(pred.shape[0],
                                                                      self.class_weight_tensor.shape[0], )
                focal_loss = (1. - pt) ** self.gamma * (ce_loss * class_weight_tensor)
            else:
                focal_loss = (1. - pt) ** self.gamma * ce_loss
        else:
            if teacher.ndim == 2:  # onehotの2次元tensorなら1次元にする
                teacher = teacher.argmax(1)
            ce_loss = F.cross_entropy(pred, teacher, reduce=False)
            pt = torch.exp(-ce_loss)
            focal_loss = (1. - pt) ** self.gamma * ce_loss
        return torch.mean(focal_loss)

 学習を早めるための学習率スケジューリング手法について記事書いているのでそちらも見てみてください!

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です