diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index 7e4c8977a7..c09da44b8d 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -5,19 +5,20 @@ use alloc::format; use alloc::string::String; use alloc::vec::Vec; use bytemuck::{AnyBitPattern, CheckedBitPattern, Zeroable, cast_mut, checked::CheckedCastError}; +use core::ops::{Index, IndexMut}; use cubecl_quant::scheme::QuantScheme; use half::{bf16, f16}; use num_traits::{Float, ToPrimitive}; use crate::{ - DType, Distribution, Element, ElementConversion, + AsIndex, DType, Distribution, Element, ElementConversion, quantization::{QuantValue, QuantizationStrategy, QuantizedBytes}, tensor::Bytes, }; -use rand::RngCore; - use super::quantization::{QuantLevel, QuantMode}; +use crate::indexing::ravel_dims; +use rand::RngCore; /// The things that can go wrong when manipulating tensor data. #[derive(Debug)] @@ -55,6 +56,32 @@ impl TensorData { } } + /// Gets an index view. + /// + /// # Example + /// ```rust, ignore + /// let value = data.index_view::()[[2, 3, 4]]; + /// ``` + pub fn as_index_view<'a, E: Element>(&'a self) -> TensorDataIndexView<'a, E> { + TensorDataIndexView { + data: self, + _phantom: core::marker::PhantomData, + } + } + + /// Gets a mutable index view. + /// + /// # Example + /// ```rust, ignore + /// data.mut_index_view::()[[2, 3, 4]] = 1.0; + /// ``` + pub fn as_mut_index_view<'a, E: Element>(&'a mut self) -> TensorDataIndexViewMut<'a, E> { + TensorDataIndexViewMut { + data: self, + _phantom: core::marker::PhantomData, + } + } + /// Creates a new quantized tensor data structure. pub fn quantized>>( value: Vec, @@ -882,6 +909,43 @@ impl core::fmt::Display for TensorData { } } +/// An [`Index`] view for [`TensorData`]. +pub struct TensorDataIndexView<'a, E: Element> { + data: &'a TensorData, + _phantom: core::marker::PhantomData<&'a E>, +} + +impl<'a, E: Element, I: AsIndex> Index<&[I]> for TensorDataIndexView<'a, E> { + type Output = E; + + fn index(&self, index: &[I]) -> &Self::Output { + let idx = ravel_dims(index, &self.data.shape); + &self.data.as_slice::().unwrap()[idx] + } +} + +/// A mutable [`Index`] view for [`TensorData`]. +pub struct TensorDataIndexViewMut<'a, E: Element> { + data: &'a mut TensorData, + _phantom: core::marker::PhantomData<&'a E>, +} + +impl<'a, E: Element, I: AsIndex> Index<&[I]> for TensorDataIndexViewMut<'a, E> { + type Output = E; + + fn index(&self, index: &[I]) -> &Self::Output { + let idx = ravel_dims(index, &self.data.shape); + &self.data.as_slice::().unwrap()[idx] + } +} + +impl<'a, E: Element, I: AsIndex> IndexMut<&[I]> for TensorDataIndexViewMut<'a, E> { + fn index_mut(&mut self, index: &[I]) -> &mut Self::Output { + let idx = ravel_dims(index, &self.data.shape); + &mut (self.data.as_mut_slice::().unwrap()[idx]) + } +} + /// The tolerance used to compare to floating point numbers. /// /// Generally, two numbers `x` and `y` are approximately equal if @@ -1097,6 +1161,39 @@ mod tests { assert_eq!(data.rank(), 3); } + #[test] + fn test_as_index_view() { + let shape = Shape::new([3, 5, 6]); + let data = TensorData::random::( + shape, + Distribution::Default, + &mut StdRng::from_os_rng(), + ); + + assert_eq!( + data.as_index_view::()[&[1, 2, 3]], + data.as_slice::().unwrap()[1 * 5 * 6 + 2 * 6 + 3] + ) + } + + #[test] + fn test_as_mut_index_view() { + let shape = Shape::new([3, 5, 6]); + let mut data = TensorData::random::( + shape, + Distribution::Default, + &mut StdRng::from_os_rng(), + ); + + assert_eq!( + data.as_index_view::()[&[1, 2, 3]], + data.as_slice::().unwrap()[1 * 5 * 6 + 2 * 6 + 3] + ); + + data.as_mut_index_view::()[&[1, 2, 3]] = 3.0; + assert_eq!(data.as_index_view::()[&[1, 2, 3]], 3.0); + } + #[test] fn into_vec_should_yield_same_value_as_iter() { let shape = Shape::new([3, 5, 6]); diff --git a/crates/burn-tensor/src/tensor/indexing/mod.rs b/crates/burn-tensor/src/tensor/indexing/mod.rs index 31e956215a..ef50f794cb 100644 --- a/crates/burn-tensor/src/tensor/indexing/mod.rs +++ b/crates/burn-tensor/src/tensor/indexing/mod.rs @@ -211,9 +211,52 @@ where } } +/// Compute the ravel index for the given coordinates. +/// +/// This returns the row-major order raveling. +/// +/// # Arguments +/// - `coords`: must be the same size len as `dims`. +/// - `dims`: must be the same len as `coords`. +/// +/// # Returns +/// - the ravel offset index. +pub fn ravel_dims(coords: &[I], dims: &[usize]) -> usize { + assert_eq!( + dims.len(), + coords.len(), + "Coordinate rank mismatch: expected {}, got {}", + dims.len(), + coords.len(), + ); + + let mut ravel_idx = 0; + let mut stride = 1; + + for (i, &dim) in dims.iter().enumerate().rev() { + let coord = canonicalize_index(coords[i], dim, false); + ravel_idx += coord * stride; + stride *= dim; + } + + ravel_idx +} + #[cfg(test)] mod tests { use super::*; + use alloc::vec; + + #[test] + fn test_ravel() { + let shape = vec![2, 3, 4, 5]; + + assert_eq!(ravel_dims(&[0, 0, 0, 0], &shape), 0); + assert_eq!( + ravel_dims(&[1, 2, 3, 4], &shape), + 1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4 + ); + } #[test] fn test_wrap_idx() { diff --git a/crates/burn-tensor/src/tensor/shape.rs b/crates/burn-tensor/src/tensor/shape.rs index c1dbfcde95..4108ffed2c 100644 --- a/crates/burn-tensor/src/tensor/shape.rs +++ b/crates/burn-tensor/src/tensor/shape.rs @@ -1,4 +1,5 @@ -use crate::{Slice, SliceArg}; +use crate::indexing::ravel_dims; +use crate::{AsIndex, Slice, SliceArg}; use alloc::vec::Vec; use core::{ ops::{Deref, DerefMut, Index, IndexMut, Range}, @@ -75,6 +76,19 @@ impl Shape { self } + /// Compute the ravel index for the given coordinates. + /// + /// This returns the row-major order raveling. + /// + /// # Arguments + /// - `coords`: must be the same size as `self.rank()`. + /// + /// # Returns + /// - the ravel offset index. + pub fn ravel(&self, coords: &[I]) -> usize { + ravel_dims(coords, &self.dims) + } + /// Convert shape dimensions to full covering ranges (0..dim) for each dimension. pub fn into_ranges(self) -> Vec> { self.into_iter().map(|d| 0..d).collect() @@ -630,6 +644,17 @@ mod tests { assert_eq!(&shape.dims, &[120]); } + #[test] + fn test_ravel() { + let shape = Shape::new([2, 3, 4, 5]); + + assert_eq!(shape.ravel(&[0, 0, 0, 0]), 0); + assert_eq!( + shape.ravel(&[1, 2, 3, 4]), + 1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4 + ); + } + #[test] fn test_shape_insert_remove() { let dims = [2, 3, 4, 5];