[PyTorch-Lightning] Image Segmentationのmean IoUのMetrics実装

 PyTorch-LightningのMetricsとして使える形式のmean IoUを実装したので上げておきます。

 まずはPyTorchのインストール。こちらのサイトから。

https://pytorch.org/

 次にPyTorch-Lightningのインストール。

pip install pytorch_lightning

 次にMetrics実装。

 本当はconfusion_matrixなど使った方がいいかもですが、segmentationはピクセルが多いのでconfusion_matrix算出にかなり時間がかかります。

 なので、クラスごとに面積(union)、かぶっている部分(overlap)を保持しておいて、最後に除算をします。

import torch
import pytorch_lightning as pl


class SegmentationIoU(pl.metrics.Metric):
    def __init__(self, n_classes: int, by_classes: bool = False, without_background_class: bool = False):
        super().__init__()
        self._n_classes = n_classes
        self._without_background_class = without_background_class
        self._by_classes = by_classes
        self.add_state("overlap_by_classes", default=torch.tensor([0 for _ in range(n_classes)]), dist_reduce_fx="sum")
        self.add_state("union_by_classes", default=torch.tensor([0 for _ in range(n_classes)]), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, targets: torch.Tensor):
        preds, targets = preds.reshape([preds.shape[0], -1]), targets.reshape([targets.shape[0], -1])
        for label_val in range(self._n_classes):
            targets_indices, preds_indices = (targets == label_val), (preds == label_val)
            overlap_indices = targets_indices & preds_indices
            overlap = torch.count_nonzero(overlap_indices)
            union = torch.count_nonzero(targets_indices) + torch.count_nonzero(preds_indices) - overlap
            self.overlap_by_classes[label_val] += overlap
            self.union_by_classes[label_val] += union

    def compute(self):
        iou_by_classes = self.overlap_by_classes / self.union_by_classes
        if self._without_background_class:
            iou_by_classes = iou_by_classes[1:]
        if self._by_classes:
            return iou_by_classes
        return torch.mean(iou_by_classes)

 without_background_classをtrueにすると背景クラス(0)を無視した値を返します。

 by_classesにすると、出力を1つの値のfloatTensorではなく、クラスごとのfloatTensor(shape=[n_classes, ])になります。

 テストコード下に置いておきます。

import torch

from mean_IoU import SegmentationIoU

n_classes = 3
a = torch.tensor([[[1, 0, 1, 0], [2, 2, 0, 0], [0, 0, 1, 2]]])
b = torch.tensor([[[0, 0, 1, 1], [2, 2, 0, 0], [0, 0, 1, 1]]])

metric = SegmentationIoU(n_classes=n_classes)
metric(a, b)
metric_value = metric.compute()

print(metric_value.item())  # Float value


metric = SegmentationIoU(n_classes=n_classes, by_classes=True)
metric(a, b)
metric_value = metric.compute()

print(metric_value)  # Float tensor value (shape is n_claases)

コメントを残す

メールアドレスが公開されることはありません。