Do I need to implement deserialization myself for Sequential
?
#1407
-
pub struct LoreDetectModel<B: Backend> {
...
wh: SequentialConv2d<B>,
}
#[derive(Module, Debug)]
struct SequentialConv2d<B: Backend> {
layers: Vec<Conv2d<B>>,
}
impl<B: Backend> SequentialConv2d<B> {
fn new() -> Self {
Self { layers: Vec::new() }
}
fn from(conv2ds: Vec<Conv2d<B>>) -> Self {
Self { layers: conv2ds }
}
fn add(&mut self, layer: Conv2d<B>) {
self.layers.push(layer);
}
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let mut x = x;
for (index, layer) in self.layers.iter().enumerate() {
x = layer.forward(x);
if index < self.layers.len() - 1 {
x = relu(x);
}
}
x
}
}
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Sorry about the bad error message. We will improve it and thanks for filing a bug (#1390). No, you do not need to implement any deserialization. You just need to match the keys. If you need to adjust the keys, you can use key remap. Please see this book section: https://burn.dev/book/import/pytorch-model.html#adjusting-the-source-model-architecture. In your example, you have new Hope this helps! P.S. The book online is for Burn 0.12.1 but the main has updated documentation and fixes for the import. |
Beta Was this translation helpful? Give feedback.
Sorry about the bad error message. We will improve it and thanks for filing a bug (#1390).
No, you do not need to implement any deserialization. You just need to match the keys. If you need to adjust the keys, you can use key remap. Please see this book section: https://burn.dev/book/import/pytorch-model.html#adjusting-the-source-model-architecture.
In your example, you have new
layers
attribute in your model but not in the source. So you'll need to addlayers
beforewh
. Basically, you'll need to replacemodel.wh
tomodel.wh.layers
. I recommend this online regex tool that works for Rust regex: https://rregex.dev/?version=1.10&method=replaceHope this helps!
P.S. The book online is for Bur…