PyTorch-LightningのMetricsとして使える形式のmean IoUを実装したので上げておきます。
まずはPyTorchのインストール。こちらのサイトから。
次にPyTorch-Lightningのインストール。
pip install pytorch_lightning
次にMetrics実装。
本当はconfusion_matrixなど使った方がいいかもですが、segmentationはピクセルが多いのでconfusion_matrix算出にかなり時間がかかります。
なので、クラスごとに面積(union)、かぶっている部分(overlap)を保持しておいて、最後に除算をします。
import torch
import pytorch_lightning as pl
class SegmentationIoU(pl.metrics.Metric):
def __init__(self, n_classes: int, by_classes: bool = False, without_background_class: bool = False):
super().__init__()
self._n_classes = n_classes
self._without_background_class = without_background_class
self._by_classes = by_classes
self.add_state("overlap_by_classes", default=torch.tensor([0 for _ in range(n_classes)]), dist_reduce_fx="sum")
self.add_state("union_by_classes", default=torch.tensor([0 for _ in range(n_classes)]), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, targets: torch.Tensor):
preds, targets = preds.reshape([preds.shape[0], -1]), targets.reshape([targets.shape[0], -1])
for label_val in range(self._n_classes):
targets_indices, preds_indices = (targets == label_val), (preds == label_val)
overlap_indices = targets_indices & preds_indices
overlap = torch.count_nonzero(overlap_indices)
union = torch.count_nonzero(targets_indices) + torch.count_nonzero(preds_indices) - overlap
self.overlap_by_classes[label_val] += overlap
self.union_by_classes[label_val] += union
def compute(self):
iou_by_classes = self.overlap_by_classes / self.union_by_classes
if self._without_background_class:
iou_by_classes = iou_by_classes[1:]
if self._by_classes:
return iou_by_classes
return torch.mean(iou_by_classes)
without_background_classをtrueにすると背景クラス(0)を無視した値を返します。
by_classesにすると、出力を1つの値のfloatTensorではなく、クラスごとのfloatTensor(shape=[n_classes, ])になります。
テストコード下に置いておきます。
import torch
from mean_IoU import SegmentationIoU
n_classes = 3
a = torch.tensor([[[1, 0, 1, 0], [2, 2, 0, 0], [0, 0, 1, 2]]])
b = torch.tensor([[[0, 0, 1, 1], [2, 2, 0, 0], [0, 0, 1, 1]]])
metric = SegmentationIoU(n_classes=n_classes)
metric(a, b)
metric_value = metric.compute()
print(metric_value.item()) # Float value
metric = SegmentationIoU(n_classes=n_classes, by_classes=True)
metric(a, b)
metric_value = metric.compute()
print(metric_value) # Float tensor value (shape is n_claases)