Skip to content

Commit

Permalink
Changes following review
Browse files Browse the repository at this point in the history
  • Loading branch information
emricksinisonos committed Nov 21, 2024
1 parent 3822363 commit 2d15a1b
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 83 deletions.
15 changes: 8 additions & 7 deletions core/src/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::fmt::{self, Debug};
use tract_data::itertools::izip;
use tract_itertools::Itertools;
use tract_linalg::LinalgFn;
use crate::ndarray::Dimension;

use super::cast::cast;

Expand Down Expand Up @@ -455,11 +456,11 @@ impl EvalOp for OptBinByScalar {
.last()
.context("Cannot use by_scalar when no trailing dimensions are unary")?;

let iterating_shape = a.shape()[..first_unary_axis].to_vec();
let iterating_shape = &a.shape()[..first_unary_axis];
if !iterating_shape.is_empty() {
for it_coords in tract_data::internal::iter_indices(&iterating_shape) {
let mut view = TensorView::at_prefix(&a, &it_coords)?;
let b_view = TensorView::at_prefix(&b, &it_coords)?;
for it_coords in tract_ndarray::indices(iterating_shape) {
let mut view = TensorView::at_prefix(&a, it_coords.slice())?;
let b_view = TensorView::at_prefix(&b, it_coords.slice())?;
debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
(self.eval_fn)(&mut view, &b_view)?;
}
Expand Down Expand Up @@ -579,9 +580,9 @@ impl EvalOp for OptBinUnicast {
if let Some(first_non_unary_axis) = first_non_unary_axis {
// Iterate on outter dimensions and evaluate with unicast subviews
let iterating_shape = a.shape()[..first_non_unary_axis].to_vec();
for it_coords in tract_data::internal::iter_indices(&iterating_shape) {
let mut view = TensorView::at_prefix(&a, &it_coords)?;
debug_assert_eq!(view.shape(), &b_view.shape()[it_coords.len()..]);
for it_coords in tract_ndarray::indices(iterating_shape) {
let mut view = TensorView::at_prefix(&a, it_coords.slice())?;
debug_assert_eq!(view.shape(), &b_view.shape()[it_coords.slice().len()..]);
(self.eval_fn)(&mut view, &b_view)?;
}
} else {
Expand Down
1 change: 0 additions & 1 deletion core/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ pub mod macros;
pub mod element_wise;
#[macro_use]
pub mod binary;
//pub mod binary_new;

pub mod array;
pub mod cast;
Expand Down
1 change: 0 additions & 1 deletion data/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ pub mod internal {
pub use crate::opaque::{ OpaquePayload, OpaqueFact };
pub use crate::prelude::*;
pub use crate::tensor::view::TensorView;
pub use crate::tensor::indices::iter_indices;
pub use crate::tensor::Approximation;
pub use crate::tensor::vector_size;
pub use anyhow::{anyhow, bail, ensure, format_err, Context as TractErrorContext};
Expand Down
1 change: 0 additions & 1 deletion data/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use std::hash::Hash;
use std::ops::Range;
use std::sync::Arc;

pub mod indices;
pub mod litteral;
pub mod view;

Expand Down
73 changes: 0 additions & 73 deletions data/src/tensor/indices.rs

This file was deleted.

0 comments on commit 2d15a1b

Please sign in to comment.