Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

关于ResRep模型性能对比 #20

Open
wangxd15 opened this issue Jul 17, 2023 · 4 comments
Open

关于ResRep模型性能对比 #20

wangxd15 opened this issue Jul 17, 2023 · 4 comments

Comments

@wangxd15
Copy link

你好,最近刚好用到ResRep剪枝,我看本框架和原始ResRep论文的实现方式稍有差异, 本框架直接移除选中的卷积通道层但原论文是对选中通道施加惩罚因子使其逐渐趋向0,或者说反向传播过程中对保留和移除卷积通道层施加不同的梯度更新策略。
if isinstance(nn_object, Compactor): lasso_grad = value.data * ((value.data ** 2).sum(dim=(1, 2, 3), keepdim=True) ** (-0.5)) value.grad.data.add_(self.config["lasso_decay"], lasso_grad)
请问实际测试中有比对两种方案的性能差异么~

@gdh1995
Copy link
Collaborator

gdh1995 commented Jul 23, 2023

抱歉原论文作者已经毕业了,resrep实现的同学也毕业了,我只能说我的印象:实践上现在的代码泛用性还不错。

另外 lasso_grad 这不是算了个分母嘛,目的就是都让 compactor 这一层的一部分列更小对吧?为什么你觉得它是“直接移除选中的卷积通道层”?

这段代码每次迭代的 backward 后都会执行,而“直接移除选中的卷积通道层”是隔一阵子才执行一次。具体来说,是先warmup,然后每隔 prune_interval 次迭代才去检查有哪些可以删的列,进而删除的。

@Annmixiu
Copy link

我对这部分代码的理解是:先有目的的减小compactor中的某些参数,再在if self.variable_dict["prune_iteration"] % self.config["prune_interval"] == 0的条件下执行检查和删除参数,而不是直接移除选中的卷积通道层,希望对题主有帮助。

@wangxd15
Copy link
Author

wangxd15 commented Jul 24, 2023

抱歉,可能我表达不够清晰。此框架下反向传播过程采用 损失函数 + 通道稀疏惩罚,不区分保留和需要移除的通道或者列。该框架反向传播

            if isinstance(nn_object, Compactor):
                lasso_grad = value.data * (
                    (value.data ** 2).sum(dim=(1, 2, 3), keepdim=True) ** (-0.5)
                )
                value.grad.data.add_(self.config["lasso_decay"], lasso_grad)

if self.variable_dict["prune_iteration"] % self.config["prune_interval"] == 0:条件下通过排序删除模长最小的通道。

ResRep源码通过模长排序确定Mask(需保留和待移除的列),然后在反向传播中对保留的列采用对抗,对需要移除的列惩罚,让其模长逐渐趋向为0。

    for compactor_param, mask in compactor_mask_dict.items():
        compactor_param.grad.data = mask * compactor_param.grad.data
        lasso_grad = compactor_param.data * ((compactor_param.data ** 2).sum(dim=(1, 2, 3), keepdim=True) ** (-0.5))
        compactor_param.grad.data.add_(resrep_config.lasso_strength, lasso_grad)

两者之间存在Mask的差异,Mask也是原作者Ding强调的记忆遗忘的差别。

综上因此咨询两者训练方式的性能差异。

@gdh1995
Copy link
Collaborator

gdh1995 commented Jul 24, 2023

感谢指出这个问题。不过超出我能力范围了。

@ZizhouJia @DingXiaoH 有闲功夫请看看 #20 (comment) 这个问题

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants