Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with type mismatch in get_sparse_C and backward_general functions #80

Open
929937690 opened this issue Nov 10, 2024 · 0 comments
Open

Comments

@929937690
Copy link

I encountered a type mismatch issue when using the alpha-beta crown method. Specifically, when calling get_sparse_C with unstable_size > crown_batch_size, the variable newC is set to the string 'Patches' instead of an actual Patches object. This causes an AssertionError when performing type checking in backward_general since the expected type is Patches but newC is a string.

The error message produced is: AssertionError: <class 'str'>

I believe this issue could also arise if newC is set to the string 'eye' rather than an eyeC object, leading to similar assertion errors.

Is this behavior intentional, or would there be a recommended workaround to prevent these type mismatches? I’d appreciate any insights into handling these cases, as currently, only the 'Patches' string assignment has triggered the error for me.

# bound_general.py
def compute_intermediate_bounds()
    ...
    sparse_C = self.get_sparse_C(node, ref_intermediate)
    ...
    ... = self.backward_general()
    ...
    

# backward_bound.py
def get_sparse_C()
    ...
    if (isinstance(node, BoundLinear) or isinstance(node, BoundMatMul)) and int(
            os.environ.get('AUTOLIRPA_USE_FULL_C', 0)) == 0:
        ...
            if not reduced_dim:
                if dim > crown_batch_size:
                    newC = 'eye'
           else:
                newC = eyeC([batch_size, dim, *node.output_shape[1:]], self.device)
    elif node.patches_start and node.mode == "patches":
        if sparse_intermediate_bounds:
            ...
            elif unstable_size > crown_batch_size: 
                    newC = 'Patches'
                    reduced_dim = True
    ...
    else:
        ...
        
        if not reduced_dim:
            ...
            if dim > crown_batch_size:
                newC = 'eye'
            else:
                newC = torch.eye(dim, device=self.device).unsqueeze(0).expand(
                    batch_size, -1, -1
                ).view(batch_size, dim, *node.output_shape[1:])
    ...

def backward_general()
    ...
    if self.infeasible_bounds is None:
        if isinstance(C, Patches):
            self.infeasible_bounds = torch.full((C.shape[1],), False, device=device)
        else:
        # If C is an Tensor/eyeC/OneHotC object, we go in to the second condition
            assert isinstance(C, (torch.Tensor, eyeC, OneHotC)), type(C)
            self.infeasible_bounds = torch.full((C.shape[0],), False, device=device)
    ...

The error is shown in the following screen shot:

Screenshot from 2024-11-09 21-59-05

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant