Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Public issue] AssertionError assert repr_input_shape in self.clients in hybrid model with multiple submodules #844

Open
gy-cao opened this issue Aug 22, 2024 · 7 comments

Comments

@gy-cao
Copy link

gy-cao commented Aug 22, 2024

Hi!

I am trying to use the hybrid model with multiple submodules to be evaluated in FHE.

I tried putting two submodules and the compilation of the circuit is ok.

But when I run the inference, it seems that the first submodules can be evaluated.

But when evaluating the second one, it shows assert repr_input_shape in self.clients AssertionError on the client side.

So roughly, on the client side, I did something like

# model.pth is the file saved by save_and_clear_private_info
model = torch.load('./compiled_models/net/model.pth')  
hybrid_model = HybridFHEModel(
        model,
        module_names,
        server_remote_address="http://0.0.0.0:8000",
        model_name=model_name,
        verbose=False,
    )
path_to_clients = Path(__file__).parent / "clients"
hybrid_model.init_client(path_to_clients=path_to_clients)
hybrid_model.set_fhe_mode(HybridFHEMode.REMOTE)

I am using the hybrid model for a unet. The submodules I put is up1.pixel_shuffle and final. My Unet is as follows:

import torch
import torch.nn as nn
import torch.nn.utils.prune as prune

# Define a customized upsample module for standard use
class CustomPixelShuffle(nn.Module):
    def __init__(self, upscale_factor):
        super(CustomPixelShuffle, self).__init__()
        self.upscale_factor = upscale_factor

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        upscale_factor = self.upscale_factor
        channels //= (upscale_factor ** 2)

        x = x.view(batch_size, channels, upscale_factor, upscale_factor, height, width)
        x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
        x = x.view(batch_size, channels, height * upscale_factor, width * upscale_factor)
        return x

class CustomUpsample(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor):
        super(CustomUpsample, self).__init__()
        self.scale_factor = scale_factor
        self.pixel_shuffle = CustomPixelShuffle(scale_factor)
        self.conv = nn.Conv2d(
            in_channels // (scale_factor ** 2), out_channels, kernel_size=3, padding=1
        )

    def forward(self, x):
        x = self.pixel_shuffle(x)
        x = self.conv(x)
        return x

# Define a standard convolutional layer with batch normalization and ReLU activation
class Conv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(Conv2d, self).__init__()
        self.conv = nn.Conv2d(
            in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding
        )
        self.bn = nn.BatchNorm2d(num_features=out_channels)
        self.relu = nn.ReLU() 

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

# Define a UNet architecture with standard components
class UNet(nn.Module):
    def __init__(self): 
        super(UNet, self).__init__()

        # Encoder path
        self.enc1 = nn.Sequential(
            Conv2d(in_channels=1, out_channels=32),  
            Conv2d(in_channels=32, out_channels=32)
        )
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.enc2 = nn.Sequential(
            Conv2d(in_channels=32, out_channels=64),
            Conv2d(in_channels=64, out_channels=64) 
        )
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.enc3 = nn.Sequential(
            Conv2d(in_channels=64, out_channels=128), 
            Conv2d(in_channels=128, out_channels=128) 
        )
        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.enc4 = nn.Sequential(
            Conv2d(in_channels=128, out_channels=256), 
            Conv2d(in_channels=256, out_channels=256) 
        )
        self.pool4 = nn.MaxPool2d(kernel_size=2)

        self.enc5 = nn.Sequential(
            Conv2d(in_channels=256, out_channels=512),
            Conv2d(in_channels=512, out_channels=512)
        )
        self.pool5 = nn.MaxPool2d(kernel_size=2)

        # Bottleneck (central) layer
        self.bottleneck = nn.Sequential(
            Conv2d(in_channels=512, out_channels=1024), 
            Conv2d(in_channels=1024, out_channels=1024) 
        )

        # Upsampling path with CustomUpsample and decoder layers
        self.up5 = CustomUpsample(in_channels=1024, out_channels=512, scale_factor=2)
        self.dec5 = nn.Sequential(
            Conv2d(in_channels=1024, out_channels=512), 
            Conv2d(in_channels=512, out_channels=512)  
        )

        self.up4 = CustomUpsample(in_channels=512, out_channels=256, scale_factor=2)
        self.dec4 = nn.Sequential(
            Conv2d(in_channels=512, out_channels=256), 
            Conv2d(in_channels=256, out_channels=256)  
        )

        self.up3 = CustomUpsample(in_channels=256, out_channels=128, scale_factor=2)
        self.dec3 = nn.Sequential(
            Conv2d(in_channels=256, out_channels=128), 
            Conv2d(in_channels=128, out_channels=128)  
        )

        self.up2 = CustomUpsample(in_channels=128, out_channels=64, scale_factor=2)
        self.dec2 = nn.Sequential(
            Conv2d(in_channels=128, out_channels=64),  
            Conv2d(in_channels=64, out_channels=64)
        )

        self.up1 = CustomUpsample(in_channels=64, out_channels=32, scale_factor=2)
        self.dec1 = nn.Sequential(
            Conv2d(in_channels=64, out_channels=32), 
            Conv2d(in_channels=32, out_channels=32) 
        )

        # Final convolutional layer for output
        self.final = nn.Conv2d(
            in_channels=32, out_channels=1, kernel_size=1, stride=1, padding=0
        )

    def forward(self, x):
        # Forward pass through the network
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool1(enc1))
        enc3 = self.enc3(self.pool2(enc2))
        enc4 = self.enc4(self.pool3(enc3))
        enc5 = self.enc5(self.pool4(enc4))

        bottleneck = self.bottleneck(self.pool5(enc5))

        up5 = self.up5(bottleneck)
        dec5 = self.dec5(torch.cat((up5, enc5), dim=1))

        up4 = self.up4(dec5)
        dec4 = self.dec4(torch.cat((up4, enc4), dim=1))

        up3 = self.up3(dec4)
        dec3 = self.dec3(torch.cat((up3, enc3), dim=1))

        up2 = self.up2(dec3)
        dec2 = self.dec2(torch.cat((up2, enc2), dim=1))

        up1 = self.up1(dec2)
        dec1 = self.dec1(torch.cat((up1, enc1), dim=1))

        output = torch.sigmoid(self.final(dec1))
        return output

#model = UNet()
#print(model)
#for (k, _) in model.named_modules():
#    print(k)

For the server setup and compilation of the circuit, I am basically following this example https://github.com/zama-ai/concrete-ml/tree/main/use_case_examples/hybrid_model. Could you please check why this happens? Thanks!

P.S. If you need the full code from me, just DM me at Gan in CML channel of FHE discord...i am pretty active there....Thanks!

Best,
Gan

@bcm-at-zama
Copy link
Collaborator

Thanks for opening the issue.

For further usage: it's related to what was discussed in https://discord.com/channels/901152454077452399/1276101101267189783

@bcm-at-zama
Copy link
Collaborator

@gy-cao : yes, having the full code always help. If you can put it in the issue or in a new GitHub repository it would help. Thanks

@gy-cao
Copy link
Author

gy-cao commented Aug 22, 2024

@gy-cao : yes, having the full code always help. If you can put it in the issue or in a new GitHub repository it would help. Thanks

Hi @bcm-at-zama! Thanks! Is it ok if I send the code to you via discord then you can forward it to the team?

@bcm-at-zama
Copy link
Collaborator

As long as it's python code (no zip files), yes! Another possibility is to create a private GitHub repo, and to share it with me

@bcm-at-zama bcm-at-zama changed the title AssertionError assert repr_input_shape in self.clients in hybrid model with multiple submodules [Public issue] AssertionError assert repr_input_shape in self.clients in hybrid model with multiple submodules Aug 22, 2024
@bcm-at-zama
Copy link
Collaborator

I am going to make a private issue for this, where we investigate. We'll tell you when we know

@gy-cao
Copy link
Author

gy-cao commented Aug 22, 2024

I am going to make a private issue for this, where we investigate. We'll tell you when we know

Hi @bcm-at-zama. Not sure if it is a typo. But it is marked as "public issue" in the title.

@bcm-at-zama
Copy link
Collaborator

Yes I marked #844 as "[Public issue]" to make it clearly different from the working-issue we have done on our internal repo, to which you don't have access.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants