Skip to content

Commit

Permalink
remove default features
Browse files Browse the repository at this point in the history
  • Loading branch information
bokutotu committed Nov 1, 2024
1 parent c32541c commit a11aeba
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 7 deletions.
6 changes: 5 additions & 1 deletion zenu-autograd/src/nn/rnns/weights.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::fmt::Debug;

use rand_distr::{Distribution, StandardNormal};
use zenu_matrix::{device::Device, nn::rnn::RNNWeights as RNNWeightsMat, num::Num};
use zenu_matrix::{device::Device, num::Num};

#[cfg(feature = "nvidia")]
use zenu_matrix::nn::rnn::RNNWeights as RNNWeightsMat;

use crate::{
creator::{rand::normal, zeros::zeros},
Expand Down Expand Up @@ -61,6 +64,7 @@ pub struct RNNWeights<T: Num, D: Device, C: CellType> {
_cell: std::marker::PhantomData<C>,
}

#[cfg(feature = "nvidia")]
impl<T: Num, D: Device, C: CellType> From<RNNWeightsMat<T, D>> for RNNWeights<T, D, C> {
fn from(weights: RNNWeightsMat<T, D>) -> Self {
Self {
Expand Down
2 changes: 0 additions & 2 deletions zenu-matrix/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#![expect(clippy::module_name_repetitions, clippy::module_inception)]

use device::cpu::Cpu;
use memory_pool::MemPool;

Expand Down
2 changes: 2 additions & 0 deletions zenu-matrix/src/nn/batch_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub struct BatchNorm2dConfig<T: Num> {

impl<T: Num> BatchNorm2dConfig<T> {
#[must_use]
#[allow(unused_variables)]
pub fn new(dim: DimDyn) -> Self {
BatchNorm2dConfig::<T> {
#[cfg(feature = "nvidia")]
Expand All @@ -42,6 +43,7 @@ pub struct BatchNorm2dBackwardConfig<T> {

impl<T: Num> BatchNorm2dBackwardConfig<T> {
#[must_use]
#[allow(unused_variables)]
pub fn new(dim: DimDyn) -> Self {
BatchNorm2dBackwardConfig::<T> {
#[cfg(feature = "nvidia")]
Expand Down
3 changes: 2 additions & 1 deletion zenu-matrix/src/nn/conv2d/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ pub struct Conv2dConfig<T: Num> {
_phantom: std::marker::PhantomData<T>,
}

#[expect(clippy::too_many_arguments, clippy::missing_panics_doc)]
#[expect(clippy::too_many_arguments)]
#[allow(unused_variables)]
#[must_use]
pub fn create_conv_descriptor<T: Num>(
input_shape: &[usize],
Expand Down
1 change: 0 additions & 1 deletion zenu-matrix/src/nn/pool2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ pub struct Pool2dConfig<T: Num> {

impl<T: Num> Pool2dConfig<T> {
#[must_use]
#[expect(clippy::missing_panics_doc)]
pub fn new(
kernel: (usize, usize),
stride: (usize, usize),
Expand Down
21 changes: 19 additions & 2 deletions zenu-matrix/tests/rnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ use zenu_test::read_test_case_from_json_val;

use zenu_matrix::device::cpu::Cpu;
use zenu_matrix::matrix::Owned;
use zenu_matrix::{device::nvidia::Nvidia, dim::DimDyn, matrix::Matrix, nn::rnn::*};
use zenu_matrix::{dim::DimDyn, matrix::Matrix};

#[cfg(feature = "nvidia")]
use zenu_matrix::{device::nvidia::Nvidia, nn::rnn::*};

#[cfg(feature = "nvidia")]
fn get_rnn_weights_from_json(
matrix_map: &std::collections::HashMap<String, Matrix<Owned<f32>, DimDyn, Cpu>>,
num_layers: usize,
Expand Down Expand Up @@ -60,6 +64,7 @@ fn get_rnn_weights_from_json(
weights
}

#[cfg(feature = "nvidia")]
fn assert_grad(expected: &[RNNWeights<f32, Cpu>], actual: &[RNNWeights<f32, Cpu>]) {
for (expected, actual) in expected.iter().zip(actual.iter()) {
assert_mat_eq_epsilon!(expected.input_weight(), actual.input_weight(), 5e-3);
Expand All @@ -69,6 +74,7 @@ fn assert_grad(expected: &[RNNWeights<f32, Cpu>], actual: &[RNNWeights<f32, Cpu>
}
}

#[cfg(feature = "nvidia")]
fn before_run(
map: &HashMap<String, Matrix<Owned<f32>, DimDyn, Cpu>>,
bidirectional: bool,
Expand All @@ -81,6 +87,7 @@ fn before_run(
(input_size, hidden_size, batch_size)
}

#[cfg(feature = "nvidia")]
fn init_weights(
desc: &RNNDescriptor<f32>,
map: &HashMap<String, Matrix<Owned<f32>, DimDyn, Cpu>>,
Expand All @@ -97,6 +104,7 @@ fn init_weights(
w
}

#[cfg(feature = "nvidia")]
fn rnn(json_path: String, num_layers: usize, bidirectional: bool) {
let matrix_map = read_test_case_from_json_val!(json_path);
let (input_size, hidden_size, batch_size) = before_run(&matrix_map, bidirectional);
Expand Down Expand Up @@ -140,6 +148,7 @@ fn rnn(json_path: String, num_layers: usize, bidirectional: bool) {
assert_grad(&weights_grad, &params);
}

#[cfg(feature = "nvidia")]
fn lstm(json_path: String, num_layers: usize, bidirectional: bool) {
let matrix_map = read_test_case_from_json_val!(json_path);
let (input_size, hidden_size, batch_size) = before_run(&matrix_map, bidirectional);
Expand Down Expand Up @@ -185,6 +194,7 @@ fn lstm(json_path: String, num_layers: usize, bidirectional: bool) {
assert_grad(&weights_grad, &params);
}

#[cfg(feature = "nvidia")]
fn gru(json_path: String, num_layers: usize, bidirectional: bool) {
let matrix_map = read_test_case_from_json_val!(json_path);
let (input_size, hidden_size, batch_size) = before_run(&matrix_map, bidirectional);
Expand Down Expand Up @@ -230,6 +240,7 @@ fn gru(json_path: String, num_layers: usize, bidirectional: bool) {
assert_grad(&weights_grad, &params);
}

#[cfg(feature = "nvidia")]
#[test]
fn test_lstm_small() {
lstm(
Expand All @@ -239,6 +250,7 @@ fn test_lstm_small() {
);
}

#[cfg(feature = "nvidia")]
#[test]
fn test_lstm_medium() {
lstm(
Expand All @@ -247,7 +259,7 @@ fn test_lstm_medium() {
false,
);
}

#[cfg(feature = "nvidia")]
#[test]
fn test_lstm_bidirectional() {
lstm(
Expand All @@ -257,6 +269,7 @@ fn test_lstm_bidirectional() {
);
}

#[cfg(feature = "nvidia")]
#[test]
fn test_rnn_seq_len_1() {
rnn(
Expand All @@ -266,6 +279,7 @@ fn test_rnn_seq_len_1() {
);
}

#[cfg(feature = "nvidia")]
#[test]
fn test_rnn_seq_len_3() {
rnn(
Expand All @@ -275,6 +289,7 @@ fn test_rnn_seq_len_3() {
);
}

#[cfg(feature = "nvidia")]
#[test]
fn test_rnn_seq_len_5() {
rnn(
Expand All @@ -284,6 +299,7 @@ fn test_rnn_seq_len_5() {
);
}

#[cfg(feature = "nvidia")]
#[test]
fn test_rnn_seq_len_5_num_layer_2_bidirectional() {
rnn(
Expand All @@ -293,6 +309,7 @@ fn test_rnn_seq_len_5_num_layer_2_bidirectional() {
);
}

#[cfg(feature = "nvidia")]
#[test]
fn test_gru_small() {
gru("../test_data_json/gru_small.json".to_string(), 2, false);
Expand Down

0 comments on commit a11aeba

Please sign in to comment.