Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

二分查找lp #4

Closed
miaoxiaodaiblack opened this issue Mar 21, 2021 · 1 comment
Closed

二分查找lp #4

miaoxiaodaiblack opened this issue Mar 21, 2021 · 1 comment

Comments

@miaoxiaodaiblack
Copy link

您好!
我想用‘simple_verification.py’中代码计算的lp和Up实现您在CROWN那篇论文中实现的二分查找法,并尝试更改如下:
def test(input,model):
eps = 0
gap_gx = 100
eps_LB = -1
eps_UB = 1
counter = 0
is_pos = True
is_neg = True

# perform binary search
eps_gx_UB = 1000000.0
eps_gx_LB = 0.0
is_pos = True
is_neg = True
# eps = eps_gx_LB*2
# eps = args.eps

while eps_gx_UB - eps_gx_LB > 0.00001:
    ptb = PerturbationLpNorm(norm=2, eps=eps)
    image = BoundedTensor(input, ptb)
    pred = model(image)
    label = torch.argmax(pred, dim=1).cpu().numpy()
    # for method in ['IBP', 'IBP+backward (CROWN-IBP)', 'backward (CROWN)']:
    lb, ub = model.compute_bounds(x=(image,), method='IBP+backward')
    gap_gx = torch.min(lb)
    lb = lb.detach().cpu().numpy()
    ub = ub.detach().cpu().numpy()
    print("Bounding method:", method)
    for i in range(N):
        print("Image {} top-1 prediction {} ground-truth {}".format(i, label[i], true_label[i]))
        for j in range(n_classes):
            indicator = '(ground-truth)' if j == true_label[i] else ''
            print("f_{j}(x_0): {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f} {ind}".format(
                j=j, l=lb[i][j], u=ub[i][j], ind=indicator))
    print()
    if gap_gx > 0:
        if gap_gx < 0.01:
            eps_gx_LB = eps
            return eps
            break
        if is_pos:  # so far always > 0, haven't found eps_UB
            eps_gx_LB = eps
            eps *= 10
        else:
            eps_gx_LB = eps
            eps = (eps_gx_LB + eps_gx_UB) / 2
        is_neg = False
    else:
        if is_neg:  # so far always < 0, haven't found eps_LB
            eps_gx_UB = eps
            eps /= 10
        else:
            eps_gx_UB = eps
            eps = (eps_gx_LB + eps_gx_UB) / 2
        is_pos = False
    counter += 1
    if counter >= 500:
        return eps
        break
print("[L2][binary search] step = {}, eps = {:.5f}, gap_gx = {:.2f}".format(counter, eps, gap_gx))

但是得到的并不是想要的结果,请问是哪里出了问题吗?

@huanzhang12
Copy link
Owner

See #5

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants