Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

请问代码是否支持单步推理? #23

Open
USTCYYX opened this issue Jan 17, 2024 · 65 comments
Open

请问代码是否支持单步推理? #23

USTCYYX opened this issue Jan 17, 2024 · 65 comments
Labels
question Further information is requested

Comments

@USTCYYX
Copy link

USTCYYX commented Jan 17, 2024

您好,查看代码后我发现推理使用的是多步推理。请问我能否用
from spikingjelly.clock_driven import functional
functional.set_step_mode(net, 's')
方便地将网络改为单步推理模式?
如果不能,您能否提供一个可行的办法使得可以在单步模式下推理?谢谢!

@fangwei123456
Copy link
Owner

这个时期的SJ框架好像不支持切换

@fangwei123456 fangwei123456 added the question Further information is requested label Jan 18, 2024
@USTCYYX
Copy link
Author

USTCYYX commented Jan 18, 2024

我看到代码中用的神经元都有MultiStep的前缀,如MultiStepIFNode。那请问是否有对应的单步神经元?如果我将多步神经元全部替换为单步神经元,可以实现单步推理吗?因为sj里似乎单步和多步只与神经元层有关,和其他层无关。

@fangwei123456
Copy link
Owner

可以这样做,把神经元换成单步神经元,把用seq to ann container包装的无状态层去掉包装即可在单步模式下运行

@USTCYYX
Copy link
Author

USTCYYX commented Jan 18, 2024

我没有找到这一版本SJ的文档,您能否提供这一版本里神经元类的文档?

@fangwei123456
Copy link
Owner

这个版本可能是04
https://spikingjelly.readthedocs.io/zh-cn/0.0.0.0.4/

@USTCYYX
Copy link
Author

USTCYYX commented Jan 18, 2024

针对您提出的解决方案,我截取了一段典型的代码:
self.conv2 = layer.SeqToANNContainer(
conv3x3(width, width, stride, groups, dilation),
norm_layer(width)
)
self.sn2 = cext_neuron.MultiStepIFNode(detach_reset=True)
要转到单步,是否应该这样修改:
self.conv2 = nn.Sequential(
conv3x3(width, width, stride, groups, dilation),
norm_layer(width)
)
self.sn2 = cext_neuron.IFNode(detach_reset=True)

@fangwei123456
Copy link
Owner

是的,这样做即可。

可能引发的小问题是SeqToANNContainer换成Sequential后加载训练好的权重出问题。由于SeqToANNContainer会融合TN,你可以考虑在单步模式下,给SeqToANNContainer注册个钩子,前向传播时输入unsqueeze(0)补充一个T=1的维度,输出在squeeze(0)去掉这个维度

@USTCYYX
Copy link
Author

USTCYYX commented Jan 18, 2024

“由于SeqToANNContainer会融合TN,你可以考虑在单步模式下,给SeqToANNContainer注册个钩子,前向传播时输入unsqueeze(0)补充一个T=1的维度,输出在squeeze(0)去掉这个维度”
这里您的意思是不改成nn.Sequential是吗?然后每次输入[1,N,C,H,W],1用来给SeqToANNContainer做融合使用,没有实际意义。

@USTCYYX
Copy link
Author

USTCYYX commented Jan 18, 2024

正常单步应该输入[N,C,H,W],然后改成[1,N,C,H,W]去适应SeqToANNContainer

@fangwei123456
Copy link
Owner

是的,如果你不加载权重的话就不需要管这个

https://spikingjelly.readthedocs.io/zh-cn/0.0.0.0.14/activation_based/container.html

@USTCYYX
Copy link
Author

USTCYYX commented Jan 18, 2024

好的,谢谢您,我应该是要加载权重的,多步训练,然后用单步做推理。另外我还发现似乎有一个问题。
如果现在修改为单步,那么就是:
self.conv2 = layer.SeqToANNContainer(
conv3x3(width, width, stride, groups, dilation),
norm_layer(width)
)
self.sn2 = cext_neuron.IFNode(detach_reset=True)
按照做法,self.conv2 输出的是[1,N,C,H,W]
单步的神经元是否能输入[1,N,C,H,W]这样五维度的tensor?按理来说,应该输入[N,C,H,W]给单步神经元.

@fangwei123456
Copy link
Owner

fangwei123456 commented Jan 18, 2024

不能,所以需要对上一层的输出或者这一层的输入squeeze(0)去掉这个维度

@USTCYYX
Copy link
Author

USTCYYX commented Jan 18, 2024

那我能不能这样做,不做任何修改,就用多步神经元和SeqToANNContainer,然后:

for image in dataset:
    total_fr=0.
    for t in range(T):
        image=image.unsqueeze(0)
        output=net(image)
        total_fr+=output.squeeze(0)
    functional.reset_net(net)

问题的关键就是,多步神经元能不能支持这样的多次输入,并在多次输入下记忆膜电位?因为一般来说多步神经元不需要面对这种多次输入的情况。

@fangwei123456
Copy link
Owner

多步神经元能不能支持这样的多次输入,并在多次输入下记忆膜电位?

最新的SJ框架可以。这个repo中用的那个我不确定,建议你翻一下源代码

@USTCYYX
Copy link
Author

USTCYYX commented Jan 18, 2024

这一版的监视器是不是不支持查看电压呀? 我想直接看每个时间步长的电压就行了。源代码里面没有cext.neuron的部分,我在文档中没有找到。

@fangwei123456
Copy link
Owner

最新版的监视器是支持看电压的,用访问成员变量的方式查看

@USTCYYX
Copy link
Author

USTCYYX commented Jan 18, 2024

你好,经过我的测试,多步神经元应该是可以保存电压直到下次reset的,您提供版本的SJ没办法查看膜电压,我只能通过输出的spike来判断,具体来说,我用的是下面的代码:

import torch
from torch import nn
from spikingjelly.cext import neuron as cext_neuron
from spikingjelly.clock_driven import layer, functional

x = torch.rand([1000, 1, 784])* 2 + 1
x = x.cuda()

net = nn.Sequential(
                layer.SeqToANNContainer(
                    nn.Linear(784, 100, bias=False)
                ),
                cext_neuron.MultiStepIFNode(alpha=2.0),
                layer.SeqToANNContainer(
                    nn.Linear(100, 10, bias=False)
                ),
                cext_neuron.MultiStepIFNode(alpha=2.0),
            )

net=net.cuda()
out=net(x)

functional.reset_net(net)

for t in range(1000):
    out1=net(x[t,:,:].unsqueeze(dim=0)).squeeze(dim=0)
    print(out1.equal(out[t]))

至少在1000个步长上,神经网络的输出都是一致的。因此我初步认为有记忆性。
但是再sew_resnet上,我进行了测试,发现单步和多步似乎不一致。下面是我的测试代码

import datetime
import os
import time
import torch
import torch.utils.data
from torch import nn
import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import math
from torch.cuda import amp
import torch.distributed.optim
import argparse

from spikingjelly.clock_driven import functional
import spiking_resnet, sew_resnet, utils

_seed_ = 2020
import random
random.seed(2020)

torch.manual_seed(_seed_)  # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
torch.cuda.manual_seed_all(_seed_)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

import numpy as np
np.random.seed(_seed_)

import numpy as np
np.random.seed(_seed_)


def main():
    x=torch.rand(1,3,224,224)
    x=x.to('cuda:0')
    
    model = sew_resnet.__dict__['sew_resnet18'](T=3, connect_f='ADD')
    model.to('cuda:0')
    out=model(x)
    print(out)
    
#    model1 = sew_resnet.__dict__['sew_resnet18'](T=1, connect_f='ADD')
#    model1.to('cuda:0')
#    out_all=0.
#    for t in range(3):
#        out1=model1(x)/3.
#        out_all+=out1
#    print(out_all)

if __name__ == "__main__":
    main()

在这里我用了两个网络model和model1,这是因为必须要预先设置T,原因是在model里会根据T去复制输入。

    def _forward_impl(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x.unsqueeze_(0)
        x = x.repeat(self.T, 1, 1, 1, 1)
        x = self.sn1(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 2)
        return self.fc(x.mean(dim=0))

这是sew_resnet的源码,我的理解是前两层是一个输入编码层,并且在最后输出了平均电压。因此按照我的测试方法,model和model1的输出应该是一致的,但是结果不一致。希望您能告诉我原因

@USTCYYX
Copy link
Author

USTCYYX commented Jan 18, 2024

我修改了一下代码,在使用预训练后,下面这个代码的单步和多步输出终于一致了:

import datetime
import os
import time
import torch
import torch.utils.data
from torch import nn
import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import math
from torch.cuda import amp
import torch.distributed.optim
import argparse

from spikingjelly.clock_driven import functional
import spiking_resnet, sew_resnet, utils

_seed_ = 2020
import random
random.seed(2020)

torch.manual_seed(_seed_)  # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
torch.cuda.manual_seed_all(_seed_)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

import numpy as np
np.random.seed(_seed_)

import numpy as np
np.random.seed(_seed_)


def main():
    x=torch.rand(1,3,224,224)
    x=x.to('cuda:0')
    checkpoint = torch.load('sew18_checkpoint_319.pth', map_location='cuda:0')
    
    model = sew_resnet.__dict__['sew_resnet18'](T=3, connect_f='ADD')
    model.load_state_dict(checkpoint['model'])
    model.to('cuda:0')
    model.eval()
    out=model(x)
#    print(out)
    
    model1 = sew_resnet.__dict__['sew_resnet18'](T=1, connect_f='ADD')
    model1.load_state_dict(checkpoint['model'])
    model1.to('cuda:0')
    out_all=0.
    model1.eval()
    for t in range(3):
        out1=model1(x)/3.
        out_all+=out1
#    print(out_all)
    
    result=torch.where(torch.abs(out-out_all)<0.01,True,False)
    print(result)
if __name__ == "__main__":
    main()

输出:全部都是True,即误差在可控范围内,这些误差应该是不可去的,属于单步和多步的固有误差。
但是下面的没有加载预权重的代码还是无法得到一致的输出:

import datetime
import os
import time
import torch
import torch.utils.data
from torch import nn
import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import math
from torch.cuda import amp
import torch.distributed.optim
import argparse

from spikingjelly.clock_driven import functional
import spiking_resnet, sew_resnet, utils

_seed_ = 2020
import random
random.seed(2020)

torch.manual_seed(_seed_)  # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
torch.cuda.manual_seed_all(_seed_)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

import numpy as np
np.random.seed(_seed_)

import numpy as np
np.random.seed(_seed_)


def main():
    x=torch.rand(1,3,224,224)
    x=x.to('cuda:0')
    
    model = sew_resnet.__dict__['sew_resnet18'](T=3, connect_f='ADD')
    model.to('cuda:0')
    model.eval()
    out=model(x)
    print(out)
    
    model1 = sew_resnet.__dict__['sew_resnet18'](T=1, connect_f='ADD')
    model1.to('cuda:0')
    model1.eval()
    out_all=0.
    for t in range(3):
        out1=model1(x)/3.
        out_all+=out1
    print(out_all)

if __name__ == "__main__":
    main()

不知道您是否知道下面的代码为什么没法得到一致的输出。

@fangwei123456
Copy link
Owner

return self.fc(x.mean(dim=0))

是不是这里导致的?前向传播的写法默认按多步来,最终的输出被直接平均了

@USTCYYX
Copy link
Author

USTCYYX commented Jan 19, 2024

具体不是很懂,但是fc层是linear的,先平均后输入和先输入后平均得到的结果应该是一致的吧?哪怕有bias也是一样。
请问dvsgesture的weight怎么使用啊?好像有名字不一致的问题。

    checkpoint = torch.load('checkpoint_191.pth', map_location='cuda:0')
    
    model = smodels.__dict__['PlainNet'](connect_f='ADD')
    model.load_state_dict(checkpoint['model'])
    model.to('cuda:0')

报错:
Missing key(s) in state_dict: "conv.0.0.module.1.weight", "conv.0.0.module.1.bias", "conv.0.0.module.1.running_mean", "conv.0.0.module.1.running_var", "conv.1.conv.0.0.module.1.weight", "conv.1.conv.0.0.module.1.bias", "conv.1.conv.0.0.module.1.running_mean", "conv.1.conv.0.0.module.1.running_var", "conv.1.conv.1.0.module.1.weight", "conv.1.conv.1.0.module.1.bias", "conv.1.conv.1.0.module.1.running_mean", "conv.1.conv.1.0.module.1.running_var", "conv.3.conv.0.0.module.1.weight", "conv.3.conv.0.0.module.1.bias", "conv.3.conv.0.0.module.1.running_mean", "conv.3.conv.0.0.module.1.running_var", "conv.3.conv.1.0.module.1.weight", "conv.3.conv.1.0.module.1.bias", "conv.3.conv.1.0.module.1.running_mean", "conv.3.conv.1.0.module.1.running_var", "conv.5.conv.0.0.module.1.weight", "conv.5.conv.0.0.module.1.bias", "conv.5.conv.0.0.module.1.running_mean", "conv.5.conv.0.0.module.1.running_var", "conv.5.conv.1.0.module.1.weight", "conv.5.conv.1.0.module.1.bias", "conv.5.conv.1.0.module.1.running_mean", "conv.5.conv.1.0.module.1.running_var", "conv.7.conv.0.0.module.1.weight", "conv.7.conv.0.0.module.1.bias", "conv.7.conv.0.0.module.1.running_mean", "conv.7.conv.0.0.module.1.running_var", "conv.7.conv.1.0.module.1.weight", "conv.7.conv.1.0.module.1.bias", "conv.7.conv.1.0.module.1.running_mean", "conv.7.conv.1.0.module.1.running_var", "conv.9.conv.0.0.module.1.weight", "conv.9.conv.0.0.module.1.bias", "conv.9.conv.0.0.module.1.running_mean", "conv.9.conv.0.0.module.1.running_var", "conv.9.conv.1.0.module.1.weight", "conv.9.conv.1.0.module.1.bias", "conv.9.conv.1.0.module.1.running_mean", "conv.9.conv.1.0.module.1.running_var", "conv.11.conv.0.0.module.1.weight", "conv.11.conv.0.0.module.1.bias", "conv.11.conv.0.0.module.1.running_mean", "conv.11.conv.0.0.module.1.running_var", "conv.11.conv.1.0.module.1.weight", "conv.11.conv.1.0.module.1.bias", "conv.11.conv.1.0.module.1.running_mean", "conv.11.conv.1.0.module.1.running_var", "conv.13.conv.0.0.module.1.weight", "conv.13.conv.0.0.module.1.bias", "conv.13.conv.0.0.module.1.running_mean", "conv.13.conv.0.0.module.1.running_var", "conv.13.conv.1.0.module.1.weight", "conv.13.conv.1.0.module.1.bias", "conv.13.conv.1.0.module.1.running_mean", "conv.13.conv.1.0.module.1.running_var".
Unexpected key(s) in state_dict: "conv.0.0.module.1.0.weight", "conv.0.0.module.1.0.bias", "conv.0.0.module.1.0.running_mean", "conv.0.0.module.1.0.running_var", "conv.0.0.module.1.0.num_batches_tracked", "conv.1.conv.0.0.module.1.0.weight", "conv.1.conv.0.0.module.1.0.bias", "conv.1.conv.0.0.module.1.0.running_mean", "conv.1.conv.0.0.module.1.0.running_var", "conv.1.conv.0.0.module.1.0.num_batches_tracked", "conv.1.conv.1.0.module.1.0.weight", "conv.1.conv.1.0.module.1.0.bias", "conv.1.conv.1.0.module.1.0.running_mean", "conv.1.conv.1.0.module.1.0.running_var", "conv.1.conv.1.0.module.1.0.num_batches_tracked", "conv.3.conv.0.0.module.1.0.weight", "conv.3.conv.0.0.module.1.0.bias", "conv.3.conv.0.0.module.1.0.running_mean", "conv.3.conv.0.0.module.1.0.running_var", "conv.3.conv.0.0.module.1.0.num_batches_tracked", "conv.3.conv.1.0.module.1.0.weight", "conv.3.conv.1.0.module.1.0.bias", "conv.3.conv.1.0.module.1.0.running_mean", "conv.3.conv.1.0.module.1.0.running_var", "conv.3.conv.1.0.module.1.0.num_batches_tracked", "conv.5.conv.0.0.module.1.0.weight", "conv.5.conv.0.0.module.1.0.bias", "conv.5.conv.0.0.module.1.0.running_mean", "conv.5.conv.0.0.module.1.0.running_var", "conv.5.conv.0.0.module.1.0.num_batches_tracked", "conv.5.conv.1.0.module.1.0.weight", "conv.5.conv.1.0.module.1.0.bias", "conv.5.conv.1.0.module.1.0.running_mean", "conv.5.conv.1.0.module.1.0.running_var", "conv.5.conv.1.0.module.1.0.num_batches_tracked", "conv.7.conv.0.0.module.1.0.weight", "conv.7.conv.0.0.module.1.0.bias", "conv.7.conv.0.0.module.1.0.running_mean", "conv.7.conv.0.0.module.1.0.running_var", "conv.7.conv.0.0.module.1.0.num_batches_tracked", "conv.7.conv.1.0.module.1.0.weight", "conv.7.conv.1.0.module.1.0.bias", "conv.7.conv.1.0.module.1.0.running_mean", "conv.7.conv.1.0.module.1.0.running_var", "conv.7.conv.1.0.module.1.0.num_batches_tracked", "conv.9.conv.0.0.module.1.0.weight", "conv.9.conv.0.0.module.1.0.bias", "conv.9.conv.0.0.module.1.0.running_mean", "conv.9.conv.0.0.module.1.0.running_var", "conv.9.conv.0.0.module.1.0.num_batches_tracked", "conv.9.conv.1.0.module.1.0.weight", "conv.9.conv.1.0.module.1.0.bias", "conv.9.conv.1.0.module.1.0.running_mean", "conv.9.conv.1.0.module.1.0.running_var", "conv.9.conv.1.0.module.1.0.num_batches_tracked", "conv.11.conv.0.0.module.1.0.weight", "conv.11.conv.0.0.module.1.0.bias", "conv.11.conv.0.0.module.1.0.running_mean", "conv.11.conv.0.0.module.1.0.running_var", "conv.11.conv.0.0.module.1.0.num_batches_tracked", "conv.11.conv.1.0.module.1.0.weight", "conv.11.conv.1.0.module.1.0.bias", "conv.11.conv.1.0.module.1.0.running_mean", "conv.11.conv.1.0.module.1.0.running_var", "conv.11.conv.1.0.module.1.0.num_batches_tracked", "conv.13.conv.0.0.module.1.0.weight", "conv.13.conv.0.0.module.1.0.bias", "conv.13.conv.0.0.module.1.0.running_mean", "conv.13.conv.0.0.module.1.0.running_var", "conv.13.conv.0.0.module.1.0.num_batches_tracked", "conv.13.conv.1.0.module.1.0.weight", "conv.13.conv.1.0.module.1.0.bias", "conv.13.conv.1.0.module.1.0.running_mean", "conv.13.conv.1.0.module.1.0.running_var", "conv.13.conv.1.0.module.1.0.num_batches_tracked".

@fangwei123456
Copy link
Owner

应该不是fc mean的问题

但是下面的没有加载预权重的代码还是无法得到一致的输出:

权重初始化不一样

名字不一致的问题

是SeqToANNContainer导致的,解决方法前面提到过:#23 (comment)

@USTCYYX
Copy link
Author

USTCYYX commented Jan 19, 2024

我试了一下,时间长了之后,单步和多步总是会产生一些细小的误差,可能这是无法避免的吧。

@fangwei123456
Copy link
Owner

误差是不能避免的。即便是pytorch中一个conv层,把2个bach拼在一块送进去,和分别送入再把输出合并,最后得到的结果都会有很小的差异

@USTCYYX
Copy link
Author

USTCYYX commented Jan 19, 2024

前面提到的dvsgesture模型的问题,我是简化了您的代码里的输入weight过程,是不是我哪里写错了?
您的:

    model = smodels.__dict__[args.model](args.connect_f)
    print("Creating model")

    model.to(device)
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss()

    if args.adam:
        optimizer = torch.optim.Adam(
            model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(
            model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    if args.amp:
        scaler = amp.GradScaler()
    else:
        scaler = None

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1
        max_test_acc1 = checkpoint['max_test_acc1']
        test_acc5_at_max_test_acc1 = checkpoint['test_acc5_at_max_test_acc1']

我的:

    checkpoint = torch.load('checkpoint_191.pth', map_location='cuda:0')
    
    model = smodels.__dict__['PlainNet'](connect_f='ADD')
    model.load_state_dict(checkpoint['model'])
    model.to('cuda:0')

@fangwei123456
Copy link
Owner

打印一下check point和model的state_dict的keys,对比一下差异

@USTCYYX
Copy link
Author

USTCYYX commented Jan 19, 2024

就在前面,主要是后面多了 .0 和 .num_batches_tracked

@fangwei123456
Copy link
Owner

.0可能是SeqToANNContainer导致的?在新版框架中,

https://github.com/fangwei123456/spikingjelly/blob/cb1cee00334ebeca101a155aee694252f85543a8/spikingjelly/activation_based/layer.py#L35

SeqToANNContainer继承自Sequential,state dict中不会多一个.0的前缀了,和sequential一致

@fangwei123456
Copy link
Owner

建议你用类似的方法修改模型的定义代码,让key和check point一致

@USTCYYX
Copy link
Author

USTCYYX commented Jan 19, 2024

您好,我觉得误差还是有点大,尤其是在长步长后,可能直接用现有的代码做单步推理不是很现实。
但是迁移到新的SJ版本上,我觉得也比较困难,不知道您有没有试过用新版本的SJ复现SEW的全部工作,在新版本的SJ上用SEW训练同样的数据集精度会差很多吗?

@fangwei123456
Copy link
Owner

是的,然后运行在单步模式

@USTCYYX
Copy link
Author

USTCYYX commented Jan 19, 2024

有个问题啊,就是新版的sewresnet里的conv3*3调用的是layer里的:

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return layer.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)

和这里的不太一样。
然后我这样修改过去的代码能够自由切换单步和多步吗?我发现单步和多部在训练的时候精度也会有差距。

@fangwei123456
Copy link
Owner

nn里面的网络不能切换单步多步
layer里面的才可以

@USTCYYX
Copy link
Author

USTCYYX commented Jan 19, 2024

哦,怪不得新版SEWResnet里全部换成layer了。
那我现在将nn也改成layer,除了nn.Sequential,是这样的:

import torch
import torch.nn as nn
from spikingjelly.activation_based.neuron import ParametricLIFNode
from spikingjelly.activation_based import layer

def conv3x3(in_channels, out_channels):
    return nn.Sequential(
            layer.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
            layer.BatchNorm2d(out_channels),
            ParametricLIFNode(init_tau=2.0, detach_reset=True)
    )

def conv1x1(in_channels, out_channels):
    return nn.Sequential(
            layer.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
            layer.BatchNorm2d(out_channels),
            ParametricLIFNode(init_tau=2.0, detach_reset=True)
    )

class SEWBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, connect_f=None):
        super(SEWBlock, self).__init__()
        self.connect_f = connect_f
        self.conv = nn.Sequential(
            conv3x3(in_channels, mid_channels),
            conv3x3(mid_channels, in_channels),
        )

    def forward(self, x: torch.Tensor):
        out = self.conv(x)
        if self.connect_f == 'ADD':
            out += x
        elif self.connect_f == 'AND':
            out *= x
        elif self.connect_f == 'IAND':
            out = x * (1. - out)
        else:
            raise NotImplementedError(self.connect_f)

        return out

class PlainBlock(nn.Module):
    def __init__(self, in_channels, mid_channels):
        super(PlainBlock, self).__init__()
        self.conv = nn.Sequential(
            conv3x3(in_channels, mid_channels),
            conv3x3(mid_channels, in_channels),
        )

    def forward(self, x: torch.Tensor):
        return self.conv(x)

class BasicBlock(nn.Module):
    def __init__(self, in_channels, mid_channels):
        super(BasicBlock, self).__init__()
        self.conv = nn.Sequential(
            conv3x3(in_channels, mid_channels),
            layer.Conv2d(mid_channels, in_channels, kernel_size=3, padding=1, stride=1, bias=False),
            layer.BatchNorm2d(in_channels),
        )
        self.sn = ParametricLIFNode(init_tau=2.0, detach_reset=True)

    def forward(self, x: torch.Tensor):
        return self.sn(x + self.conv(x))


class ResNetN(nn.Module):
    def __init__(self, layer_list, num_classes, connect_f=None):
        super(ResNetN, self).__init__()
        in_channels = 2
        conv = []

        for cfg_dict in layer_list:
            channels = cfg_dict['channels']

            if 'mid_channels' in cfg_dict:
                mid_channels = cfg_dict['mid_channels']
            else:
                mid_channels = channels

            if in_channels != channels:
                if cfg_dict['up_kernel_size'] == 3:
                    conv.append(conv3x3(in_channels, channels))
                elif cfg_dict['up_kernel_size'] == 1:
                    conv.append(conv1x1(in_channels, channels))
                else:
                    raise NotImplementedError

            in_channels = channels


            if 'num_blocks' in cfg_dict:
                num_blocks = cfg_dict['num_blocks']
                if cfg_dict['block_type'] == 'sew':
                    for _ in range(num_blocks):
                        conv.append(SEWBlock(in_channels, mid_channels, connect_f))
                elif cfg_dict['block_type'] == 'plain':
                    for _ in range(num_blocks):
                        conv.append(PlainBlock(in_channels, mid_channels))
                elif cfg_dict['block_type'] == 'basic':
                    for _ in range(num_blocks):
                        conv.append(BasicBlock(in_channels, mid_channels))
                else:
                    raise NotImplementedError

            if 'k_pool' in cfg_dict:
                k_pool = cfg_dict['k_pool']
                conv.append(layer.MaxPool2d(k_pool, k_pool))

        conv.append(layer.Flatten(2))

        self.conv = nn.Sequential(*conv)

        with torch.no_grad():
            x = torch.zeros([1, 1, 128, 128])
            for m in self.conv.modules():
                if isinstance(m, layer.MaxPool2d):
                    x = m(x)
            out_features = x.numel() * in_channels

        self.out = layer.Linear(out_features, num_classes, bias=True)

    def forward(self, x: torch.Tensor):
        x = x.permute(1, 0, 2, 3, 4)  # [T, N, 2, *, *]
        x = self.conv(x)
        return self.out(x.mean(0))

def SEWResNet(connect_f):
    layer_list = [
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 128, 'up_kernel_size': 1, 'mid_channels': 128, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 128, 'up_kernel_size': 1, 'mid_channels': 128, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 128, 'up_kernel_size': 1, 'mid_channels': 128, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
    ]
    num_classes = 10
    return ResNetN(layer_list, num_classes, connect_f)

不知道现在是否可以自由切换单步多步?
关于maxpool和flatten我不太确定行不行,因为新版SEWResnet里没用layer.flatten,而是用torch.flatten:

    def _forward_impl(self, x):
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.sn1(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        if self.avgpool.step_mode == 's':
            x = torch.flatten(x, 1)
        elif self.avgpool.step_mode == 'm':
            x = torch.flatten(x, 2)
        
        x = self.fc(x)

        return x

@fangwei123456
Copy link
Owner

现在可以了

@USTCYYX
Copy link
Author

USTCYYX commented Jan 19, 2024

为啥在新版的SEWResnet里是用torch.flatten(x, 1)和x = torch.flatten(x, 2)呀,还需要根据self.avgpool.step_mode进行判断,我修改的代码里面直接用layer.Flatten会不会产生不可预知的错误

@USTCYYX
Copy link
Author

USTCYYX commented Jan 19, 2024

conv.append(layer.Flatten(2)),修改后的代码中。

@fangwei123456
Copy link
Owner

为啥在新版的SEWResnet里是用torch.flatten(x, 1)和x = torch.flatten(x, 2)

需要看输入是[T, N, C, H, W]还是[N, C, H, W],所以flatten的起始维度不同。layer.Flatten会根据单步/多步模式来进行区分。

@USTCYYX
Copy link
Author

USTCYYX commented Jan 27, 2024

试着跑了一下,报了一个错,DVSCIFAR10

Traceback (most recent call last):
  File "train.py", line 289, in <module>
    main()
  File "train.py", line 219, in main
    out_fr = net(frame)
  File "/home/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sewresnet/smodels.py", line 126, in forward
    return self.out(x.mean(0))
  File "/home/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (2048x1 and 128x10)

下面是部分代码

class ResNetN(nn.Module):
    def __init__(self, layer_list, num_classes, connect_f=None):
        super(ResNetN, self).__init__()
        in_channels = 2
        conv = []

        for cfg_dict in layer_list:
            channels = cfg_dict['channels']

            if 'mid_channels' in cfg_dict:
                mid_channels = cfg_dict['mid_channels']
            else:
                mid_channels = channels

            if in_channels != channels:
                if cfg_dict['up_kernel_size'] == 3:
                    conv.append(conv3x3(in_channels, channels))
                elif cfg_dict['up_kernel_size'] == 1:
                    conv.append(conv1x1(in_channels, channels))
                else:
                    raise NotImplementedError

            in_channels = channels


            if 'num_blocks' in cfg_dict:
                num_blocks = cfg_dict['num_blocks']
                if cfg_dict['block_type'] == 'sew':
                    for _ in range(num_blocks):
                        conv.append(SEWBlock(in_channels, mid_channels, connect_f))
                elif cfg_dict['block_type'] == 'plain':
                    for _ in range(num_blocks):
                        conv.append(PlainBlock(in_channels, mid_channels))
                elif cfg_dict['block_type'] == 'basic':
                    for _ in range(num_blocks):
                        conv.append(BasicBlock(in_channels, mid_channels))
                else:
                    raise NotImplementedError

            if 'k_pool' in cfg_dict:
                k_pool = cfg_dict['k_pool']
                conv.append(layer.MaxPool2d(k_pool, k_pool))

        conv.append(layer.Flatten(2))

        self.conv = nn.Sequential(*conv)

        with torch.no_grad():
            x = torch.zeros([1, 1, 128, 128])
            for m in self.conv.modules():
                if isinstance(m, layer.MaxPool2d):
                    x = m(x)
            out_features = x.numel() * in_channels

        self.out = layer.Linear(out_features, num_classes, bias=True)

    def forward(self, x: torch.Tensor):
        x = x.permute(1, 0, 2, 3, 4)  # [T, N, 2, *, *]
        x = self.conv(x)
        return self.out(x.mean(0))

def SEWResNet(connect_f):
    layer_list = [
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 128, 'up_kernel_size': 1, 'mid_channels': 128, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 128, 'up_kernel_size': 1, 'mid_channels': 128, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 128, 'up_kernel_size': 1, 'mid_channels': 128, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
    ]
    num_classes = 10
    return ResNetN(layer_list, num_classes, connect_f)

@USTCYYX
Copy link
Author

USTCYYX commented Jan 27, 2024

好像维度不对

@fangwei123456
Copy link
Owner

conv.append(layer.Flatten(2))
改成1看看?还是报错的话建议检查一下每一层的输出shape

@USTCYYX
Copy link
Author

USTCYYX commented Jan 27, 2024

十分感谢!好像跑起来了,但是如果我要改成单步的话,上面的代码还要哪些改动呢?

    def forward(self, x: torch.Tensor):
        x = x.permute(1, 0, 2, 3, 4)  # [T, N, 2, *, *]
        x = self.conv(x)
        return self.out(x.mean(0))

应该改成

    def forward(self, x: torch.Tensor):
        all=0.
        x = x.permute(1, 0, 2, 3, 4)  # [T, N, 2, *, *]
        for t in range x.size()[0]:
                y=self.conv(x[t])
                out=self.out(y)
                all+=out
        return all/x.size()[0]

但我不确定前面的怎么改动

        conv.append(layer.Flatten(1))

        self.conv = nn.Sequential(*conv)

        with torch.no_grad():
            x = torch.zeros([1, 1, 128, 128])
            for m in self.conv.modules():
                if isinstance(m, layer.MaxPool2d):
                    x = m(x)
            out_features = x.numel() * in_channels

        self.out = layer.Linear(out_features, num_classes, bias=True)

@fangwei123456
Copy link
Owner

fangwei123456 commented Jan 27, 2024

functional.set_step_mode作用整个网络,改成单步即可。具体请参考最新版的教程

https://spikingjelly.readthedocs.io/zh-cn/latest/activation_based/basic_concept.html#id4

@fangwei123456
Copy link
Owner

单步模式下最后一层应该是return self.out(x),没有mean(0)

@USTCYYX
Copy link
Author

USTCYYX commented Jan 27, 2024

意思是这么写吗:

    def forward(self, x: torch.Tensor):
        all=0.
        x = x.permute(1, 0, 2, 3, 4)  # [T, N, 2, *, *]
        for t in range x.size()[0]:
                y=self.conv(x[t])
                out=self.out(y)
                all+=out
        return all

这里的all就是累计电压。

@fangwei123456
Copy link
Owner

可以这样,先用functional.set_step_mode作用self,让所有子模块都在单步模式下

@USTCYYX
Copy link
Author

USTCYYX commented Jan 30, 2024

您好,我想用spikingjelly训练cifar10和cifar100,我想用在spikingjelly/activation_based/model/spiking_vgg.py中的VGG范例模型,这些模型都是来自于ANN的VGG架构吗?然后将激活函数改成了神经元。

@USTCYYX
Copy link
Author

USTCYYX commented Jan 30, 2024

spiking_vgg的文档有点不全,我想问下我这么用是对的吗?
net = models.spiking_vgg.__dict__[args.model](neuron=neuron.IFNode(), num_classes=num_classes, surrogate_function=surrogate.ATan())

@fangwei123456
Copy link
Owner

neuron=neuron.IFNode()换成neuron=neuron.IFNode

@USTCYYX
Copy link
Author

USTCYYX commented Jan 31, 2024

我直接训练vgg16效果比较差,我想试试加上tdbn看看能不能到达那篇文章(Going Deeper With Directly-Trained Larger Spiking Neural Networks)所说的效果,net定义成这样:

spiking_vgg.__dict__[args.model](spiking_neuron=neuron.IFNode, num_classes=num_classes,surrogate_function=surrogate.ATan(),norm_layer=layer.ThresholdDependentBatchNorm2d(alpha=1., v_th=1.))

但是报错了:
TypeError: __init__() missing 1 required positional argument: 'num_features'
应该是直接使用layer.ThresholdDependentBatchNorm2d报错了,但是我看spiking_vgg是提供了这个norm_layer自定义的接口的,不知道应该怎么修改。

@fangwei123456
Copy link
Owner

norm_layer=layer.ThresholdDependentBatchNorm2d?

@fangwei123456
Copy link
Owner

fangwei123456 commented Feb 1, 2024

norm_layer需要是一个可调用的对象,而不是一个已经生成的模块吧

@USTCYYX
Copy link
Author

USTCYYX commented Feb 1, 2024

norm_layer=layer.ThresholdDependentBatchNorm2d我试过了,会提示输入v_th报错:TypeError: __init__() missing 1 required positional argument: 'v_th'
另外我现在用sj自带的sew_resnet18训练cifar100数据集,并且因为cifar100的size比较小,我将第一层的77卷积核换成了33卷积核。现在的问题是过拟合很严重,在训练集上有99%的精度,但是测试集只有55%左右。我看现在ANN上resnet18能做到75.61%左右的精度。不知道您有没有什么技巧可以缓解过拟合,提高精度。

@fangwei123456
Copy link
Owner

CIFAR数据集很容易过拟合,一般只能采用改进网络结构的方式。L2之类的技巧没什么效果。数据增强有一些用,参考下面的代码:
https://github.com/fangwei123456/Parallel-Spiking-Neuron/blob/main/cifar10/train_cf10.py

@USTCYYX
Copy link
Author

USTCYYX commented Feb 1, 2024

如果要改进网络的话,只能添加dropout,但是resnet原文里是不加dropout的。或者用您文章里的专门为DVScifar10准备的小网络Wide-7B-Net,不知道可不可行。

@USTCYYX
Copy link
Author

USTCYYX commented Feb 1, 2024

不知道您有没有用sewresnet做过cifar100,可以提供一点经验吗

@fangwei123456
Copy link
Owner

没有试过,但你可以用CIFAR10的网络试试

@USTCYYX
Copy link
Author

USTCYYX commented Feb 2, 2024

您好,您说的数据增强方法确实有用,但是似乎有点过了,导致训练集都有点欠拟合。不知道数据增强中有没有参数可以调节数据增强的强度?
https://github.com/fangwei123456/Parallel-Spiking-Neuron/blob/main/cifar10/train_cf10.py
中的数据增强办法。

@USTCYYX
Copy link
Author

USTCYYX commented Feb 2, 2024

现在比较尴尬的是测试集的精度比训练集的精度更高,这可能是数据增强太强所导致的。您原来的数据增强是针对cifar10的,但是现在我用的数据集是cifar100,可能不需要这么强的数据增强。希望您能告诉我数据增强的代码中有没有参数可供调节,来降低数据增强的效果。

@fangwei123456
Copy link
Owner

看一下数据增强的class的构造参数

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants