-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathcumsum.rs
30 lines (26 loc) · 929 Bytes
/
cumsum.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#![cfg(feature = "cumsum")]
use crate::{
candle::{shape::Dim, Result, Tensor},
F,
};
impl F {
/// Returns the cumulative sum of elements of input in the dimension dim.
///
/// [https://pytorch.org/docs/stable/generated/torch.cumsum.html](https://pytorch.org/docs/stable/generated/torch.cumsum.html)
pub fn cumsum<D: Dim>(input: &Tensor, dim: D) -> Result<Tensor> {
let dim = dim.to_index(input.shape(), "cumsum")?;
let dim_size = input.dim(dim)?;
let mut tensors = Vec::with_capacity(dim_size);
let mut a = input.clone();
for i in 0..dim_size {
if i > 0 {
a = a.narrow(dim, 1, dim_size - i)?;
let b = input.narrow(dim, 0, dim_size - i)?;
a = (a + b)?;
}
tensors.push(a.narrow(dim, 0, 1)?);
}
let cumsum = Tensor::cat(&tensors, dim)?;
Ok(cumsum)
}
}