1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
| import torch import torch.nn as nn
class SimpleTransformer(nn.Module): def __init__(self, input_dim=512, num_layers=6, nhead=8): super().__init__() encoder_layer = nn.TransformerEncoderLayer( d_model=input_dim, nhead=nhead, dim_feedforward=2048, dropout=0.1, activation="relu", batch_first=True, )
self.transformer_encoder = nn.TransformerEncoder( encoder_layer, num_layers=num_layers )
def forward(self, x): x = self.input_proj(x) output = self.transformer_encoder(x) return output
model = SimpleTransformer(input_dim=512, num_layers=2, nhead=8) model.eval()
dummy_input = torch.randn(2, 10, 512)
torch.onnx.export( model, (dummy_input,), "transformer_encoder.onnx", do_constant_folding=True, input_names=["input"], output_names=["output"], dynamo=True, )
print("ONNX model exported...TransformerEncoder导出onnx问题解决1. 问题说明在使用Pytorch的TransformerEncoder时,导出onnx会将时序长度固定,导致没法采用变长输入,例如下面的简单例子复现了这个问题: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
| import torch import torch.nn as nn
class SimpleTransformer(nn.Module): def __init__(self, input_dim=512, num_layers=6, nhead=8): super().__init__() encoder_layer = nn.TransformerEncoderLayer( d_model=input_dim, nhead=nhead, dim_feedforward=2048, dropout=0.1, activation="relu", batch_first=True, )
self.transformer_encoder = nn.TransformerEncoder( encoder_layer, num_layers=num_layers )
def forward(self, x): x = self.input_proj(x) output = self.transformer_encoder(x) return output
model = SimpleTransformer(input_dim=512, num_layers=2, nhead=8) model.eval()
dummy_input = torch.randn(2, 10, 512)
torch.onnx.export( model, (dummy_input,), "transformer_encoder.onnx", do_constant_folding=True, input_names=["input"], output_names=["output"], dynamo=True, )
print("ONNX model exported... |
|