[Pytorch] セマンティックセグメンテーションにおけるDice係数の実装

 今回はDice係数を用いたLossの実装をします。

クラスに偏りがありすぎると学習がうまくいかないことも


 セマンティックセグメンテーションにおいて、Cross EntropyなどのLoss関数だと、ほとんど背景の画像などクラスに偏りがある場合に学習がうまくいかないことがあります。

 ほとんど背景のデータであれば、全て背景と推論すればLossが少なくなってしまうため学習がうまくいかないと思います。

Dice係数を用いたLossの利点・実装


 Dice係数を用いると、検出できた面積の割合でLossを算出するため、背景の割合が多いなどのクラスの偏りが大きいデータでも学習がうまくいくことがあります。

class DiceLoss(torch.nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, pred: torch.Tensor, teacher: torch.Tensor, smooth=1.0):
        """
        :param pred:
        :param teacher:
        :param smooth:
        :return:
        """
        pred = F.normalize(pred - pred.min(), 1)
        pred, teacher = teacher.float(), pred.float()

        # batch size, classes, width and height
        intersection = (pred * teacher).sum((-1, -2))
        pred = pred.contiguous().view(pred.shape[0], pred.shape[1], -1)
        teacher = teacher.contiguous().view(teacher.shape[0], teacher.shape[1], -1)
        pred_sum = pred.sum((-1,))
        teacher_sum = teacher.sum((-1,))
        dice_by_classes = (2. * intersection + smooth) / (pred_sum + teacher_sum + smooth)
        return (1. - dice_by_classes).mean((-1,)).mean((-1,))

 ただし、最後にクラスごとに平均をとっているため、クラス数が多いと正常に学習が進まないことがあります。

 出現しているクラスのみで平均をとる方法については後日書きます。

コメントを残す

メールアドレスが公開されることはありません。