[Pytorch] torchvision標準のResNetをbackboneとして使う

 今回は、torchvisionに入っているResNetをバックボーンとして使用する方法について。

torchvisionには有名なモデルがいくつかある


 pytorch/torchvisionには標準でResNetやResNeXt、MobileNetなどがあります。

https://pytorch.org/docs/stable/torchvision/models.html

 学習済みの重みも容易にロードできるためfine-tuningしやすいです。

 標準のモデルは分類用のモデルのため、そのままだとバックボーンとしては使えません。

 中間レイヤーをうまく使えばBackboneとして特徴量を抽出できます。

ResNetバックボーン実装


 torchvisionのresnetの中身を見てみると、conv1あたりの前処理の後に、特徴量抽出部分がありlayer1~4が該当します。ここの中間出力をバックボーンの出力として使います。

 出力は4つのtensorで、それぞれ画像の大きさがwidth, heightが1/4, 1/8, 1/16, 1/32になっています。

 また、resnet_typeによってスケール可能にしてあります。

 出力チャンネル数はモデルのスケールによって違っており、以下の通りです。

  • resnet18, 34は[64, 128, 256, 512]
  • 50以上は[256, 512, 1024, 2048]
import torchvision
from torch import nn

class ResNetBackBone(nn.Module):
    def __init__(self, resnet_type="resnet18", pretrained=True):
        super().__init__()
        if resnet_type == "resnet18":
            self.resnet_model = torchvision.models.resnet18(pretrained=pretrained)
        elif resnet_type == "resnet34":
            self.resnet_model = torchvision.models.resnet34(pretrained=pretrained)
        elif resnet_type == "resnet50":
            self.resnet_model = torchvision.models.resnet50(pretrained=pretrained)
        elif resnet_type == "resnet101":
            self.resnet_model = torchvision.models.resnet101(pretrained=pretrained)
        elif resnet_type == "resnet152":
            self.resnet_model = torchvision.models.resnet152(pretrained=pretrained)

    def forward(self, x):
        x = self.resnet_model.conv1(x)
        x = self.resnet_model.bn1(x)
        x = self.resnet_model.relu(x)
        x = self.resnet_model.maxpool(x)
        out1 = self.resnet_model.layer1(x)  # width, heightは1/4
        out2 = self.resnet_model.layer2(out1)  # width, heightは1/8
        out3 = self.resnet_model.layer3(out2)  # width, heightは1/16
        out4 = self.resnet_model.layer4(out3)  # width, heightは1/32
        return out1, out2, out3, out4

他にもResNeXtやVGGもバックボーンとして使える


 上記のようにモデルの中身を見て中間レイヤーの出力だけ取得するようにすれば、ほかのモデルもバックボーンとして使えます。

 ResNeXtに関しては、ResNetとインスタンス変数が同じのため、forwardを組み替えずにこれだけでいけます。

self.resnet_model = torchvision.models.resnext50_32x4d(pretrained=pretrained)

コメントを残す

メールアドレスが公開されることはありません。