[PyTorch-Lightning] TorchScript対応で注意・修正が必要な点

 TorchScript対応するには通常のPyTorchで記載したforwardなどが動かない場合があります。

 中間コードにするため、主にtype hint・型関連に注意する必要があります。

 PyTorch-Lightningを使用するとModelCheckpointなど使うときにtorchscriptを使用するため、修正が必要な場合があります。

torch.nn.ModuleListにindexが使えない(Can’t index nn.ModuleList in script function)


 TorchScriptでは、以下のようにModuleListからindexなどでLayerを取得することができません。

import torch

module_list = torch.nn.ModuleList(nn.Moduleのリスト...)
layer1 = module_list[1]  # ここでエラー

 indexで取得できないため、iteratorとして使用する必要があります。

 module_list[i], module_list2[i]みたいにindex使って複数のModuleListを扱っている場合はzipを使うと良いです。

import torch

module_list = torch.nn.ModuleList(nn.Moduleのリスト...)

for layer in module_list:
  # 処理

# もしくはenumerate
for i, layer in enumerate(module_list):
  # 処理

module_list2 = torch.nn.ModuleList(nn.Moduleのリスト...)

# 2つ以上ならzipを使う
for i, (layer1, layer2) in zip(module_list, module_list2):
  # 処理

 

type hintがないとtorch.Tensor型と判別される


 torchscriptではしっかりtype hintを記載することを求められます。

 type hintの書かれていない引数はtorch.Tensorと判断されてしまうため、それ以外の型だとRuntimeErrorが起きます。

https://pytorch.org/docs/stable/jit_language_reference.html#default-types

 以下のように、明示的にtype hintを記載しましょう。

import torch
...
# これだとRuntimeError
def forward(self, x, aux):
  if not aux:
    return self.model(x)[0]
  out1, out2 = self.model(x)

...
# こうする
def forward(self, x: torch.Tensor, aux: bool):
  if not aux:
    return self.model(x)[0]
  out1, out2 = self.model(x)

ListとTupleを明示的に使い分ける


 これは元の実装が良くないパターンが多いかと思いますが、torchscriptではTupleが入るべきところにListが入るとエラーが起こります。Tupleはstaticであることを求められます。

 これもtype hintしっかり記載することと通じるところがありますね。

import torch

class TestModel(torch.nn.Module):
  def forward(self, x: Tuple[torch.Tensor, torch.Tensor]):
    ...

model = TestModel()
tensor_list = [torch.Tensorのリスト]
model(tensor_list)  # エラーになる


# こうする
class TestModel_(torch.nn.Module):
  def forward(self, x: List[torch.Tensor]):
    ...

# もしくはこうする
model(tensor_list[0], tensor_list[1])

コメントを残す

メールアドレスが公開されることはありません。