TorchScript対応するには通常のPyTorchで記載したforwardなどが動かない場合があります。
中間コードにするため、主にtype hint・型関連に注意する必要があります。
PyTorch-Lightningを使用するとModelCheckpointなど使うときにtorchscriptを使用するため、修正が必要な場合があります。
目次
torch.nn.ModuleListにindexが使えない(Can’t index nn.ModuleList in script function)
TorchScriptでは、以下のようにModuleListからindexなどでLayerを取得することができません。
import torch
module_list = torch.nn.ModuleList(nn.Moduleのリスト...)
layer1 = module_list[1] # ここでエラー
indexで取得できないため、iteratorとして使用する必要があります。
module_list[i], module_list2[i]みたいにindex使って複数のModuleListを扱っている場合はzipを使うと良いです。
import torch
module_list = torch.nn.ModuleList(nn.Moduleのリスト...)
for layer in module_list:
# 処理
# もしくはenumerate
for i, layer in enumerate(module_list):
# 処理
module_list2 = torch.nn.ModuleList(nn.Moduleのリスト...)
# 2つ以上ならzipを使う
for i, (layer1, layer2) in zip(module_list, module_list2):
# 処理
type hintがないとtorch.Tensor型と判別される
torchscriptではしっかりtype hintを記載することを求められます。
type hintの書かれていない引数はtorch.Tensorと判断されてしまうため、それ以外の型だとRuntimeErrorが起きます。
https://pytorch.org/docs/stable/jit_language_reference.html#default-types
以下のように、明示的にtype hintを記載しましょう。
import torch
...
# これだとRuntimeError
def forward(self, x, aux):
if not aux:
return self.model(x)[0]
out1, out2 = self.model(x)
...
# こうする
def forward(self, x: torch.Tensor, aux: bool):
if not aux:
return self.model(x)[0]
out1, out2 = self.model(x)
ListとTupleを明示的に使い分ける
これは元の実装が良くないパターンが多いかと思いますが、torchscriptではTupleが入るべきところにListが入るとエラーが起こります。Tupleはstaticであることを求められます。
これもtype hintしっかり記載することと通じるところがありますね。
import torch
class TestModel(torch.nn.Module):
def forward(self, x: Tuple[torch.Tensor, torch.Tensor]):
...
model = TestModel()
tensor_list = [torch.Tensorのリスト]
model(tensor_list) # エラーになる
# こうする
class TestModel_(torch.nn.Module):
def forward(self, x: List[torch.Tensor]):
...
# もしくはこうする
model(tensor_list[0], tensor_list[1])