Skip to content
103 changes: 100 additions & 3 deletions crates/burn-tensor/src/tensor/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@ use alloc::format;
use alloc::string::String;
use alloc::vec::Vec;
use bytemuck::{AnyBitPattern, CheckedBitPattern, Zeroable, cast_mut, checked::CheckedCastError};
use core::ops::{Index, IndexMut};
use cubecl_quant::scheme::QuantScheme;
use half::{bf16, f16};
use num_traits::{Float, ToPrimitive};

use crate::{
DType, Distribution, Element, ElementConversion,
AsIndex, DType, Distribution, Element, ElementConversion,
quantization::{QuantValue, QuantizationStrategy, QuantizedBytes},
tensor::Bytes,
};

use rand::RngCore;

use super::quantization::{QuantLevel, QuantMode};
use crate::indexing::ravel_dims;
use rand::RngCore;

/// The things that can go wrong when manipulating tensor data.
#[derive(Debug)]
Expand Down Expand Up @@ -55,6 +56,32 @@ impl TensorData {
}
}

/// Gets an index view.
///
/// # Example
/// ```rust, ignore
/// let value = data.index_view::<f32>()[[2, 3, 4]];
/// ```
pub fn as_index_view<'a, E: Element>(&'a self) -> TensorDataIndexView<'a, E> {
TensorDataIndexView {
data: self,
_phantom: core::marker::PhantomData,
}
}

/// Gets a mutable index view.
///
/// # Example
/// ```rust, ignore
/// data.mut_index_view::<f32>()[[2, 3, 4]] = 1.0;
/// ```
pub fn as_mut_index_view<'a, E: Element>(&'a mut self) -> TensorDataIndexViewMut<'a, E> {
TensorDataIndexViewMut {
data: self,
_phantom: core::marker::PhantomData,
}
}

/// Creates a new quantized tensor data structure.
pub fn quantized<E: Element, S: Into<Vec<usize>>>(
value: Vec<E>,
Expand Down Expand Up @@ -882,6 +909,43 @@ impl core::fmt::Display for TensorData {
}
}

/// An [`Index`] view for [`TensorData`].
pub struct TensorDataIndexView<'a, E: Element> {
data: &'a TensorData,
_phantom: core::marker::PhantomData<&'a E>,
}

impl<'a, E: Element, I: AsIndex> Index<&[I]> for TensorDataIndexView<'a, E> {
type Output = E;

fn index(&self, index: &[I]) -> &Self::Output {
let idx = ravel_dims(index, &self.data.shape);
&self.data.as_slice::<E>().unwrap()[idx]
}
}

/// A mutable [`Index`] view for [`TensorData`].
pub struct TensorDataIndexViewMut<'a, E: Element> {
data: &'a mut TensorData,
_phantom: core::marker::PhantomData<&'a E>,
}

impl<'a, E: Element, I: AsIndex> Index<&[I]> for TensorDataIndexViewMut<'a, E> {
type Output = E;

fn index(&self, index: &[I]) -> &Self::Output {
let idx = ravel_dims(index, &self.data.shape);
&self.data.as_slice::<E>().unwrap()[idx]
}
}

impl<'a, E: Element, I: AsIndex> IndexMut<&[I]> for TensorDataIndexViewMut<'a, E> {
fn index_mut(&mut self, index: &[I]) -> &mut Self::Output {
let idx = ravel_dims(index, &self.data.shape);
&mut (self.data.as_mut_slice::<E>().unwrap()[idx])
}
}

/// The tolerance used to compare to floating point numbers.
///
/// Generally, two numbers `x` and `y` are approximately equal if
Expand Down Expand Up @@ -1097,6 +1161,39 @@ mod tests {
assert_eq!(data.rank(), 3);
}

#[test]
fn test_as_index_view() {
let shape = Shape::new([3, 5, 6]);
let data = TensorData::random::<f32, _, _>(
shape,
Distribution::Default,
&mut StdRng::from_os_rng(),
);

assert_eq!(
data.as_index_view::<f32>()[&[1, 2, 3]],
data.as_slice::<f32>().unwrap()[1 * 5 * 6 + 2 * 6 + 3]
)
}

#[test]
fn test_as_mut_index_view() {
let shape = Shape::new([3, 5, 6]);
let mut data = TensorData::random::<f32, _, _>(
shape,
Distribution::Default,
&mut StdRng::from_os_rng(),
);

assert_eq!(
data.as_index_view::<f32>()[&[1, 2, 3]],
data.as_slice::<f32>().unwrap()[1 * 5 * 6 + 2 * 6 + 3]
);

data.as_mut_index_view::<f32>()[&[1, 2, 3]] = 3.0;
assert_eq!(data.as_index_view::<f32>()[&[1, 2, 3]], 3.0);
}

#[test]
fn into_vec_should_yield_same_value_as_iter() {
let shape = Shape::new([3, 5, 6]);
Expand Down
43 changes: 43 additions & 0 deletions crates/burn-tensor/src/tensor/indexing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,52 @@ where
}
}

/// Compute the ravel index for the given coordinates.
///
/// This returns the row-major order raveling.
///
/// # Arguments
/// - `coords`: must be the same size len as `dims`.
/// - `dims`: must be the same len as `coords`.
///
/// # Returns
/// - the ravel offset index.
pub fn ravel_dims<I: AsIndex>(coords: &[I], dims: &[usize]) -> usize {
assert_eq!(
dims.len(),
coords.len(),
"Coordinate rank mismatch: expected {}, got {}",
dims.len(),
coords.len(),
);

let mut ravel_idx = 0;
let mut stride = 1;

for (i, &dim) in dims.iter().enumerate().rev() {
let coord = canonicalize_index(coords[i], dim, false);
ravel_idx += coord * stride;
stride *= dim;
}

ravel_idx
}

#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;

#[test]
fn test_ravel() {
let shape = vec![2, 3, 4, 5];

assert_eq!(ravel_dims(&[0, 0, 0, 0], &shape), 0);
assert_eq!(
ravel_dims(&[1, 2, 3, 4], &shape),
1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4
);
}

#[test]
fn test_wrap_idx() {
Expand Down
27 changes: 26 additions & 1 deletion crates/burn-tensor/src/tensor/shape.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{Slice, SliceArg};
use crate::indexing::ravel_dims;
use crate::{AsIndex, Slice, SliceArg};
use alloc::vec::Vec;
use core::{
ops::{Deref, DerefMut, Index, IndexMut, Range},
Expand Down Expand Up @@ -75,6 +76,19 @@ impl Shape {
self
}

/// Compute the ravel index for the given coordinates.
///
/// This returns the row-major order raveling.
///
/// # Arguments
/// - `coords`: must be the same size as `self.rank()`.
///
/// # Returns
/// - the ravel offset index.
pub fn ravel<I: AsIndex>(&self, coords: &[I]) -> usize {
ravel_dims(coords, &self.dims)
}

/// Convert shape dimensions to full covering ranges (0..dim) for each dimension.
pub fn into_ranges(self) -> Vec<Range<usize>> {
self.into_iter().map(|d| 0..d).collect()
Expand Down Expand Up @@ -630,6 +644,17 @@ mod tests {
assert_eq!(&shape.dims, &[120]);
}

#[test]
fn test_ravel() {
let shape = Shape::new([2, 3, 4, 5]);

assert_eq!(shape.ravel(&[0, 0, 0, 0]), 0);
assert_eq!(
shape.ravel(&[1, 2, 3, 4]),
1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4
);
}

#[test]
fn test_shape_insert_remove() {
let dims = [2, 3, 4, 5];
Expand Down