-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for grouped convolution #93
Conversation
Build failure for |
8823714
to
949fe5b
Compare
The majority of changes related to automatic formatting changes after the Emacs Python linter started working out of nowhere. You can ignore the Python changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing too crazy here. It's quite a lot of code for a single review though, so please try to keep in mind you can and should push formatting changes, whitespace, typos, etc. separately first (perhaps directly to main), then put a PR for review with the interesting changes.
// SAFETY: `init_heap` must be called once only | ||
unsafe { init_heap() }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this reverting a prior upstream change? Seems strange. I think init_heap is what it's supposed to be atm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to use call directly with bsp
@@ -28,7 +29,7 @@ unsafe fn ffi_data_import( | |||
kernel_width: usize, | |||
kernel_order: *const c_char, | |||
) -> (Tensor3<i8>, Tensor4<i8>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You do have to wonder if some of these could or should be combined into structs even for the FFI, e.g.,
typedef struct data_t {
int8_t *data,
size_t channels,
size_t height,
size_t width,
char *order,
} data_t;`
data_t input { .data = ... }
...
ffi_data_import(input, kernel)
Might not make sense to commit to refactoring at this point in the project though unless we have some reason to continue the work next year :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, if I or someone else wants to revisit this for optimizations at some point, this should be done. However, it's hard to justify now.
None, | ||
); | ||
|
||
let input_order_string = unsafe { CStr::from_ptr(input_order).to_str().unwrap_unchecked() }; | ||
let _input_order_string = unsafe { CStr::from_ptr(input_order).to_str().unwrap_unchecked() }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused? Just remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
|
||
fn conv_test() { | ||
sprintln!("conv_test: enter"); | ||
let din: Vec<i8> = vec![ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tag this entry with #[rustfmt::skip]
to make sure it doesn't get autoformatted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tagged
@@ -179,6 +179,50 @@ pub fn conv2d_bias_relu<T: DlaOutput + Clone>( | |||
) | |||
} | |||
|
|||
pub fn grouped_conv2d<T: DlaOutput + Clone>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Public API function: consider documenting what the function is trying to achieve. Possibly even # Arguments
as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
API documentation added
@@ -183,6 +183,56 @@ impl<T: Clone> Tensor3<T> { | |||
self.order | |||
} | |||
|
|||
/// Concatenates a Tensor along the least significant axis (axis=2) by interleaving the tensors | |||
pub fn concat_interleaved(tensors: Vec<Tensor3<T>>) -> Tensor3<T> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function does not seem to require ownership of the input buffer tensors
. Consider only borrowing a slice for the duration of the function call: concat_interleaved(tensors: &[Tensor3<T>])
. You'll need to add the borrow symbol &
on the callsite as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar thing might apply to other functions around here, I didn't review those.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the actual difference between a vec and a slice, but this works alright with this change. You probably know better :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vilukissa68 highlighting for resurrect.
You could use borrowed Vec as well but borrowed slice (&[T]) is just more general than Vec. It reads: "a view to any contiguous memory" without specifying if it should be allocated as a Vec specifically, ... or array, or Hashmap. This allows the function to be called with any of the above types.
766c845
to
a00d3cd
Compare
No description provided.