Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Importing a Pytorch neural network with tanh activation using ONNX does not work. #109

Closed
caelorza opened this issue May 29, 2023 · 4 comments

Comments

@caelorza
Copy link

Using Tanh as an activation function returns the following error. Using other activation functions such as sigmoid or relu works fine.

Exception: Unhandled node type Tanh

ONNX version: 1.13.1
OMLT version: 1.1
Pytorch version: 2.0.1

import torch
from torch import nn
from collections import OrderedDict


from omlt.io import write_onnx_model_with_bounds, load_onnx_neural_network_with_bounds
from omlt import OmltBlock, OffsetScaling
from omlt.neuralnet import FullSpaceNNFormulation, NetworkDefinition

import tempfile

if torch.cuda.is_available():
    device = torch.device('cuda')
    print("Running on GPU")
else:
    device = torch.device('cpu')
    print("Running on CPU")
    
class DNN(nn.Module):
    def __init__(
        self, n_in, n_out, n_neu, n_layers, activ="Tanh"
    ):
        super().__init__()
            
        self.n_in = n_in
        self.n_out = n_out
        self.n_neu = n_neu
        self.n_layers = n_layers
        self.activ = getattr(nn, activ)
        
        layer_list =list()

        layer_list.append(
            ('layer_%d' % 0, nn.Linear(self.n_in, self.n_neu))
        )
        
        layer_list.append(
            ('activation_%d' % 0, self.activ())
        )
        
        for i in range(1,self.n_layers):
            
            layer_list.append(
                ('layer_%d' % i,nn.Linear(self.n_neu, self.n_neu))
            )
            
            
            layer_list.append(
                ('activation_%d' % i, self.activ())
            )
        
        layer_list.append(
            ('layer_%d' % n_layers, nn.Linear(self.n_neu, self.n_out))
        )
        
        layerDict = OrderedDict(layer_list)
        
        self.dnn = nn.Sequential(layerDict)

    def forward(self, x):
        return self.dnn(x)
    
    
if __name__=="__main__":
    
    n_in = 2
    n_out = 2
    n_neu = 5
    n_layers = 3
    
    model = DNN(n_in,n_out,n_neu,n_layers,activ="Tanh").to(device)
    
    x = torch.randn(10, n_in, requires_grad=True).to(device)
    pytorch_model = None
    with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f:
        torch.onnx.export(
            model,
            x,
            f,
            input_names=['input'],
            output_names=['output'],
            dynamic_axes={
                'input': {0: 'batch_size'},
                'output': {0: 'batch_size'}
            }
        )
        
    input_bounds = [(-1, 1) for _ in range(n_in)]
    write_onnx_model_with_bounds(f.name, None, input_bounds)
    print(f"Wrote PyTorch model to {f.name}")
    pytorch_model = f.name
    
    network_definition = load_onnx_neural_network_with_bounds(pytorch_model)
@caelorza
Copy link
Author

caelorza commented May 30, 2023

Capture
The same error also occurs when using Softplus as the activation function.

@zzygith
Copy link

zzygith commented Sep 24, 2023

I meet the same question. I don't find Tanh in _ACTIVATION_OP_TYPES of onnx_parser.py. Do you have any clue now?

@rmisener
Copy link
Member

@juan-campos fixed this bug in PR #121, which we've now merged. Would you take a look and see if this solves your problem? @juan-campos also added some tests to try to make sure we don't reintroduce this bug!

@zzygith
Copy link

zzygith commented Sep 26, 2023

The problem has been solved. Thanks for wonderful work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants