diff --git a/Cargo.toml b/Cargo.toml index 90bc83e46f..2b4274a106 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -134,6 +134,7 @@ cairo-lang-sierra-type-size.workspace = true cairo-lang-utils.workspace = true educe.workspace = true itertools.workspace = true +lambdaworks-math.workspace = true lazy_static.workspace = true libc.workspace = true libloading.workspace = true diff --git a/benches/benches.rs b/benches/benches.rs index a779efc342..ed4656b66f 100644 --- a/benches/benches.rs +++ b/benches/benches.rs @@ -88,6 +88,7 @@ fn criterion_benchmark(c: &mut Criterion) { compare(c, "programs/benches/dict_snapshot.cairo"); compare(c, "programs/benches/dict_insert.cairo"); compare(c, "programs/benches/factorial_2M.cairo"); + compare(c, "programs/benches/factorial_2M_inv.cairo"); compare(c, "programs/benches/fib_2M.cairo"); compare(c, "programs/benches/linear_search.cairo"); compare(c, "programs/benches/logistic_map.cairo"); diff --git a/programs/benches/factorial_2M_inv.cairo b/programs/benches/factorial_2M_inv.cairo new file mode 100644 index 0000000000..187597f9bd --- /dev/null +++ b/programs/benches/factorial_2M_inv.cairo @@ -0,0 +1,15 @@ +fn factorial_inv(value: felt252, n: felt252) -> felt252 { + if (n == 1) { + value + } else { + factorial_inv(felt252_div(value, n.try_into().unwrap()), n - 1) + } +} + +fn main() { + let result = factorial_inv(0x4d6e41de886ac83938da3456ccf1481182687989ead34d9d35236f0864575a0, 2_000_000); + assert( + result == 1, + 'invalid result' + ); +} diff --git a/src/arch/aarch64.rs b/src/arch/aarch64.rs index 556de0c584..7279871e80 100644 --- a/src/arch/aarch64.rs +++ b/src/arch/aarch64.rs @@ -10,7 +10,11 @@ #![cfg(target_arch = "aarch64")] use super::AbiArgument; -use crate::{error::Error, starknet::U256, utils::get_integer_layout}; +use crate::{ + error::Error, + starknet::U256, + utils::{get_integer_layout, montgomery::MontyBytes}, +}; use cairo_lang_sierra::ids::ConcreteTypeId; use num_traits::ToBytes; use starknet_types_core::felt::Felt; @@ -197,7 +201,8 @@ impl AbiArgument for Felt { if buffer.len() >= 56 { align_to(buffer, get_integer_layout(252).align()); } - buffer.extend_from_slice(&self.to_bytes_le()); + + buffer.extend_from_slice(&self.to_monty_bytes_le()); Ok(()) } } @@ -489,8 +494,10 @@ mod test { buffer, [0; 40] .into_iter() - .chain([0xFF; 31]) - .chain([0x00]) + .chain([ + 37, 0, 126, 115, 253, 255, 255, 255, 255, 15, 51, 1, 0, 0, 0, 0, 128, 111, 255, + 255, 255, 255, 255, 255, 184, 2, 94, 171, 212, 255, 255, 7 + ]) .collect::>() ); @@ -504,8 +511,10 @@ mod test { buffer, [0; 48] .into_iter() - .chain([0xFF; 31]) - .chain([0x00]) + .chain([ + 37, 0, 126, 115, 253, 255, 255, 255, 255, 15, 51, 1, 0, 0, 0, 0, 128, 111, 255, + 255, 255, 255, 255, 255, 184, 2, 94, 171, 212, 255, 255, 7 + ]) .collect::>() ); @@ -519,8 +528,10 @@ mod test { buffer, [0; 64] .into_iter() - .chain([0xFF; 31]) - .chain([0x00]) + .chain([ + 37, 0, 126, 115, 253, 255, 255, 255, 255, 15, 51, 1, 0, 0, 0, 0, 128, 111, 255, + 255, 255, 255, 255, 255, 184, 2, 94, 171, 212, 255, 255, 7 + ]) .collect::>() ); } diff --git a/src/arch/x86_64.rs b/src/arch/x86_64.rs index 59e274affd..f08a1b6941 100644 --- a/src/arch/x86_64.rs +++ b/src/arch/x86_64.rs @@ -10,7 +10,11 @@ #![cfg(target_arch = "x86_64")] use super::AbiArgument; -use crate::{error::Error, starknet::U256, utils::get_integer_layout}; +use crate::{ + error::Error, + starknet::U256, + utils::{get_integer_layout, montgomery::MontyBytes}, +}; use cairo_lang_sierra::ids::ConcreteTypeId; use num_traits::ToBytes; use starknet_types_core::felt::Felt; @@ -159,7 +163,7 @@ impl AbiArgument for Felt { align_to(buffer, get_integer_layout(252).align()); } - buffer.extend_from_slice(&self.to_bytes_le()); + buffer.extend_from_slice(&self.to_monty_bytes_le()); Ok(()) } } @@ -240,8 +244,10 @@ mod test { buffer, [0; 24] .into_iter() - .chain([0xFF; 31]) - .chain([0x00]) + .chain([ + 37, 0, 126, 115, 253, 255, 255, 255, 255, 15, 51, 1, 0, 0, 0, 0, 128, 111, 255, + 255, 255, 255, 255, 255, 184, 2, 94, 171, 212, 255, 255, 7 + ]) .collect::>() ); @@ -255,8 +261,10 @@ mod test { buffer, [0; 32] .into_iter() - .chain([0xFF; 31]) - .chain([0x00]) + .chain([ + 37, 0, 126, 115, 253, 255, 255, 255, 255, 15, 51, 1, 0, 0, 0, 0, 128, 111, 255, + 255, 255, 255, 255, 255, 184, 2, 94, 171, 212, 255, 255, 7 + ]) .collect::>() ); @@ -270,8 +278,10 @@ mod test { buffer, [0; 48] .into_iter() - .chain([0xFF; 31]) - .chain([0x00]) + .chain([ + 37, 0, 126, 115, 253, 255, 255, 255, 255, 15, 51, 1, 0, 0, 0, 0, 128, 111, 255, + 255, 255, 255, 255, 255, 184, 2, 94, 171, 212, 255, 255, 7 + ]) .collect::>() ); } diff --git a/src/executor.rs b/src/executor.rs index 55ee41eddd..b29db67c7f 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -454,11 +454,17 @@ fn parse_result( #[cfg(target_arch = "aarch64")] Ok(Value::Felt252({ + use crate::utils::montgomery; + let data = unsafe { std::mem::transmute::<&mut [u64; 4], &mut [u8; 32]>(&mut ret_registers) }; data[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). - starknet_types_core::felt::Felt::from_bytes_le(data) + + // Felts are represented in Montgomery form. Due to this, we + // need to convert them back to their original representation. + montgomery::felt_from_monty_bytes(&data) + .to_native_assert_error("Couldn't create felt from Montgomery bytes")? })) } }, diff --git a/src/executor/contract.rs b/src/executor/contract.rs index 8cec669b05..e7730e2403 100644 --- a/src/executor/contract.rs +++ b/src/executor/contract.rs @@ -51,7 +51,9 @@ use crate::{ types::TypeBuilder, utils::{ decode_error_message, generate_function_name, get_integer_layout, get_types_total_size, - libc_free, libc_malloc, BuiltinCosts, + libc_free, libc_malloc, + montgomery::{self, MontyBytes}, + BuiltinCosts, }, OptLevel, }; @@ -501,7 +503,8 @@ impl AotContractExecutor { }; for (idx, elem) in args.iter().enumerate() { - let f = elem.to_bytes_le(); + let f = elem.to_monty_bytes_le(); + unsafe { std::ptr::copy_nonoverlapping( f.as_ptr().cast::(), @@ -667,9 +670,15 @@ impl AotContractExecutor { let cur_elem_ptr = unsafe { array_ptr.byte_add(elem_stride * i as usize) }; let mut data = unsafe { cur_elem_ptr.cast::<[u8; 32]>().read() }; + data[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). - array_value.push(Felt::from_bytes_le(&data)); + // Felts are represented in Montgomery form. Due to this, we + // need to convert them back to their original representation. + array_value.push( + montgomery::felt_from_monty_bytes(&data) + .to_native_assert_error("Couldn't create felt from Montgomery bytes")?, + ); } unsafe { diff --git a/src/libfuncs/bool.rs b/src/libfuncs/bool.rs index afa1c490b4..d536b43098 100644 --- a/src/libfuncs/bool.rs +++ b/src/libfuncs/bool.rs @@ -5,7 +5,7 @@ use crate::{ error::{panic::ToNativeAssertError, Result}, metadata::MetadataStorage, types::TypeBuilder, - utils::ProgramRegistryExt, + utils::{montgomery, ProgramRegistryExt}, }; use cairo_lang_sierra::{ extensions::{ @@ -200,9 +200,10 @@ pub fn build_bool_to_felt252<'ctx, 'this>( let value = entry.arg(0)?; let tag_value = entry.extract_value(context, location, value, tag_ty, 0)?; - let result = entry.extui(tag_value, felt252_ty, location)?; + // Convert into Montgomery representation. + let felt = montgomery::mlir::monty_transform(context, entry, tag_value, felt252_ty, location)?; - helper.br(entry, 0, &[result], location) + helper.br(entry, 0, &[felt], location) } #[cfg(test)] diff --git a/src/libfuncs/bounded_int.rs b/src/libfuncs/bounded_int.rs index 6a2d8e3837..acc4b40471 100644 --- a/src/libfuncs/bounded_int.rs +++ b/src/libfuncs/bounded_int.rs @@ -7,7 +7,7 @@ use crate::{ metadata::MetadataStorage, native_assert, types::TypeBuilder, - utils::RangeExt, + utils::{montgomery, RangeExt}, }; use cairo_lang_sierra::{ extensions::{ @@ -149,6 +149,17 @@ fn build_add<'ctx, 'this>( rhs_value }; + let lhs_value = if lhs_ty.is_felt252(registry)? { + montgomery::mlir::monty_reduce(context, entry, lhs_value, lhs_value.r#type(), location)? + } else { + lhs_value + }; + let rhs_value = if rhs_ty.is_felt252(registry)? { + montgomery::mlir::monty_reduce(context, entry, rhs_value, rhs_value.r#type(), location)? + } else { + rhs_value + }; + // Addition and get the result value on the desired range let res_value = entry.addi(lhs_value, rhs_value, location)?; let res_value = if compute_width > dst_width { @@ -248,6 +259,17 @@ fn build_sub<'ctx, 'this>( rhs_value }; + let lhs_value = if lhs_ty.is_felt252(registry)? { + montgomery::mlir::monty_reduce(context, entry, lhs_value, lhs_value.r#type(), location)? + } else { + lhs_value + }; + let rhs_value = if rhs_ty.is_felt252(registry)? { + montgomery::mlir::monty_reduce(context, entry, rhs_value, rhs_value.r#type(), location)? + } else { + rhs_value + }; + let compile_time_val = entry.const_int_from_type(context, location, compile_time_val, compute_ty)?; // First we do -> intermediate_res = Ad - Bd @@ -357,6 +379,8 @@ fn build_mul<'ctx, 'this>( let lhs_offset = entry.const_int_from_type(context, location, lhs_range.lower, compute_ty)?; entry.addi(lhs_value, lhs_offset, location)? + } else if lhs_ty.is_felt252(registry)? { + montgomery::mlir::monty_reduce(context, entry, lhs_value, lhs_value.r#type(), location)? } else { lhs_value }; @@ -364,6 +388,8 @@ fn build_mul<'ctx, 'this>( let rhs_offset = entry.const_int_from_type(context, location, rhs_range.lower, compute_ty)?; entry.addi(rhs_value, rhs_offset, location)? + } else if rhs_ty.is_felt252(registry)? { + montgomery::mlir::monty_reduce(context, entry, rhs_value, rhs_value.r#type(), location)? } else { rhs_value }; @@ -489,6 +515,8 @@ fn build_div_rem<'ctx, 'this>( let lhs_offset = entry.const_int_from_type(context, location, lhs_range.lower, compute_ty)?; entry.addi(lhs_value, lhs_offset, location)? + } else if lhs_ty.is_felt252(registry)? { + montgomery::mlir::monty_reduce(context, entry, lhs_value, lhs_value.r#type(), location)? } else { lhs_value }; @@ -496,6 +524,8 @@ fn build_div_rem<'ctx, 'this>( let rhs_offset = entry.const_int_from_type(context, location, rhs_range.lower, compute_ty)?; entry.addi(rhs_value, rhs_offset, location)? + } else if rhs_ty.is_felt252(registry)? { + montgomery::mlir::monty_reduce(context, entry, rhs_value, rhs_value.r#type(), location)? } else { rhs_value }; @@ -578,7 +608,6 @@ fn build_constrain<'ctx, 'this>( info: &BoundedIntConstrainConcreteLibfunc, ) -> Result<()> { let range_check = super::increment_builtin_counter(context, entry, location, entry.arg(0)?)?; - let src_value: Value = entry.arg(1)?; let src_ty = registry.get_type(&info.param_signatures()[1].ty)?; let src_range = src_ty.integer_range(registry)?; @@ -589,6 +618,18 @@ fn build_constrain<'ctx, 'this>( src_range.zero_based_bit_width() }; + let src_value: Value = if src_ty.is_felt252(registry)? { + montgomery::mlir::monty_reduce( + context, + entry, + entry.arg(1)?, + IntegerType::new(context, 252).into(), + location, + )? + } else { + entry.arg(1)? + }; + let lower_range = registry .get_type(&info.branch_signatures()[0].vars[1].ty)? .integer_range(registry)?; @@ -783,6 +824,7 @@ fn build_is_zero<'ctx, 'this>( } else { entry.const_int_from_type(context, location, 0, src_value.r#type())? }; + let src_is_zero = entry.cmpi(context, CmpiPredicate::Eq, src_value, k0, location)?; helper.cond_br( diff --git a/src/libfuncs/bytes31.rs b/src/libfuncs/bytes31.rs index 82a71ac76e..948d2e14bd 100644 --- a/src/libfuncs/bytes31.rs +++ b/src/libfuncs/bytes31.rs @@ -4,7 +4,7 @@ use super::LibfuncHelper; use crate::{ error::{Error, Result}, metadata::MetadataStorage, - utils::ProgramRegistryExt, + utils::{montgomery, ProgramRegistryExt}, }; use cairo_lang_sierra::{ extensions::{ @@ -22,7 +22,7 @@ use melior::{ cf, }, helpers::{ArithBlockExt, BuiltinBlockExt}, - ir::{Attribute, Block, BlockLike, Location, Value}, + ir::{r#type::IntegerType, Attribute, Block, BlockLike, Location, Value}, Context, }; use num_bigint::BigUint; @@ -96,7 +96,11 @@ pub fn build_to_felt252<'ctx, 'this>( )?; let value: Value = entry.arg(0)?; - let result = entry.extui(value, felt252_ty, location)?; + let result = { + let result = entry.extui(value, felt252_ty, location)?; + + montgomery::mlir::monty_transform(context, entry, result, felt252_ty, location)? + }; helper.br(entry, 0, &[result], location) } @@ -116,7 +120,13 @@ pub fn build_from_felt252<'ctx, 'this>( let range_check: Value = super::increment_builtin_counter_by(context, entry, location, entry.arg(0)?, 3)?; - let value: Value = entry.arg(1)?; + let value: Value = montgomery::mlir::monty_reduce( + context, + entry, + entry.arg(1)?, + IntegerType::new(context, 252).into(), + location, + )?; let felt252_ty = registry.build_type(context, helper, metadata, &info.param_signatures()[1].ty)?; diff --git a/src/libfuncs/cast.rs b/src/libfuncs/cast.rs index 0171edaf0e..fc0e93d088 100644 --- a/src/libfuncs/cast.rs +++ b/src/libfuncs/cast.rs @@ -7,7 +7,10 @@ use crate::{ metadata::MetadataStorage, native_assert, native_panic, types::TypeBuilder, - utils::{RangeExt, HALF_PRIME, PRIME}, + utils::{ + montgomery::{self}, + RangeExt, HALF_PRIME, PRIME, + }, }; use cairo_lang_sierra::{ extensions::{ @@ -152,26 +155,49 @@ pub fn build_downcast<'ctx, 'this>( // [-P/2, P/2]. // 2. if it is a bounded_int, we need to offset the value to get the // actual value. - let src_value = if is_signed && src_ty.is_felt252(registry)? { - if src_range.upper.is_one() { - let adj_offset = - entry.const_int_from_type(context, location, PRIME.clone(), src_value.r#type())?; - entry.append_op_result(arith::subi(src_value, adj_offset, location))? + let src_value = if src_ty.is_felt252(registry)? { + let felt252_ty = IntegerType::new(context, 252).into(); + + let src_value = + montgomery::mlir::monty_reduce(context, entry, src_value, felt252_ty, location)?; + + if is_signed { + if src_range.upper.is_one() { + let adj_offset = entry.const_int_from_type( + context, + location, + PRIME.clone(), + src_value.r#type(), + )?; + entry.append_op_result(arith::subi(src_value, adj_offset, location))? + } else { + let adj_offset = entry.const_int_from_type( + context, + location, + HALF_PRIME.clone(), + src_value.r#type(), + )?; + let is_negative = + entry.cmpi(context, CmpiPredicate::Ugt, src_value, adj_offset, location)?; + + let k_prime = entry.const_int_from_type( + context, + location, + PRIME.clone(), + src_value.r#type(), + )?; + let adj_value = + entry.append_op_result(arith::subi(src_value, k_prime, location))?; + + entry.append_op_result(arith::select( + is_negative, + adj_value, + src_value, + location, + ))? + } } else { - let adj_offset = entry.const_int_from_type( - context, - location, - HALF_PRIME.clone(), - src_value.r#type(), - )?; - let is_negative = - entry.cmpi(context, CmpiPredicate::Ugt, src_value, adj_offset, location)?; - - let k_prime = - entry.const_int_from_type(context, location, PRIME.clone(), src_value.r#type())?; - let adj_value = entry.append_op_result(arith::subi(src_value, k_prime, location))?; - - entry.append_op_result(arith::select(is_negative, adj_value, src_value, location))? + src_value } } else if src_ty.is_bounded_int(registry)? && src_range.lower != BigInt::ZERO { let dst_offset = entry.const_int_from_type( @@ -470,14 +496,22 @@ pub fn build_upcast<'ctx, 'this>( // When converting to a felt from a signed integer, we need to convert // the canonical signed integer representation, to the signed felt // representation: `negative = P - absolute`. - let dst_value = if dst_ty.is_felt252(registry)? && src_range.lower.sign() == Sign::Minus { - let k0 = entry.const_int(context, location, 0, 252)?; - let is_negative = entry.cmpi(context, CmpiPredicate::Slt, dst_value, k0, location)?; + let dst_value = if dst_ty.is_felt252(registry)? { + let felt252_ty = IntegerType::new(context, 252).into(); + + let value = if src_range.lower.sign() == Sign::Minus { + let k0 = entry.const_int(context, location, 0, 252)?; + let is_negative = entry.cmpi(context, CmpiPredicate::Slt, dst_value, k0, location)?; + + let k_prime = entry.const_int(context, location, PRIME.clone(), 252)?; + let adj_value = entry.addi(dst_value, k_prime, location)?; - let k_prime = entry.const_int(context, location, PRIME.clone(), 252)?; - let adj_value = entry.addi(dst_value, k_prime, location)?; + entry.append_op_result(arith::select(is_negative, adj_value, dst_value, location))? + } else { + dst_value + }; - entry.append_op_result(arith::select(is_negative, adj_value, dst_value, location))? + montgomery::mlir::monty_transform(context, entry, value, felt252_ty, location)? } else { dst_value }; diff --git a/src/libfuncs/const.rs b/src/libfuncs/const.rs index d6c52c3e84..6652f1ca9b 100644 --- a/src/libfuncs/const.rs +++ b/src/libfuncs/const.rs @@ -2,12 +2,12 @@ use super::LibfuncHelper; use crate::{ - error::{Error, Result}, + error::{panic::ToNativeAssertError, Error, Result}, libfuncs::{r#enum::build_enum_value, r#struct::build_struct_value}, metadata::{realloc_bindings::ReallocBindingsMeta, MetadataStorage}, native_panic, types::TypeBuilder, - utils::{ProgramRegistryExt, RangeExt, PRIME}, + utils::{montgomery::monty_transform, ProgramRegistryExt, RangeExt, PRIME}, }; use cairo_lang_sierra::{ extensions::{ @@ -265,8 +265,10 @@ pub fn build_const_type_value<'ctx, 'this>( Sign::Minus => PRIME.clone() - value, _ => value, }; - - Ok(entry.const_int_from_type(context, location, value, inner_ty)?) + let monty_value = monty_transform(&value, &PRIME).to_native_assert_error(&format!( + "could not transform felt252: {value} to Montgomery form" + ))?; + Ok(entry.const_int_from_type(context, location, monty_value, inner_ty)?) } CoreTypeConcrete::Starknet( StarknetTypeConcrete::ClassHash(_) | StarknetTypeConcrete::ContractAddress(_), @@ -281,8 +283,10 @@ pub fn build_const_type_value<'ctx, 'this>( Sign::Minus => PRIME.clone() - value, _ => value, }; - - Ok(entry.const_int_from_type(context, location, value, inner_ty)?) + let monty_value = monty_transform(&value, &PRIME).to_native_assert_error(&format!( + "could not transform felt252: {value} to Montgomery form" + ))?; + Ok(entry.const_int_from_type(context, location, monty_value, inner_ty)?) } CoreTypeConcrete::Uint8(_) | CoreTypeConcrete::Uint16(_) diff --git a/src/libfuncs/felt252.rs b/src/libfuncs/felt252.rs index af79646f12..53da58cacd 100644 --- a/src/libfuncs/felt252.rs +++ b/src/libfuncs/felt252.rs @@ -2,9 +2,9 @@ use super::LibfuncHelper; use crate::{ - error::Result, + error::{panic::ToNativeAssertError, Result}, metadata::MetadataStorage, - utils::{ProgramRegistryExt, PRIME}, + utils::{montgomery, ProgramRegistryExt, PRIME}, }; use cairo_lang_sierra::{ extensions::{ @@ -19,12 +19,9 @@ use cairo_lang_sierra::{ program_registry::ProgramRegistry, }; use melior::{ - dialect::{ - arith::{self, CmpiPredicate}, - cf, - }, + dialect::arith::{self, CmpiPredicate}, helpers::{ArithBlockExt, BuiltinBlockExt}, - ir::{r#type::IntegerType, Block, BlockLike, Location, Value, ValueLike}, + ir::{r#type::IntegerType, Block, Location, Value, ValueLike}, Context, }; use num_bigint::{BigInt, Sign}; @@ -73,7 +70,6 @@ pub fn build_binary_operation<'ctx, 'this>( &info.branch_signatures()[0].vars[0].ty, )?; let i256 = IntegerType::new(context, 256).into(); - let i512 = IntegerType::new(context, 512).into(); let (op, lhs, rhs) = match info { Felt252BinaryOperationConcrete::WithVar(operation) => { @@ -86,9 +82,12 @@ pub fn build_binary_operation<'ctx, 'this>( .clone(), _ => operation.c.magnitude().clone(), }; + let monty_value = montgomery::monty_transform(&value, &PRIME).to_native_assert_error( + &format!("could not transform felt252: {value} to Montgomery form"), + )?; // TODO: Ensure that the constant is on the correct side of the operation. - let rhs = entry.const_int_from_type(context, location, value, felt252_ty)?; + let rhs = entry.const_int_from_type(context, location, monty_value, felt252_ty)?; (operation.operator, entry.arg(0)?, rhs) } @@ -101,7 +100,7 @@ pub fn build_binary_operation<'ctx, 'this>( let result = entry.addi(lhs, rhs, location)?; let prime = entry.const_int_from_type(context, location, PRIME.clone(), i256)?; - let result_mod = entry.append_op_result(arith::subi(result, prime, location))?; + let result_mod = entry.subi(result, prime, location)?; let is_out_of_range = entry.cmpi(context, CmpiPredicate::Uge, result, prime, location)?; @@ -111,12 +110,13 @@ pub fn build_binary_operation<'ctx, 'this>( result, location, ))?; + entry.trunci(result, felt252_ty, location)? } Felt252BinaryOperator::Sub => { let lhs = entry.extui(lhs, i256, location)?; let rhs = entry.extui(rhs, i256, location)?; - let result = entry.append_op_result(arith::subi(lhs, rhs, location))?; + let result = entry.subi(lhs, rhs, location)?; let prime = entry.const_int_from_type(context, location, PRIME.clone(), i256)?; let result_mod = entry.addi(result, prime, location)?; @@ -131,148 +131,10 @@ pub fn build_binary_operation<'ctx, 'this>( entry.trunci(result, felt252_ty, location)? } Felt252BinaryOperator::Mul => { - let lhs = entry.extui(lhs, i512, location)?; - let rhs = entry.extui(rhs, i512, location)?; - let result = entry.muli(lhs, rhs, location)?; - - let prime = entry.const_int_from_type(context, location, PRIME.clone(), i512)?; - let result_mod = entry.append_op_result(arith::remui(result, prime, location))?; - let is_out_of_range = - entry.cmpi(context, CmpiPredicate::Uge, result, prime, location)?; - - let result = entry.append_op_result(arith::select( - is_out_of_range, - result_mod, - result, - location, - ))?; - entry.trunci(result, felt252_ty, location)? + montgomery::mlir::monty_mul(context, entry, lhs, rhs, felt252_ty, location)? } Felt252BinaryOperator::Div => { - // The extended euclidean algorithm calculates the greatest common divisor of two integers, - // as well as the bezout coefficients x and y such that for inputs a and b, ax+by=gcd(a,b) - // We use this in felt division to find the modular inverse of a given number - // If a is the number we're trying to find the inverse of, we can do - // ax+y*PRIME=gcd(a,PRIME)=1 => ax = 1 (mod PRIME) - // Hence for input a, we return x - // The input MUST be non-zero - // See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm - let start_block = helper.append_block(Block::new(&[(i512, location)])); - let loop_block = helper.append_block(Block::new(&[ - (i512, location), - (i512, location), - (i512, location), - (i512, location), - ])); - let negative_check_block = helper.append_block(Block::new(&[])); - // Block containing final result - let inverse_result_block = helper.append_block(Block::new(&[(i512, location)])); - // Egcd works by calculating a series of remainders, each the remainder of dividing the previous two - // For the initial setup, r0 = PRIME, r1 = a - // This order is chosen because if we reverse them, then the first iteration will just swap them - let prev_remainder = - start_block.const_int_from_type(context, location, PRIME.clone(), i512)?; - let remainder = start_block.arg(0)?; - // Similarly we'll calculate another series which starts 0,1,... and from which we will retrieve the modular inverse of a - let prev_inverse = start_block.const_int_from_type(context, location, 0, i512)?; - let inverse = start_block.const_int_from_type(context, location, 1, i512)?; - start_block.append_operation(cf::br( - loop_block, - &[prev_remainder, remainder, prev_inverse, inverse], - location, - )); - - //---Loop body--- - // Arguments are rem_(i-1), rem, inv_(i-1), inv - let prev_remainder = loop_block.arg(0)?; - let remainder = loop_block.arg(1)?; - let prev_inverse = loop_block.arg(2)?; - let inverse = loop_block.arg(3)?; - - // First calculate q = rem_(i-1)/rem_i, rounded down - let quotient = - loop_block.append_op_result(arith::divui(prev_remainder, remainder, location))?; - // Then r_(i+1) = r_(i-1) - q * r_i, and inv_(i+1) = inv_(i-1) - q * inv_i - let rem_times_quo = loop_block.muli(remainder, quotient, location)?; - let inv_times_quo = loop_block.muli(inverse, quotient, location)?; - let next_remainder = loop_block.append_op_result(arith::subi( - prev_remainder, - rem_times_quo, - location, - ))?; - let next_inverse = - loop_block.append_op_result(arith::subi(prev_inverse, inv_times_quo, location))?; - - // If r_(i+1) is 0, then inv_i is the inverse - let zero = loop_block.const_int_from_type(context, location, 0, i512)?; - let next_remainder_eq_zero = - loop_block.cmpi(context, CmpiPredicate::Eq, next_remainder, zero, location)?; - loop_block.append_operation(cf::cond_br( - context, - next_remainder_eq_zero, - negative_check_block, - loop_block, - &[], - &[remainder, next_remainder, inverse, next_inverse], - location, - )); - - // egcd sometimes returns a negative number for the inverse, - // in such cases we must simply wrap it around back into [0, PRIME) - // this suffices because |inv_i| <= divfloor(PRIME,2) - let zero = negative_check_block.const_int_from_type(context, location, 0, i512)?; - - let is_negative = negative_check_block - .append_operation(arith::cmpi( - context, - CmpiPredicate::Slt, - inverse, - zero, - location, - )) - .result(0)? - .into(); - // if the inverse is < 0, add PRIME - let prime = - negative_check_block.const_int_from_type(context, location, PRIME.clone(), i512)?; - let wrapped_inverse = negative_check_block.addi(inverse, prime, location)?; - let inverse = negative_check_block.append_op_result(arith::select( - is_negative, - wrapped_inverse, - inverse, - location, - ))?; - negative_check_block.append_operation(cf::br( - inverse_result_block, - &[inverse], - location, - )); - - // Div Logic Start - // Fetch operands - let lhs = entry.extui(lhs, i512, location)?; - let rhs = entry.extui(rhs, i512, location)?; - // Calculate inverse of rhs, callling the inverse implementation's starting block - entry.append_operation(cf::br(start_block, &[rhs], location)); - // Fetch the inverse result from the result block - let inverse = inverse_result_block.arg(0)?; - // Peform lhs * (1/ rhs) - let result = inverse_result_block.muli(lhs, inverse, location)?; - // Apply modulo and convert result to felt252 - let result_mod = - inverse_result_block.append_op_result(arith::remui(result, prime, location))?; - let is_out_of_range = - inverse_result_block.cmpi(context, CmpiPredicate::Uge, result, prime, location)?; - - let result = inverse_result_block.append_op_result(arith::select( - is_out_of_range, - result_mod, - result, - location, - ))?; - let result = inverse_result_block.trunci(result, felt252_ty, location)?; - - return helper.br(inverse_result_block, 0, &[result], location); + montgomery::mlir::monty_div(context, entry, lhs, rhs, felt252_ty, location)? } }; @@ -303,7 +165,10 @@ pub fn build_const<'ctx, 'this>( &info.branch_signatures()[0].vars[0].ty, )?; - let value = entry.const_int_from_type(context, location, value, felt252_ty)?; + let monty_value = montgomery::monty_transform(&value, &PRIME).to_native_assert_error( + &format!("could not transform felt252: {value} to Montgomery form"), + )?; + let value = entry.const_int_from_type(context, location, monty_value, felt252_ty)?; helper.br(entry, 0, &[value], location) } diff --git a/src/libfuncs/int.rs b/src/libfuncs/int.rs index c8f1e5e478..b8bb306db5 100644 --- a/src/libfuncs/int.rs +++ b/src/libfuncs/int.rs @@ -6,7 +6,7 @@ use crate::{ metadata::MetadataStorage, native_panic, types::TypeBuilder, - utils::{ProgramRegistryExt, PRIME}, + utils::{montgomery, ProgramRegistryExt, PRIME}, }; use cairo_lang_sierra::{ extensions::{ @@ -387,7 +387,6 @@ fn build_from_felt252<'ctx, 'this>( let value_ty = registry.get_type(&info.signature.branch_signatures[0].vars[1].ty)?; let threshold = value_ty.integer_range(registry)?; let threshold_size = threshold.size(); - let value_ty = value_ty.build( context, helper, @@ -396,7 +395,10 @@ fn build_from_felt252<'ctx, 'this>( &info.signature.branch_signatures[0].vars[1].ty, )?; + // We casting from a felt, so we need convert it from Montgomery back to + // its original representation. let input = entry.arg(1)?; + let input = montgomery::mlir::monty_reduce(context, entry, input, input.r#type(), location)?; // Handle signedness separately. let (is_in_range, value) = if threshold.lower.is_zero() { @@ -894,9 +896,12 @@ fn build_to_felt252<'ctx, 'this>( entry.append_op_result(arith::select(is_negative, neg_value, value, location))? } else { - entry.extui(entry.arg(0)?, felt252_ty, location)? + entry.arg(0)? }; + // We are casting to a felt, so we need convert it into Montgomery. + let value = montgomery::mlir::monty_transform(context, entry, value, felt252_ty, location)?; + helper.br(entry, 0, &[value], location) } @@ -911,10 +916,15 @@ fn build_u128s_from_felt252<'ctx, 'this>( ) -> Result<()> { let target_ty = IntegerType::new(context, 128).into(); - let lo = entry.trunci(entry.arg(1)?, target_ty, location)?; + // We casting from a felt, so we need convert it from Montgomery back to + // its original representation. + let felt = entry.arg(1)?; + let lo = montgomery::mlir::monty_reduce(context, entry, felt, felt.r#type(), location)?; + + let k128 = entry.const_int_from_type(context, location, 128, felt.r#type())?; + let hi = entry.shrui(lo, k128, location)?; - let k128 = entry.const_int_from_type(context, location, 128, entry.arg(1)?.r#type())?; - let hi = entry.shrui(entry.arg(1)?, k128, location)?; + let lo = entry.trunci(lo, target_ty, location)?; let hi = entry.trunci(hi, target_ty, location)?; let k0 = entry.const_int_from_type(context, location, 0, target_ty)?; diff --git a/src/libfuncs/starknet.rs b/src/libfuncs/starknet.rs index 4c1248b36b..dabaf373b3 100644 --- a/src/libfuncs/starknet.rs +++ b/src/libfuncs/starknet.rs @@ -2,11 +2,11 @@ use super::LibfuncHelper; use crate::{ - error::{Error, Result}, + error::{panic::ToNativeAssertError, Error, Result}, ffi::get_struct_field_type_at, metadata::{drop_overrides::DropOverridesMeta, MetadataStorage}, starknet::handler::StarknetSyscallHandlerCallbacks, - utils::{get_integer_layout, ProgramRegistryExt, PRIME}, + utils::{get_integer_layout, montgomery, ProgramRegistryExt, PRIME}, }; use cairo_lang_sierra::{ extensions::{ @@ -378,10 +378,14 @@ pub fn build_class_hash_const<'ctx, 'this>( let value = entry.const_int( context, location, - match info.c.sign() { - Sign::Minus => &*PRIME - info.c.magnitude(), - _ => info.c.magnitude().clone(), - }, + montgomery::monty_transform( + &match info.c.sign() { + Sign::Minus => &*PRIME - info.c.magnitude(), + _ => info.c.magnitude().clone(), + }, + &PRIME, + ) + .to_native_assert_error("couldn't transform Felt into Montgomery space")?, 252, )?; @@ -403,6 +407,8 @@ pub fn build_class_hash_try_from_felt252<'ctx, 'this>( super::increment_builtin_counter_by(context, entry, location, entry.arg(0)?, 3)?; let value = entry.arg(1)?; + let tmp_value = + montgomery::mlir::monty_reduce(context, entry, value, value.r#type(), location)?; let limit = entry.append_op_result(arith::constant( context, @@ -413,7 +419,7 @@ pub fn build_class_hash_try_from_felt252<'ctx, 'this>( .ok_or(Error::ParseAttributeError)?, location, ))?; - let is_in_range = entry.cmpi(context, CmpiPredicate::Ult, value, limit, location)?; + let is_in_range = entry.cmpi(context, CmpiPredicate::Ult, tmp_value, limit, location)?; helper.cond_br( context, @@ -437,10 +443,14 @@ pub fn build_contract_address_const<'ctx, 'this>( let value = entry.const_int( context, location, - match info.c.sign() { - Sign::Minus => &*PRIME - info.c.magnitude(), - _ => info.c.magnitude().clone(), - }, + montgomery::monty_transform( + &match info.c.sign() { + Sign::Minus => &*PRIME - info.c.magnitude(), + _ => info.c.magnitude().clone(), + }, + &PRIME, + ) + .to_native_assert_error("couldn't transform Felt into Montgomery space")?, 252, )?; @@ -462,6 +472,8 @@ pub fn build_contract_address_try_from_felt252<'ctx, 'this>( super::increment_builtin_counter_by(context, entry, location, entry.arg(0)?, 3)?; let value = entry.arg(1)?; + let tmp_value = + montgomery::mlir::monty_reduce(context, entry, value, value.r#type(), location)?; let limit = entry.append_op_result(arith::constant( context, @@ -472,7 +484,7 @@ pub fn build_contract_address_try_from_felt252<'ctx, 'this>( .ok_or(Error::ParseAttributeError)?, location, ))?; - let is_in_range = entry.cmpi(context, CmpiPredicate::Ult, value, limit, location)?; + let is_in_range = entry.cmpi(context, CmpiPredicate::Ult, tmp_value, limit, location)?; helper.cond_br( context, @@ -819,10 +831,14 @@ pub fn build_storage_base_address_const<'ctx, 'this>( let value = entry.const_int( context, location, - match info.c.sign() { - Sign::Minus => &*PRIME - info.c.magnitude(), - _ => info.c.magnitude().clone(), - }, + montgomery::monty_transform( + &match info.c.sign() { + Sign::Minus => &*PRIME - info.c.magnitude(), + _ => info.c.magnitude().clone(), + }, + &PRIME, + ) + .to_native_assert_error("couldn't transform Felt into Montgomery space")?, 252, )?; @@ -843,6 +859,10 @@ pub fn build_storage_base_address_from_felt252<'ctx, 'this>( let range_check = super::increment_builtin_counter_by(context, entry, location, entry.arg(0)?, 3)?; + let value = entry.arg(1)?; + let tmp_value = + montgomery::mlir::monty_reduce(context, entry, value, value.r#type(), location)?; + let k_limit = entry.append_op_result(arith::constant( context, Attribute::parse( @@ -853,19 +873,19 @@ pub fn build_storage_base_address_from_felt252<'ctx, 'this>( location, ))?; - let limited_value = entry.append_op_result(arith::subi(entry.arg(1)?, k_limit, location))?; + let limited_value = entry.append_op_result(arith::subi(tmp_value, k_limit, location))?; - let is_within_limit = entry.cmpi( - context, - CmpiPredicate::Ult, - entry.arg(1)?, - k_limit, - location, - )?; + let is_within_limit = entry.cmpi(context, CmpiPredicate::Ult, tmp_value, k_limit, location)?; let value = entry.append_op_result(arith::select( is_within_limit, - entry.arg(1)?, - limited_value, + value, + montgomery::mlir::monty_transform( + context, + entry, + limited_value, + limited_value.r#type(), + location, + )?, location, ))?; @@ -873,7 +893,7 @@ pub fn build_storage_base_address_from_felt252<'ctx, 'this>( } pub fn build_storage_address_from_base_and_offset<'ctx, 'this>( - _context: &'ctx Context, + context: &'ctx Context, _registry: &ProgramRegistry, entry: &'this Block<'ctx>, location: Location<'ctx>, @@ -881,10 +901,21 @@ pub fn build_storage_address_from_base_and_offset<'ctx, 'this>( _metadata: &mut MetadataStorage, _info: &SignatureOnlyConcreteLibfunc, ) -> Result<()> { - let offset = entry.extui(entry.arg(1)?, entry.argument(0)?.r#type(), location)?; - let addr = entry.addi(entry.arg(0)?, offset, location)?; + let felt252_ty = entry.arg(0)?.r#type(); - helper.br(entry, 0, &[addr], location) + let addr_base = + montgomery::mlir::monty_reduce(context, entry, entry.arg(0)?, felt252_ty, location)?; + let offset = entry.extui(entry.arg(1)?, felt252_ty, location)?; + let addr = entry.addi(addr_base, offset, location)?; + + helper.br( + entry, + 0, + &[montgomery::mlir::monty_transform( + context, entry, addr, felt252_ty, location, + )?], + location, + ) } pub fn build_storage_address_try_from_felt252<'ctx, 'this>( @@ -902,6 +933,8 @@ pub fn build_storage_address_try_from_felt252<'ctx, 'this>( super::increment_builtin_counter_by(context, entry, location, entry.arg(0)?, 3)?; let value = entry.arg(1)?; + let tmp_value = + montgomery::mlir::monty_reduce(context, entry, value, value.r#type(), location)?; let limit = entry.append_op_result(arith::constant( context, @@ -912,7 +945,7 @@ pub fn build_storage_address_try_from_felt252<'ctx, 'this>( .ok_or(Error::ParseAttributeError)?, location, ))?; - let is_in_range = entry.cmpi(context, CmpiPredicate::Ult, value, limit, location)?; + let is_in_range = entry.cmpi(context, CmpiPredicate::Ult, tmp_value, limit, location)?; helper.cond_br( context, diff --git a/src/metadata/trace_dump.rs b/src/metadata/trace_dump.rs index a3fb3833be..74f3403efb 100644 --- a/src/metadata/trace_dump.rs +++ b/src/metadata/trace_dump.rs @@ -228,7 +228,7 @@ pub mod trace_dump_runtime { use crate::{ starknet::ArrayAbi, types::TypeBuilder, - utils::{get_integer_layout, layout_repeat}, + utils::{get_integer_layout, layout_repeat, montgomery}, }; use crate::runtime::FeltDict; @@ -308,7 +308,8 @@ pub mod trace_dump_runtime { | CoreTypeConcrete::Starknet(StarknetTypeConcrete::ClassHash(_)) | CoreTypeConcrete::Starknet(StarknetTypeConcrete::StorageAddress(_)) | CoreTypeConcrete::Starknet(StarknetTypeConcrete::StorageBaseAddress(_)) => { - Value::Felt(Felt::from_bytes_le(value_ptr.cast().as_ref())) + let value = montgomery::felt_from_monty_bytes(value_ptr.cast().as_ref()).unwrap(); + Value::Felt(value) } CoreTypeConcrete::Uint8(_) => Value::U8(value_ptr.cast().read()), CoreTypeConcrete::Uint16(_) => Value::U16(value_ptr.cast().read()), @@ -340,14 +341,20 @@ pub mod trace_dump_runtime { let (x, layout) = { let (layout, offset) = layout.extend(Layout::new::<[u128; 2]>()).unwrap(); ( - Felt::from_bytes_le(value_ptr.byte_add(offset).cast().as_ref()), + montgomery::felt_from_monty_bytes( + value_ptr.byte_add(offset).cast().as_ref(), + ) + .unwrap(), layout, ) }; let (y, _) = { let (layout, offset) = layout.extend(Layout::new::<[u128; 2]>()).unwrap(); ( - Felt::from_bytes_le(value_ptr.byte_add(offset).cast().as_ref()), + montgomery::felt_from_monty_bytes( + value_ptr.byte_add(offset).cast().as_ref(), + ) + .unwrap(), layout, ) }; @@ -359,28 +366,40 @@ pub mod trace_dump_runtime { let (x0, layout) = { let (layout, offset) = layout.extend(Layout::new::<[u128; 2]>()).unwrap(); ( - Felt::from_bytes_le(value_ptr.byte_add(offset).cast().as_ref()), + montgomery::felt_from_monty_bytes( + value_ptr.byte_add(offset).cast().as_ref(), + ) + .unwrap(), layout, ) }; let (y0, layout) = { let (layout, offset) = layout.extend(Layout::new::<[u128; 2]>()).unwrap(); ( - Felt::from_bytes_le(value_ptr.byte_add(offset).cast().as_ref()), + montgomery::felt_from_monty_bytes( + value_ptr.byte_add(offset).cast().as_ref(), + ) + .unwrap(), layout, ) }; let (x1, layout) = { let (layout, offset) = layout.extend(Layout::new::<[u128; 2]>()).unwrap(); ( - Felt::from_bytes_le(value_ptr.byte_add(offset).cast().as_ref()), + montgomery::felt_from_monty_bytes( + value_ptr.byte_add(offset).cast().as_ref(), + ) + .unwrap(), layout, ) }; let (y1, _) = { let (layout, offset) = layout.extend(Layout::new::<[u128; 2]>()).unwrap(); ( - Felt::from_bytes_le(value_ptr.byte_add(offset).cast().as_ref()), + montgomery::felt_from_monty_bytes( + value_ptr.byte_add(offset).cast().as_ref(), + ) + .unwrap(), layout, ) }; @@ -688,11 +707,11 @@ pub mod trace_dump_runtime { ]) } }, - CoreTypeConcrete::Const(_) => todo!("CoreTypeConcrete::Const"), + CoreTypeConcrete::Const(info) => value_from_ptr(registry, &info.inner_ty, value_ptr), CoreTypeConcrete::Sint8(_) => Value::I8(value_ptr.cast().read()), - CoreTypeConcrete::Sint16(_) => todo!("CoreTypeConcrete::Sint16"), + CoreTypeConcrete::Sint16(_) => Value::I16(value_ptr.cast().read()), CoreTypeConcrete::Sint32(_) => Value::I32(value_ptr.cast().read()), - CoreTypeConcrete::Sint64(_) => todo!("CoreTypeConcrete::Sint64"), + CoreTypeConcrete::Sint64(_) => Value::I64(value_ptr.cast().read()), CoreTypeConcrete::Sint128(_) => Value::I128(value_ptr.cast().read()), CoreTypeConcrete::Nullable(info) => { let inner_ptr = value_ptr.cast::<*mut ()>().read(); @@ -720,7 +739,7 @@ pub mod trace_dump_runtime { ty: info.ty.clone(), }, }; - let k = Felt::from_bytes_le(k); + let k = montgomery::felt_from_monty_bytes(k).unwrap(); (k, v) }) .collect::>(); @@ -749,11 +768,11 @@ pub mod trace_dump_runtime { ty: info.ty.clone(), }, }; - let k = Felt::from_bytes_le(k); + let k = montgomery::felt_from_monty_bytes(k).unwrap(); (k, v) }) .collect::>(); - let key = Felt::from_bytes_le(value.key); + let key = montgomery::felt_from_monty_bytes(value.key).unwrap(); Value::FeltDictEntry { ty: info.ty.clone(), @@ -808,7 +827,7 @@ pub mod trace_dump_runtime { data[i] = v } - Value::Bytes31(Felt::from_bytes_le(&data)) + Value::Bytes31(montgomery::felt_from_monty_bytes(&data).unwrap()) } CoreTypeConcrete::IntRange(info) => { let type_info = registry.get_type(&info.ty).unwrap(); diff --git a/src/runtime.rs b/src/runtime.rs index 249d562fad..2dca160ad8 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -1,6 +1,9 @@ #![allow(non_snake_case)] -use crate::utils::BuiltinCosts; +use crate::utils::{ + montgomery::{self, MontyBytes}, + BuiltinCosts, +}; use cairo_lang_sierra_gas::core_libfunc_cost::{ DICT_SQUASH_REPEATED_ACCESS_COST, DICT_SQUASH_UNIQUE_KEY_COST, }; @@ -61,8 +64,10 @@ pub unsafe extern "C" fn cairo_native__libfunc__debug__print( let mut data = *data.add(i); data[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). - let value = Felt::from_bytes_le(&data); - items.push(value); + items.push( + montgomery::felt_from_monty_bytes(&data) + .expect("Couldn't create felt from Montgomery bytes"), + ); } let value = format_for_debug(items.into_iter()); @@ -98,13 +103,15 @@ pub unsafe extern "C" fn cairo_native__libfunc__pedersen( lhs[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). rhs[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). - // Convert to FieldElement. - let lhs = Felt::from_bytes_le(&lhs); - let rhs = Felt::from_bytes_le(&rhs); + // Coexpect + let lhs = montgomery::felt_from_monty_bytes(&lhs) + .expect("Couldn't create felt from Montgomery bytes"); + let rhs = montgomery::felt_from_monty_bytes(&rhs) + .expect("Couldn't create felt from Montgomery bytes"); // Compute pedersen hash and copy the result into `dst`. let res = starknet_types_core::hash::Pedersen::hash(&lhs, &rhs); - *dst = res.to_bytes_le(); + *dst = res.to_monty_bytes_le(); } /// Compute `hades_permutation(op0, op1, op2)` and replace the operands with the results. @@ -130,18 +137,18 @@ pub unsafe extern "C" fn cairo_native__libfunc__hades_permutation( // Convert to FieldElement. let mut state = [ - Felt::from_bytes_le(op0), - Felt::from_bytes_le(op1), - Felt::from_bytes_le(op2), + montgomery::felt_from_monty_bytes(op0).expect("Couldn't create felt from Montgomery bytes"), + montgomery::felt_from_monty_bytes(op1).expect("Couldn't create felt from Montgomery bytes"), + montgomery::felt_from_monty_bytes(op2).expect("Couldn't create felt from Montgomery bytes"), ]; // Compute Poseidon permutation. starknet_types_core::hash::Poseidon::hades_permutation(&mut state); // Write back the results. - *op0 = state[0].to_bytes_le(); - *op1 = state[1].to_bytes_le(); - *op2 = state[2].to_bytes_le(); + *op0 = state[0].to_monty_bytes_le(); + *op1 = state[1].to_monty_bytes_le(); + *op2 = state[2].to_monty_bytes_le(); } /// Felt252 type used in cairo native runtime @@ -319,7 +326,12 @@ pub unsafe extern "C" fn cairo_native__dict_squash( let no_big_keys = dict .mappings .keys() - .map(Felt::from_bytes_le) + // Felts are represented in Montgomery form. Due to this, we + // need to convert them back to their original representation. + .map(|b| { + montgomery::felt_from_monty_bytes(b) + .expect("Couldn't create felt from Montgomery bytes") + }) .all(|key| key < Felt::from(BigInt::from(1).shl(128))); let number_of_keys = dict.mappings.len() as u64; @@ -379,7 +391,11 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_point_from_x_nz( point_ptr: &mut [[u8; 32]; 2], ) -> bool { point_ptr[0][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). - let x = Felt::from_bytes_le(&point_ptr[0]); + + // Felts are represented in Montgomery form. Due to this, we + // need to convert them back to their original representation. + let x = montgomery::felt_from_monty_bytes(&point_ptr[0]) + .expect("Couldn't create felt from Montgomery bytes"); // https://github.com/starkware-libs/cairo/blob/aaad921bba52e729dc24ece07fab2edf09ccfa15/crates/cairo-lang-sierra-to-casm/src/invocations/ec.rs#L63 @@ -395,7 +411,7 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_point_from_x_nz( match AffinePoint::new(x, y) { Ok(point) => { - point_ptr[1] = point.y().to_bytes_le(); + point_ptr[1] = point.y().to_monty_bytes_le(); true } Err(_) => false, @@ -418,13 +434,17 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_point_try_new_nz( point_ptr[0][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). point_ptr[1][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). - let x = Felt::from_bytes_le(&point_ptr[0]); - let y = Felt::from_bytes_le(&point_ptr[1]); + // Felts are represented in Montgomery form, so we need to convert them + // back their original representation before operating. + let x = montgomery::felt_from_monty_bytes(&point_ptr[0]) + .expect("Couldn't create felt from Montgomery bytes"); + let y = montgomery::felt_from_monty_bytes(&point_ptr[1]) + .expect("Couldn't create felt from Montgomery bytes"); match AffinePoint::new(x, y) { Ok(point) => { - point_ptr[0] = point.x().to_bytes_le(); - point_ptr[1] = point.y().to_bytes_le(); + point_ptr[0] = point.x().to_monty_bytes_le(); + point_ptr[1] = point.y().to_monty_bytes_le(); true } Err(_) => false, @@ -453,8 +473,8 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_init(state_ptr: &mu // We already made sure its a valid point. let state = AffinePoint::new_unchecked(random_x, random_y); - state_ptr[0] = state.x().to_bytes_le(); - state_ptr[1] = state.y().to_bytes_le(); + state_ptr[0] = state.x().to_monty_bytes_le(); + state_ptr[1] = state.y().to_monty_bytes_le(); state_ptr[2] = state_ptr[0]; state_ptr[3] = state_ptr[1]; } @@ -480,21 +500,28 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_add( point_ptr[0][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). point_ptr[1][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). - // We use unchecked methods because the inputs must already be valid points. + // Here the points should already be checked as valid, so we can use unchecked. + // + // Felts are represented in Montgomery form. Due to this, we + // need to convert them back to their original representation. let mut state = ProjectivePoint::from_affine_unchecked( - Felt::from_bytes_le(&state_ptr[0]), - Felt::from_bytes_le(&state_ptr[1]), + montgomery::felt_from_monty_bytes(&state_ptr[0]) + .expect("Couldn't create felt from Montgomery bytes"), + montgomery::felt_from_monty_bytes(&state_ptr[1]) + .expect("Couldn't create felt from Montgomery bytes"), ); let point = AffinePoint::new_unchecked( - Felt::from_bytes_le(&point_ptr[0]), - Felt::from_bytes_le(&point_ptr[1]), + montgomery::felt_from_monty_bytes(&point_ptr[0]) + .expect("Couldn't create felt from Montgomery bytes"), + montgomery::felt_from_monty_bytes(&point_ptr[1]) + .expect("Couldn't create felt from Montgomery bytes"), ); state += &point; let state = state.to_affine().unwrap(); - state_ptr[0] = state.x().to_bytes_le(); - state_ptr[1] = state.y().to_bytes_le(); + state_ptr[0] = state.x().to_monty_bytes_le(); + state_ptr[1] = state.y().to_monty_bytes_le(); } /// Compute `ec_state_add_mul(state, scalar, point)` and store the state back. @@ -523,21 +550,30 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_add_mul( scalar_ptr[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). // Here the points should already be checked as valid, so we can use unchecked. + // + // Felts are represented in Montgomery form. Due to this, we + // need to convert them back to their original representation. let mut state = ProjectivePoint::from_affine_unchecked( - Felt::from_bytes_le(&state_ptr[0]), - Felt::from_bytes_le(&state_ptr[1]), + montgomery::felt_from_monty_bytes(&state_ptr[0]) + .expect("Couldn't create felt from Montgomery bytes"), + montgomery::felt_from_monty_bytes(&state_ptr[1]) + .expect("Couldn't create felt from Montgomery bytes"), ); let point = ProjectivePoint::from_affine_unchecked( - Felt::from_bytes_le(&point_ptr[0]), - Felt::from_bytes_le(&point_ptr[1]), + montgomery::felt_from_monty_bytes(&point_ptr[0]) + .expect("Couldn't create felt from Montgomery bytes"), + montgomery::felt_from_monty_bytes(&point_ptr[1]) + .expect("Couldn't create felt from Montgomery bytes"), ); - let scalar = Felt::from_bytes_le(&scalar_ptr); + + let scalar = montgomery::felt_from_monty_bytes(&scalar_ptr) + .expect("Couldn't create felt from Montgomery bytes"); state += &point.mul(scalar); let state = state.to_affine().unwrap(); - state_ptr[0] = state.x().to_bytes_le(); - state_ptr[1] = state.y().to_bytes_le(); + state_ptr[0] = state.x().to_monty_bytes_le(); + state_ptr[1] = state.y().to_monty_bytes_le(); } /// Compute `ec_state_try_finalize_nz(state)` and store the result. @@ -562,12 +598,16 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_try_finalize_nz( // We use unchecked methods because the inputs must already be valid points. let state = ProjectivePoint::from_affine_unchecked( - Felt::from_bytes_le(&state_ptr[0]), - Felt::from_bytes_le(&state_ptr[1]), + montgomery::felt_from_monty_bytes(&state_ptr[0]) + .expect("Couldn't create felt from Montgomery bytes"), + montgomery::felt_from_monty_bytes(&state_ptr[1]) + .expect("Couldn't create felt from Montgomery bytes"), ); let random = ProjectivePoint::from_affine_unchecked( - Felt::from_bytes_le(&state_ptr[2]), - Felt::from_bytes_le(&state_ptr[3]), + montgomery::felt_from_monty_bytes(&state_ptr[2]) + .expect("Couldn't create felt from Montgomery bytes"), + montgomery::felt_from_monty_bytes(&state_ptr[3]) + .expect("Couldn't create felt from Montgomery bytes"), ); if state.x() == random.x() && state.y() == random.y() { @@ -576,8 +616,8 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_try_finalize_nz( let point = &state - &random; let point = point.to_affine().unwrap(); - point_ptr[0] = point.x().to_bytes_le(); - point_ptr[1] = point.y().to_bytes_le(); + point_ptr[0] = point.x().to_monty_bytes_le(); + point_ptr[1] = point.y().to_monty_bytes_le(); true } @@ -809,7 +849,7 @@ mod tests { { let fd = file.as_raw_fd(); let data = felt252_short_str("hello world"); - let data = data.to_bytes_le(); + let data = data.to_monty_bytes_le(); unsafe { cairo_native__libfunc__debug__print(fd, &data, 1) }; } file.seek(std::io::SeekFrom::Start(0)).unwrap(); @@ -826,15 +866,17 @@ mod tests { #[test] fn test_pederesen() { let mut dst = [0; 32]; - let lhs = Felt::from(1).to_bytes_le(); - let rhs = Felt::from(3).to_bytes_le(); + let lhs = Felt::from(1).to_monty_bytes_le(); + let rhs = Felt::from(3).to_monty_bytes_le(); unsafe { cairo_native__libfunc__pedersen(&mut dst, &lhs, &rhs); } assert_eq!( - dst, + montgomery::felt_from_monty_bytes(&dst) + .unwrap() + .to_bytes_le(), [ 84, 98, 174, 134, 3, 124, 237, 179, 166, 110, 159, 98, 170, 35, 83, 237, 130, 154, 236, 0, 205, 134, 200, 185, 39, 92, 0, 228, 132, 217, 130, 5 @@ -844,26 +886,26 @@ mod tests { #[test] fn test_hades_permutation() { - let mut op0 = Felt::from(1).to_bytes_le(); - let mut op1 = Felt::from(1).to_bytes_le(); - let mut op2 = Felt::from(1).to_bytes_le(); + let mut op0 = Felt::from(1).to_monty_bytes_le(); + let mut op1 = Felt::from(1).to_monty_bytes_le(); + let mut op2 = Felt::from(1).to_monty_bytes_le(); unsafe { cairo_native__libfunc__hades_permutation(&mut op0, &mut op1, &mut op2); } assert_eq!( - Felt::from_bytes_le(&op0), + montgomery::felt_from_monty_bytes(&op0).unwrap(), Felt::from_hex("0x4ebdde1149fcacbb41e4fc342432a48c97994fd045f432ad234ae9279269779") .unwrap() ); assert_eq!( - Felt::from_bytes_le(&op1), + montgomery::felt_from_monty_bytes(&op1).unwrap(), Felt::from_hex("0x7f4cec57dd08b69414f7de7dffa230fc90fa3993673c422408af05831e0cc98") .unwrap() ); assert_eq!( - Felt::from_bytes_le(&op2), + montgomery::felt_from_monty_bytes(&op2).unwrap(), Felt::from_hex("0x5b5d00fd09caade43caffe70527fa84d5d9cd51e22c2ce115693ecbb5854d6a") .unwrap() ); @@ -939,22 +981,22 @@ mod tests { "874739451078007766457464989774322083649278607533249481151382481072868806602", ) .unwrap() - .to_bytes_le(), + .to_monty_bytes_le(), Felt::from_dec_str( "152666792071518830868575557812948353041420400780739481342941381225525861407", ) .unwrap() - .to_bytes_le(), + .to_monty_bytes_le(), Felt::from_dec_str( "874739451078007766457464989774322083649278607533249481151382481072868806602", ) .unwrap() - .to_bytes_le(), + .to_monty_bytes_le(), Felt::from_dec_str( "152666792071518830868575557812948353041420400780739481342941381225525861407", ) .unwrap() - .to_bytes_le(), + .to_monty_bytes_le(), ]; let point = [ @@ -962,12 +1004,12 @@ mod tests { "874739451078007766457464989774322083649278607533249481151382481072868806602", ) .unwrap() - .to_bytes_le(), + .to_monty_bytes_le(), Felt::from_dec_str( "152666792071518830868575557812948353041420400780739481342941381225525861407", ) .unwrap() - .to_bytes_le(), + .to_monty_bytes_le(), ]; unsafe { @@ -980,7 +1022,7 @@ mod tests { "3324833730090626974525872402899302150520188025637965566623476530814354734325", ) .unwrap() - .to_bytes_le() + .to_monty_bytes_le() ); assert_eq!( state[1], @@ -988,7 +1030,7 @@ mod tests { "3147007486456030910661996439995670279305852583596209647900952752170983517249", ) .unwrap() - .to_bytes_le() + .to_monty_bytes_le() ); } } diff --git a/src/starknet.rs b/src/starknet.rs index 71e3d6a27c..d8659699f4 100644 --- a/src/starknet.rs +++ b/src/starknet.rs @@ -1,8 +1,11 @@ //! Starknet related code for `cairo_native` +use lambdaworks_math::errors::ByteConversionError; use serde::{Deserialize, Serialize}; use starknet_types_core::felt::Felt; +use crate::utils::montgomery; + pub type SyscallResult = std::result::Result>; #[repr(C)] @@ -14,8 +17,10 @@ pub struct ArrayAbi { pub capacity: u32, } -impl From<&ArrayAbi> for Vec { - fn from(value: &ArrayAbi) -> Self { +impl TryFrom<&ArrayAbi> for Vec { + type Error = ByteConversionError; + + fn try_from(value: &ArrayAbi) -> Result { unsafe { let since_offset = value.since as usize; let until_offset = value.until as usize; @@ -27,8 +32,8 @@ impl From<&ArrayAbi> for Vec { } } .iter() - .map(Felt::from) - .collect() + .map(Felt::try_from) + .collect::>() } } @@ -37,18 +42,22 @@ impl From<&ArrayAbi> for Vec { #[repr(C, align(16))] pub struct Felt252Abi(pub [u8; 32]); -impl From for Felt { - fn from(mut value: Felt252Abi) -> Felt { +impl TryFrom for Felt { + type Error = ByteConversionError; + + fn try_from(mut value: Felt252Abi) -> Result { value.0[31] &= 0x0F; - Felt::from_bytes_le(&value.0) + montgomery::felt_from_monty_bytes(&value.0) } } -impl From<&Felt252Abi> for Felt { - fn from(value: &Felt252Abi) -> Felt { +impl TryFrom<&Felt252Abi> for Felt { + type Error = ByteConversionError; + + fn try_from(value: &Felt252Abi) -> Result { let mut value = *value; value.0[31] &= 0x0F; - Felt::from_bytes_le(&value.0) + montgomery::felt_from_monty_bytes(&value.0) } } @@ -559,7 +568,7 @@ impl StarknetSyscallHandler for DummySyscallHandler { // TODO: Move to the correct place or remove if unused. See: https://github.com/lambdaclass/cairo_native/issues/1222 pub(crate) mod handler { use super::*; - use crate::utils::{libc_free, libc_malloc}; + use crate::utils::{libc_free, libc_malloc, montgomery::MontyBytes}; use std::{ alloc::Layout, fmt::Debug, @@ -991,7 +1000,10 @@ pub(crate) mod handler { err: ManuallyDrop::new(SyscallResultAbiErr { tag: 1u8, payload: unsafe { - let data: Vec<_> = e.iter().map(|x| Felt252Abi(x.to_bytes_le())).collect(); + let data: Vec<_> = e + .iter() + .map(|x| Felt252Abi(x.to_monty_bytes_le())) + .collect(); Self::alloc_mlir_array(&data) }, }), @@ -1010,7 +1022,7 @@ pub(crate) mod handler { Ok(x) => SyscallResultAbi { ok: ManuallyDrop::new(SyscallResultAbiOk { tag: 0u8, - payload: ManuallyDrop::new(Felt252Abi(x.to_bytes_le())), + payload: ManuallyDrop::new(Felt252Abi(x.to_monty_bytes_le())), }), }, Err(e) => Self::wrap_error(&e), @@ -1037,28 +1049,29 @@ pub(crate) mod handler { block_info_ptr.as_mut().block_number = x.block_info.block_number; block_info_ptr.as_mut().block_timestamp = x.block_info.block_timestamp; block_info_ptr.as_mut().sequencer_address = - Felt252Abi(x.block_info.sequencer_address.to_bytes_le()); + Felt252Abi(x.block_info.sequencer_address.to_monty_bytes_le()); let mut tx_info_ptr = NonNull::new(libc_malloc(size_of::()) as *mut TxInfoAbi) .unwrap(); tx_info_ptr.as_mut().version = - Felt252Abi(x.tx_info.version.to_bytes_le()); + Felt252Abi(x.tx_info.version.to_monty_bytes_le()); tx_info_ptr.as_mut().account_contract_address = - Felt252Abi(x.tx_info.account_contract_address.to_bytes_le()); + Felt252Abi(x.tx_info.account_contract_address.to_monty_bytes_le()); tx_info_ptr.as_mut().max_fee = x.tx_info.max_fee; tx_info_ptr.as_mut().signature = Self::alloc_mlir_array( &x.tx_info .signature .into_iter() - .map(|x| Felt252Abi(x.to_bytes_le())) + .map(|x| Felt252Abi(x.to_monty_bytes_le())) .collect::>(), ); tx_info_ptr.as_mut().transaction_hash = - Felt252Abi(x.tx_info.transaction_hash.to_bytes_le()); + Felt252Abi(x.tx_info.transaction_hash.to_monty_bytes_le()); tx_info_ptr.as_mut().chain_id = - Felt252Abi(x.tx_info.chain_id.to_bytes_le()); - tx_info_ptr.as_mut().nonce = Felt252Abi(x.tx_info.nonce.to_bytes_le()); + Felt252Abi(x.tx_info.chain_id.to_monty_bytes_le()); + tx_info_ptr.as_mut().nonce = + Felt252Abi(x.tx_info.nonce.to_monty_bytes_le()); let mut execution_info_ptr = NonNull::new(libc_malloc(size_of::()) @@ -1067,11 +1080,11 @@ pub(crate) mod handler { execution_info_ptr.as_mut().block_info = block_info_ptr; execution_info_ptr.as_mut().tx_info = tx_info_ptr; execution_info_ptr.as_mut().caller_address = - Felt252Abi(x.caller_address.to_bytes_le()); + Felt252Abi(x.caller_address.to_monty_bytes_le()); execution_info_ptr.as_mut().contract_address = - Felt252Abi(x.contract_address.to_bytes_le()); + Felt252Abi(x.contract_address.to_monty_bytes_le()); execution_info_ptr.as_mut().entry_point_selector = - Felt252Abi(x.entry_point_selector.to_bytes_le()); + Felt252Abi(x.entry_point_selector.to_monty_bytes_le()); ManuallyDrop::new(execution_info_ptr) }, @@ -1106,33 +1119,34 @@ pub(crate) mod handler { block_info_ptr.as_mut().block_number = x.block_info.block_number; block_info_ptr.as_mut().block_timestamp = x.block_info.block_timestamp; block_info_ptr.as_mut().sequencer_address = - Felt252Abi(x.block_info.sequencer_address.to_bytes_le()); + Felt252Abi(x.block_info.sequencer_address.to_monty_bytes_le()); let mut tx_info_ptr = NonNull::new( libc_malloc(size_of::()) as *mut TxInfoV2Abi, ) .unwrap(); tx_info_ptr.as_mut().version = - Felt252Abi(x.tx_info.version.to_bytes_le()); + Felt252Abi(x.tx_info.version.to_monty_bytes_le()); tx_info_ptr.as_mut().signature = Self::alloc_mlir_array( &x.tx_info .signature .into_iter() - .map(|x| Felt252Abi(x.to_bytes_le())) + .map(|x| Felt252Abi(x.to_monty_bytes_le())) .collect::>(), ); tx_info_ptr.as_mut().max_fee = x.tx_info.max_fee; tx_info_ptr.as_mut().transaction_hash = - Felt252Abi(x.tx_info.transaction_hash.to_bytes_le()); + Felt252Abi(x.tx_info.transaction_hash.to_monty_bytes_le()); tx_info_ptr.as_mut().chain_id = - Felt252Abi(x.tx_info.chain_id.to_bytes_le()); - tx_info_ptr.as_mut().nonce = Felt252Abi(x.tx_info.nonce.to_bytes_le()); + Felt252Abi(x.tx_info.chain_id.to_monty_bytes_le()); + tx_info_ptr.as_mut().nonce = + Felt252Abi(x.tx_info.nonce.to_monty_bytes_le()); tx_info_ptr.as_mut().resource_bounds = Self::alloc_mlir_array( &x.tx_info .resource_bounds .into_iter() .map(|x| ResourceBoundsAbi { - resource: Felt252Abi(x.resource.to_bytes_le()), + resource: Felt252Abi(x.resource.to_monty_bytes_le()), max_amount: x.max_amount, max_price_per_unit: x.max_price_per_unit, }) @@ -1143,7 +1157,7 @@ pub(crate) mod handler { &x.tx_info .paymaster_data .into_iter() - .map(|x| Felt252Abi(x.to_bytes_le())) + .map(|x| Felt252Abi(x.to_monty_bytes_le())) .collect::>(), ); tx_info_ptr.as_mut().nonce_data_availability_mode = @@ -1154,20 +1168,20 @@ pub(crate) mod handler { &x.tx_info .account_deployment_data .into_iter() - .map(|x| Felt252Abi(x.to_bytes_le())) + .map(|x| Felt252Abi(x.to_monty_bytes_le())) .collect::>(), ); tx_info_ptr.as_mut().account_contract_address = - Felt252Abi(x.tx_info.account_contract_address.to_bytes_le()); + Felt252Abi(x.tx_info.account_contract_address.to_monty_bytes_le()); execution_info_ptr.as_mut().block_info = block_info_ptr; execution_info_ptr.as_mut().tx_info = tx_info_ptr; execution_info_ptr.as_mut().caller_address = - Felt252Abi(x.caller_address.to_bytes_le()); + Felt252Abi(x.caller_address.to_monty_bytes_le()); execution_info_ptr.as_mut().contract_address = - Felt252Abi(x.contract_address.to_bytes_le()); + Felt252Abi(x.contract_address.to_monty_bytes_le()); execution_info_ptr.as_mut().entry_point_selector = - Felt252Abi(x.entry_point_selector.to_bytes_le()); + Felt252Abi(x.entry_point_selector.to_monty_bytes_le()); ManuallyDrop::new(execution_info_ptr) }, @@ -1186,10 +1200,14 @@ pub(crate) mod handler { calldata: &ArrayAbi, deploy_from_zero: bool, ) { - let class_hash = Felt::from(class_hash); - let contract_address_salt = Felt::from(contract_address_salt); - - let calldata_vec: Vec<_> = calldata.into(); + let class_hash = + Felt::try_from(class_hash).expect("Couldn't create felt from Montgomery bytes"); + let contract_address_salt = Felt::try_from(contract_address_salt) + .expect("Couldn't create felt from Montgomery bytes"); + + let calldata_vec: Vec<_> = calldata + .try_into() + .expect("Couldn't create array of felts from array of Montgomery bytes"); unsafe { Self::drop_mlir_array(calldata); } @@ -1204,12 +1222,18 @@ pub(crate) mod handler { *result_ptr = match result { Ok(x) => { - let felts: Vec<_> = x.1.iter().map(|x| Felt252Abi(x.to_bytes_le())).collect(); + let felts: Vec<_> = + x.1.iter() + .map(|x| Felt252Abi(x.to_monty_bytes_le())) + .collect(); let felts_ptr = unsafe { Self::alloc_mlir_array(&felts) }; SyscallResultAbi { ok: ManuallyDrop::new(SyscallResultAbiOk { tag: 0u8, - payload: ManuallyDrop::new((Felt252Abi(x.0.to_bytes_le()), felts_ptr)), + payload: ManuallyDrop::new(( + Felt252Abi(x.0.to_monty_bytes_le()), + felts_ptr, + )), }), } } @@ -1223,7 +1247,8 @@ pub(crate) mod handler { gas: &mut u64, class_hash: &Felt252Abi, ) { - let class_hash = Felt::from(class_hash); + let class_hash = + Felt::try_from(class_hash).expect("Couldn't create felt from Montgomery bytes"); let result = ptr.replace_class(class_hash, gas); *result_ptr = match result { @@ -1245,10 +1270,14 @@ pub(crate) mod handler { function_selector: &Felt252Abi, calldata: &ArrayAbi, ) { - let class_hash = Felt::from(class_hash); - let function_selector = Felt::from(function_selector); - - let calldata_vec: Vec = calldata.into(); + let class_hash = + Felt::try_from(class_hash).expect("Couldn't create felt from Montgomery bytes"); + let function_selector = Felt::try_from(function_selector) + .expect("Couldn't create felt from Montgomery bytes"); + + let calldata_vec: Vec = calldata + .try_into() + .expect("Couldn't create array of felts from array of Montgomery bytes"); unsafe { Self::drop_mlir_array(calldata); } @@ -1257,7 +1286,10 @@ pub(crate) mod handler { *result_ptr = match result { Ok(x) => { - let felts: Vec<_> = x.iter().map(|x| Felt252Abi(x.to_bytes_le())).collect(); + let felts: Vec<_> = x + .iter() + .map(|x| Felt252Abi(x.to_monty_bytes_le())) + .collect(); let felts_ptr = unsafe { Self::alloc_mlir_array(&felts) }; SyscallResultAbi { ok: ManuallyDrop::new(SyscallResultAbiOk { @@ -1278,10 +1310,14 @@ pub(crate) mod handler { entry_point_selector: &Felt252Abi, calldata: &ArrayAbi, ) { - let address = Felt::from(address); - let entry_point_selector = Felt::from(entry_point_selector); - - let calldata_vec: Vec = calldata.into(); + let address = + Felt::try_from(address).expect("Couldn't create felt from Montgomery bytes"); + let entry_point_selector = Felt::try_from(entry_point_selector) + .expect("Couldn't create felt from Montgomery bytes"); + + let calldata_vec: Vec = calldata + .try_into() + .expect("Couldn't create array of felts from array of Montgomery bytes"); unsafe { Self::drop_mlir_array(calldata); } @@ -1290,7 +1326,10 @@ pub(crate) mod handler { *result_ptr = match result { Ok(x) => { - let felts: Vec<_> = x.iter().map(|x| Felt252Abi(x.to_bytes_le())).collect(); + let felts: Vec<_> = x + .iter() + .map(|x| Felt252Abi(x.to_monty_bytes_le())) + .collect(); let felts_ptr = unsafe { Self::alloc_mlir_array(&felts) }; SyscallResultAbi { ok: ManuallyDrop::new(SyscallResultAbiOk { @@ -1310,14 +1349,15 @@ pub(crate) mod handler { address_domain: u32, address: &Felt252Abi, ) { - let address = Felt::from(address); + let address = + Felt::try_from(address).expect("Couldn't create felt from Montgomery bytes"); let result = ptr.storage_read(address_domain, address, gas); *result_ptr = match result { Ok(res) => SyscallResultAbi { ok: ManuallyDrop::new(SyscallResultAbiOk { tag: 0u8, - payload: ManuallyDrop::new(Felt252Abi(res.to_bytes_le())), + payload: ManuallyDrop::new(Felt252Abi(res.to_monty_bytes_le())), }), }, Err(e) => Self::wrap_error(&e), @@ -1332,8 +1372,9 @@ pub(crate) mod handler { address: &Felt252Abi, value: &Felt252Abi, ) { - let address = Felt::from(address); - let value = Felt::from(value); + let address = + Felt::try_from(address).expect("Couldn't create felt from Montgomery bytes"); + let value = Felt::try_from(value).expect("Couldn't create felt from Montgomery bytes"); let result = ptr.storage_write(address_domain, address, value, gas); *result_ptr = match result { @@ -1354,8 +1395,12 @@ pub(crate) mod handler { keys: &ArrayAbi, data: &ArrayAbi, ) { - let keys_vec: Vec<_> = keys.into(); - let data_vec: Vec<_> = data.into(); + let keys_vec: Vec<_> = keys + .try_into() + .expect("Couldn't create array of felts from array of Montgomery bytes"); + let data_vec: Vec<_> = data + .try_into() + .expect("Couldn't create array of felts from array of Montgomery bytes"); unsafe { Self::drop_mlir_array(keys); @@ -1382,8 +1427,11 @@ pub(crate) mod handler { to_address: &Felt252Abi, payload: &ArrayAbi, ) { - let to_address = Felt::from(to_address); - let payload_vec: Vec<_> = payload.into(); + let to_address = + Felt::try_from(to_address).expect("Couldn't create felt from Montgomery bytes"); + let payload_vec: Vec<_> = payload + .try_into() + .expect("Couldn't create array of felts from array of Montgomery bytes"); unsafe { Self::drop_mlir_array(payload); @@ -1691,13 +1739,18 @@ pub(crate) mod handler { gas: &mut u64, contract_address: &Felt252Abi, ) { - let result = ptr.get_class_hash_at(contract_address.into(), gas); + let result = ptr.get_class_hash_at( + contract_address + .try_into() + .expect("Couldn't create array of felts from array of Montgomery bytes"), + gas, + ); *result_ptr = match result { Ok(x) => SyscallResultAbi { ok: ManuallyDrop::new(SyscallResultAbiOk { tag: 0u8, - payload: ManuallyDrop::new(Felt252Abi(x.to_bytes_le())), + payload: ManuallyDrop::new(Felt252Abi(x.to_monty_bytes_le())), }), }, Err(e) => Self::wrap_error(&e), @@ -1713,14 +1766,20 @@ pub(crate) mod handler { calldata: &ArrayAbi, signature: &ArrayAbi, ) { - let address = Felt::from(address); - let entry_point_selector = Felt::from(entry_point_selector); - - let calldata_vec: Vec = calldata.into(); + let address = + Felt::try_from(address).expect("Couldn't create felt from Montgomery bytes"); + let entry_point_selector = Felt::try_from(entry_point_selector) + .expect("Couldn't create felt from Montgomery bytes"); + + let calldata_vec: Vec = calldata + .try_into() + .expect("Couldn't create array of felts from array of Montgomery bytes"); unsafe { Self::drop_mlir_array(calldata); } - let signature_vec: Vec = signature.into(); + let signature_vec: Vec = signature + .try_into() + .expect("Couldn't create array of felts from array of Montgomery bytes"); unsafe { Self::drop_mlir_array(signature); } @@ -1735,7 +1794,10 @@ pub(crate) mod handler { *result_ptr = match result { Ok(x) => { - let felts: Vec<_> = x.iter().map(|x| Felt252Abi(x.to_bytes_le())).collect(); + let felts: Vec<_> = x + .iter() + .map(|x| Felt252Abi(x.to_monty_bytes_le())) + .collect(); let felts_ptr = unsafe { Self::alloc_mlir_array(&felts) }; SyscallResultAbi { ok: ManuallyDrop::new(SyscallResultAbiOk { @@ -1755,8 +1817,10 @@ pub(crate) mod handler { selector: &Felt252Abi, input: &ArrayAbi, ) { - let selector = Felt::from(selector); - let input_vec: Vec<_> = input.into(); + let selector = Felt::from_bytes_le(&selector.0); + let input_vec: Vec<_> = input + .try_into() + .expect("Couldn't create array of felts from array of Montgomery bytes"); unsafe { Self::drop_mlir_array(input); @@ -1765,7 +1829,7 @@ pub(crate) mod handler { let result = ptr .cheatcode(selector, &input_vec) .into_iter() - .map(|x| Felt252Abi(x.to_bytes_le())) + .map(|x| Felt252Abi(x.to_monty_bytes_le())) .collect::>(); *result_ptr = unsafe { Self::alloc_mlir_array(&result) }; diff --git a/src/utils.rs b/src/utils.rs index 6398d96be5..5376731024 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -32,6 +32,7 @@ use std::{ use thiserror::Error; pub mod mem_tracing; +pub mod montgomery; mod program_registry_ext; mod range_ext; #[cfg(feature = "with-segfault-catcher")] diff --git a/src/utils/montgomery.rs b/src/utils/montgomery.rs new file mode 100644 index 0000000000..52a2b68cb7 --- /dev/null +++ b/src/utils/montgomery.rs @@ -0,0 +1,615 @@ +//! # Montgomery implementation for Felt252. +//! +//! This module holds utility functions for performing arithmetic operations +//! inside the Montgomery space. +//! +//! Representing felts in the Montgomery space allows for optimizations when +//! performing multiplication and division operations. This is because it +//! avoids having to perform modulo operations and even divisions. Montgomery +//! reduces these operations to shifts and simple arithmetic operation such as +//! additions and subtractions. +//! +//! The way this works is by representing a values as x' = x * r mod n. This +//! introduces a new constant `r` which, for performance reasons, it is defined +//! as r = 2^{k} where k should be big enough to satisfy r > n. +//! +//! For more information on check: https://en.wikipedia.org/wiki/Montgomery_modular_multiplication. + +use std::sync::LazyLock; + +use lambdaworks_math::{ + errors::{ByteConversionError, CreationError}, + traits::ByteConversion, + unsigned_integer::{ + element::{UnsignedInteger, U256}, + montgomery::MontgomeryAlgorithms, + }, +}; +use num_bigint::BigUint; +use num_traits::Num; +use starknet_types_core::felt::Felt; + +// R parameter for felts. R = 2^{256} which is the smallest power of 2 greater than prime. +pub static MONTY_R: LazyLock = LazyLock::new(|| BigUint::from(1u64) << 256); +// R2 parameter for felts. R2 = 2^{256 * 2} mod prime. This value is a U256 instead of a +// BigUint to integrate with lambdaworks with ease. +pub static MONTY_R2: LazyLock = LazyLock::new(|| { + UnsignedInteger::from_hex_unchecked( + "7FFD4AB5E008810FFFFFFFFFF6F800000000001330FFFFFFFFFFD737E000401", + ) +}); +// MU parameter for felts. MU = -prime^{-1} mod 2^{64}. The variant is used to +// allow a better integration with lambdaworks. +// Check: https://github.com/lambdaclass/lambdaworks/blob/main/crates/math/src/field/fields/montgomery_backed_prime_fields.rs#L60 +pub const MONTY_MU_U64: u64 = 18446744073709551615; +// MU parameter for felts. MU = prime^{-1} mod R. +pub static MONTY_MU_U256: LazyLock = LazyLock::new(|| { + BigUint::from_str_radix( + "f7ffffffffffffef000000000000000000000000000000000000000000000001", + 16, + ) + .expect("hardcoded mu constant should be valid") +}); + +pub trait MontyBytes { + fn to_monty_bytes_le(&self) -> [u8; 32]; +} + +impl MontyBytes for Felt { + /// Returns the raw bytes of a Felt, which are in Montgomery representation. + fn to_monty_bytes_le(&self) -> [u8; 32] { + let limbs = self.to_raw(); + let mut buffer = [0; 32]; + + for i in (0..4).rev() { + let bytes = limbs[i].to_le_bytes(); + let init = (3 - i) * 8; + buffer[init..init + 8].copy_from_slice(&bytes); + } + + buffer + } +} + +/// Utility function to convert Felt bytes in Montgomery form into a Felt with +/// its correct representation. +pub fn felt_from_monty_bytes(value: &[u8; 32]) -> Result { + let value = U256::from_bytes_le(value)?; + Ok(Felt::from_raw(value.limbs)) +} + +/// Computes the Montgomery reduction (REDC). +/// +/// Having a value `x' = x . r mod n`, the Montgomery reduction can be +/// interpreted as dividing `x by r mod n`, such that `REDC(x') = x`. +/// +/// For more info on this operation check: +/// https://en.wikipedia.org/wiki/Montgomery_modular_multiplication#The_REDC_algorithm. +pub fn monty_reduction(x: &BigUint, modulus: &BigUint) -> Result { + let x = U256::from_hex(&x.to_str_radix(16))?; + let modulus = U256::from_hex(&modulus.to_str_radix(16))?; + + let reduced = MontgomeryAlgorithms::cios(&x, &U256::from_u64(1), &modulus, &MONTY_MU_U64); + + Ok(BigUint::from_bytes_le(&reduced.to_bytes_le())) +} + +/// Computes the Montgomery transform operation. +/// +/// To efficiently perform this operation, a precomputed `r^{2}` value is used. +/// This way `x' = REDC(x * r^{2})`. Since we are multiplying by `r^{2}`, and we want +/// `x' = x * r mod n`, we need to apply a reduction after multiplication. +/// +/// For more info on this operation check: +/// https://en.wikipedia.org/wiki/Montgomery_modular_multiplication#Arithmetic_in_Montgomery_form. +pub fn monty_transform(x: &BigUint, modulus: &BigUint) -> Result { + let x = U256::from_hex(&x.to_str_radix(16))?; + let modulus = U256::from_hex(&modulus.to_str_radix(16))?; + + let reduced = MontgomeryAlgorithms::cios(&x, &MONTY_R2, &modulus, &MONTY_MU_U64); + + Ok(BigUint::from_bytes_le(&reduced.to_bytes_le())) +} + +pub mod mlir { + use crate::{ + error::Result, + utils::{ + montgomery::{MONTY_MU_U256, MONTY_R, MONTY_R2}, + PRIME, + }, + }; + use melior::{ + dialect::{arith, ods, scf}, + helpers::{ArithBlockExt, BuiltinBlockExt}, + ir::{r#type::IntegerType, Block, BlockLike, Location, Region, Type, Value, ValueLike}, + Context, + }; + + /// Computes Montgomery multiplication in MLIR. + pub fn monty_mul<'c, 'a>( + context: &'c Context, + block: &'a Block<'c>, + lhs: Value<'c, '_>, + rhs: Value<'c, '_>, + res_ty: Type<'c>, + location: Location<'c>, + ) -> Result> { + let i512 = IntegerType::new(context, 512).into(); + + let lhs = block.extui(lhs, i512, location)?; + let rhs = block.extui(rhs, i512, location)?; + + let t = block.muli(lhs, rhs, location)?; + + monty_reduce(context, block, t, res_ty, location) + } + + /// Computes Montgomery division in MLIR. + pub fn monty_div<'c, 'a>( + context: &'c Context, + block: &'a Block<'c>, + lhs: Value<'c, '_>, + rhs: Value<'c, '_>, + res_ty: Type<'c>, + location: Location<'c>, + ) -> Result> { + let inv_rhs = monty_inverse(context, block, rhs, location)?; + monty_mul(context, block, lhs, inv_rhs, res_ty, location) + } + + /// Compute Montgomery modular inverse. + /// + /// The algorithm is given by B. S. Kaliski Jr. in "The Montgomery Inverse + /// and Its Applications". The algorithm consists of two phases: + /// 1. Compute x = a^{-1}2^{k} mod p, where n < k < 2n (denoted as + /// almost inverse). + /// 2. Corrects the result from phase 1 so that x = a^{-1}2^{n} mod p. + /// The algorithm can also be checked here: + /// https://www.researchgate.net/publication/3044233_The_Montgomery_modular_inverse_-_Revisited. + fn monty_inverse<'c, 'a>( + context: &'c Context, + block: &'a Block<'c>, + value: Value<'c, '_>, + location: Location<'c>, + ) -> Result> { + let value = block.extui(value, IntegerType::new(context, 256).into(), location)?; + let (r, k) = almost_inverse(context, block, value, location)?; + let inverse = inverse_correction(context, block, r, k, location)?; + + // Since value is already in its Montgomery form, we currently that + // inverse = MontyInv(value) = (value * r)^{-1} mod n. Since we need + // inverse = value^{-1} * r mod n, we still need to perform one more + // correction. So, + // inverse = MontyProd(inverse, r^{2}) = value^{-1} * r mod n. + let r2 = block.const_int_from_type(context, location, *MONTY_R2, inverse.r#type())?; + monty_mul(context, block, inverse, r2, inverse.r#type(), location) + } + + /// Performs the Montgomery inverse correction. + /// + /// This algorithm represented phase 2 of of B. S. Kaliski Jr.'s Montgomery + /// inverse which returns (a * 2^{m})^{-1} mod n, where `a` is the value + /// to invert and `m` the smallest value such that 2^{m} > n. + fn inverse_correction<'c, 'a>( + context: &'c Context, + block: &'a Block<'c>, + r: Value<'c, '_>, + k: Value<'c, '_>, + location: Location<'c>, + ) -> Result> { + let i16 = IntegerType::new(context, 16).into(); + let i256 = IntegerType::new(context, 256).into(); + + let k0 = block.const_int(context, location, 0, 256)?; + let k0_i16 = block.const_int(context, location, 0, 16)?; + let k1_i16 = block.const_int(context, location, 1, 16)?; + let k1 = block.const_int(context, location, 1, 256)?; + let k256 = block.const_int(context, location, 256, 16)?; + + let loop_limit = block.subi(k, k256, location)?; + + let result = block.append_operation( + ods::scf::r#for( + context, + &[i256], + k0_i16, + loop_limit, + k1_i16, + &[r], + { + let region = Region::new(); + let loop_block = + region.append_block(Block::new(&[(i16, location), (i256, location)])); + + let r = loop_block.arg(1)?; + + let r_and_one = loop_block.andi(r, k1, location)?; + let is_r_even = loop_block.cmpi( + context, + arith::CmpiPredicate::Eq, + r_and_one, + k0, + location, + )?; + + let next_r = loop_block.append_op_result(scf::r#if( + is_r_even, + &[i256], + { + let region = Region::new(); + let block_then = region.append_block(Block::new(&[])); + + let result = block_then.shrui(r, k1, location)?; + + block_then.append_operation(scf::r#yield(&[result], location)); + + region + }, + { + let region = Region::new(); + let block_else = region.append_block(Block::new(&[])); + + let prime = + block_else.const_int(context, location, PRIME.clone(), 256)?; + + let result = block_else.addi(r, prime, location)?; + let result = block_else.shrui(result, k1, location)?; + + block_else.append_operation(scf::r#yield(&[result], location)); + + region + }, + location, + ))?; + + loop_block.append_operation(scf::r#yield(&[next_r], location)); + + region + }, + location, + ) + .into(), + ); + + Ok(result.result(0)?.into()) + } + + /// Performs a first approach to the Montgomery Inverse. + /// + /// This algorithm represents phase 1 of B. S. Kaliski Jr.'s Montgomery + /// inverse which returns `alm_inv = (a * 2^{k})^{-1} mod n`, where `a` is + /// the value to invert and `k` a value such that `m < k < 2 * m`, being + /// `m` the smallest value such that 2^{m} > n. + fn almost_inverse<'c, 'a>( + context: &'c Context, + block: &'a Block<'c>, + value: Value<'c, '_>, + location: Location<'c>, + ) -> Result<(Value<'c, 'a>, Value<'c, 'a>)> { + let i16 = IntegerType::new(context, 16).into(); + let value_ty = value.r#type(); + + let k0 = block.const_int_from_type(context, location, 0, value_ty)?; + let k0_i16 = block.const_int(context, location, 0, 16)?; + let prime = block.const_int_from_type(context, location, PRIME.clone(), value_ty)?; + let k1 = block.const_int_from_type(context, location, 1, value_ty)?; + let k1_i16 = block.const_int(context, location, 1, 16)?; + + let result = block.append_operation(scf::r#while( + &[prime, value, k0, k1, k0_i16], + &[value_ty, value_ty, value_ty, value_ty, i16], + { + let region = Region::new(); + let cond_block = region.append_block(Block::new(&[ + (value_ty, location), + (value_ty, location), + (value_ty, location), + (value_ty, location), + (i16, location), + ])); + let u = cond_block.arg(0)?; + let v = cond_block.arg(1)?; + let r = cond_block.arg(2)?; + let s = cond_block.arg(3)?; + + let u_is_even = { + let u_and_one = cond_block.andi(u, k1, location)?; + cond_block.cmpi(context, arith::CmpiPredicate::Eq, u_and_one, k0, location)? + }; + + // if u is even then + // u = u / 2 + // s = 2 * s + // else if v is even then + // v = v / 2 + // s = 2 * s + // else if u > v then + // u = (u − v) / 2 + // r = r + s + // s = 2 * s + // else if u <= v then + // v = (v − u) / 2 + // s = r + s + // r = 2 * r + let result = cond_block.append_operation(scf::r#if( + u_is_even, + &[value_ty, value_ty, value_ty, value_ty], + { + let region = Region::new(); + let u_even_block = region.append_block(Block::new(&[])); + + let u = u_even_block.shrui(u, k1, location)?; + let s = u_even_block.shli(s, k1, location)?; + + u_even_block.append_operation(scf::r#yield(&[u, v, r, s], location)); + + region + }, + { + let region = Region::new(); + let u_not_even_block = region.append_block(Block::new(&[])); + + let v_is_even = { + let v_and_one = u_not_even_block.andi(v, k1, location)?; + u_not_even_block.cmpi( + context, + arith::CmpiPredicate::Eq, + v_and_one, + k0, + location, + )? + }; + + let result = u_not_even_block.append_operation(scf::r#if( + v_is_even, + &[value_ty, value_ty, value_ty, value_ty], + { + let region = Region::new(); + let v_even_block = region.append_block(Block::new(&[])); + + let v = v_even_block.shrui(v, k1, location)?; + let r = v_even_block.shli(r, k1, location)?; + + v_even_block + .append_operation(scf::r#yield(&[u, v, r, s], location)); + + region + }, + { + let region = Region::new(); + let v_not_even_block = region.append_block(Block::new(&[])); + + let is_u_gt_v = v_not_even_block.cmpi( + context, + arith::CmpiPredicate::Ugt, + u, + v, + location, + )?; + + let result = v_not_even_block.append_operation(scf::r#if( + is_u_gt_v, + &[value_ty, value_ty, value_ty, value_ty], + { + let region = Region::new(); + let u_gt_v_block = region.append_block(Block::new(&[])); + + let u = { + let u_min_v = u_gt_v_block.subi(u, v, location)?; + u_gt_v_block.shrui(u_min_v, k1, location)? + }; + let r = u_gt_v_block.addi(r, s, location)?; + let s = u_gt_v_block.shli(s, k1, location)?; + + u_gt_v_block.append_operation(scf::r#yield( + &[u, v, r, s], + location, + )); + + region + }, + { + let region = Region::new(); + let v_ge_u_block = region.append_block(Block::new(&[])); + + let v = { + let v_min_u = v_ge_u_block.subi(v, u, location)?; + v_ge_u_block.shrui(v_min_u, k1, location)? + }; + let s = v_ge_u_block.addi(r, s, location)?; + let r = v_ge_u_block.shli(r, k1, location)?; + + v_ge_u_block.append_operation(scf::r#yield( + &[u, v, r, s], + location, + )); + + region + }, + location, + )); + + let u = result.result(0)?.into(); + let v = result.result(1)?.into(); + let r = result.result(2)?.into(); + let s = result.result(3)?.into(); + + v_not_even_block + .append_operation(scf::r#yield(&[u, v, r, s], location)); + + region + }, + location, + )); + + let u = result.result(0)?.into(); + let v = result.result(1)?.into(); + let r = result.result(2)?.into(); + let s = result.result(3)?.into(); + + u_not_even_block.append_operation(scf::r#yield(&[u, v, r, s], location)); + region + }, + location, + )); + + let u = result.result(0)?.into(); + let v = result.result(1)?.into(); + let r = result.result(2)?.into(); + let s = result.result(3)?.into(); + let k = cond_block.addi(cond_block.arg(4)?, k1_i16, location)?; + + let is_v_gt_zero = + cond_block.cmpi(context, arith::CmpiPredicate::Ugt, v, k0, location)?; + + cond_block.append_operation(scf::condition( + is_v_gt_zero, + &[u, v, r, s, k], + location, + )); + + region + }, + { + let region = Region::new(); + let loop_block = region.append_block(Block::new(&[ + (value_ty, location), + (value_ty, location), + (value_ty, location), + (value_ty, location), + (i16, location), + ])); + + let u = loop_block.arg(0)?; + let v = loop_block.arg(1)?; + let r = loop_block.arg(2)?; + let s = loop_block.arg(3)?; + let k = loop_block.arg(4)?; + + loop_block.append_operation(scf::r#yield(&[u, v, r, s, k], location)); + + region + }, + location, + )); + + let (almost_inv, k) = { + // if r >= p: + // r = r − p + // else: + // r + // return (p - r), k + let k = result.result(4)?.into(); + let r = { + let r = result.result(2)?.into(); + let r_wrapped = block.subi(r, prime, location)?; + let r_ge_prime = + block.cmpi(context, arith::CmpiPredicate::Uge, r, prime, location)?; + let r = + block.append_op_result(arith::select(r_ge_prime, r_wrapped, r, location))?; + + block.subi(prime, r, location)? + }; + + (r, k) + }; + + Ok((almost_inv, k)) + } + + /// Computes Montgomery reduction in MLIR. + pub fn monty_reduce<'c, 'a>( + context: &'c Context, + block: &'a Block<'c>, + x: Value<'c, '_>, + res_ty: Type<'c>, + location: Location<'c>, + ) -> Result> { + let x = block.extui(x, IntegerType::new(context, 512).into(), location)?; + let mu = block.const_int(context, location, MONTY_MU_U256.clone(), 512)?; + let r_minus_1 = block.const_int(context, location, MONTY_R.clone() - 1u8, 512)?; + let k256 = block.const_int(context, location, 256, 512)?; + let modulus = block.const_int(context, location, PRIME.clone(), 512)?; + + // q = (value * mu) mod r. + let q = block.muli(x, mu, location)?; + let q = block.andi(q, r_minus_1, location)?; + // m = q * modulus. + let m = block.muli(q, modulus, location)?; + // y = (value - m) / r. + let y = block.subi(x, m, location)?; + let y = block.shrui(y, k256, location)?; + // if (m > x): + // y = y + modulus + let y_plus_mod = block.addi(y, modulus, location)?; + + let is_negative = block.cmpi(context, arith::CmpiPredicate::Ugt, m, x, location)?; + + let value = block.append_op_result(arith::select(is_negative, y_plus_mod, y, location))?; + Ok(block.trunci(value, res_ty, location)?) + } + + /// Computes to Montgomery space conversion in MLIR. + pub fn monty_transform<'c, 'a>( + context: &'c Context, + block: &'a Block<'c>, + x: Value<'c, '_>, + res_ty: Type<'c>, + location: Location<'c>, + ) -> Result> { + let r2 = block.const_int(context, location, *MONTY_R2, 257)?; + monty_mul(context, block, x, r2, res_ty, location) + } +} + +#[cfg(test)] +mod tests { + use crate::utils::{ + montgomery::{self, monty_reduction, monty_transform, MontyBytes}, + PRIME, + }; + use starknet_types_core::felt::Felt; + + #[test] + fn felt_to_bytes_raw() { + let felt = Felt::from(10); + let bytes = felt.to_monty_bytes_le(); + let felt_from_raw = montgomery::felt_from_monty_bytes(&bytes).unwrap(); + + assert_eq!(felt_from_raw, felt); + + let felt = Felt::from(-10); + let bytes = felt.to_monty_bytes_le(); + let felt_from_raw = montgomery::felt_from_monty_bytes(&bytes).unwrap(); + + assert_eq!(felt_from_raw, felt); + + let felt = Felt::from(PRIME.clone()); + let bytes = felt.to_monty_bytes_le(); + let felt_from_raw = montgomery::felt_from_monty_bytes(&bytes).unwrap(); + + assert_eq!(felt_from_raw, felt); + } + + #[test] + fn felt_to_monty_to_felt() { + let felt = Felt::from(10).to_biguint(); + let monty_felt = monty_transform(&felt, &PRIME).unwrap(); + let reduced_monty_felt = monty_reduction(&monty_felt, &PRIME).unwrap(); + + assert_eq!(reduced_monty_felt, felt); + + let felt = Felt::from(-10).to_biguint(); + let monty_felt = monty_transform(&felt, &PRIME).unwrap(); + let reduced_monty_felt = monty_reduction(&monty_felt, &PRIME).unwrap(); + + assert_eq!(reduced_monty_felt, felt); + + let felt = Felt::from(PRIME.clone()).to_biguint(); + let monty_felt = monty_transform(&felt, &PRIME).unwrap(); + let reduced_monty_felt = monty_reduction(&monty_felt, &PRIME).unwrap(); + + assert_eq!(reduced_monty_felt, felt); + } +} diff --git a/src/values.rs b/src/values.rs index c561cf6ba2..9720758e4f 100644 --- a/src/values.rs +++ b/src/values.rs @@ -9,7 +9,9 @@ use crate::{ starknet::{Secp256k1Point, Secp256r1Point}, types::TypeBuilder, utils::{ - felt252_bigint, get_integer_layout, layout_repeat, libc_free, libc_malloc, RangeExt, PRIME, + felt252_bigint, get_integer_layout, layout_repeat, libc_free, libc_malloc, + montgomery::{self, MontyBytes}, + RangeExt, PRIME, }, }; use bumpalo::Bump; @@ -175,7 +177,9 @@ impl Value { Self::Felt252(value) => { let ptr = arena.alloc_layout(get_integer_layout(252)).cast(); - let data = felt252_bigint(value.to_bigint()).to_bytes_le(); + // Felts are represented in Montgomery form. Due to this, + // we need to take its raw bytes. + let data = value.to_monty_bytes_le(); ptr.cast::<[u8; 32]>().as_mut().copy_from_slice(&data); ptr } @@ -431,7 +435,9 @@ impl Value { // next key must be called before next_value for (key, value) in map.iter() { - let key = key.to_bytes_le(); + // Felts are represented in Montgomery form. Due to this, + // we need to take its raw bytes. + let key = key.to_monty_bytes_le(); let value = value.to_ptr(arena, registry, &info.ty, find_dict_drop_override)?; @@ -521,8 +527,8 @@ impl Value { .alloc_layout(layout_repeat(&get_integer_layout(252), 2)?.0.pad_to_align()) .cast(); - let a = felt252_bigint(a.to_bigint()).to_bytes_le(); - let b = felt252_bigint(b.to_bigint()).to_bytes_le(); + let a = felt252_bigint(a.to_bigint()).to_monty_bytes_le(); + let b = felt252_bigint(b.to_bigint()).to_monty_bytes_le(); let data = [a, b]; ptr.cast::<[[u8; 32]; 2]>().as_mut().copy_from_slice(&data); @@ -534,10 +540,10 @@ impl Value { .alloc_layout(layout_repeat(&get_integer_layout(252), 4)?.0.pad_to_align()) .cast(); - let a = felt252_bigint(a.to_bigint()).to_bytes_le(); - let b = felt252_bigint(b.to_bigint()).to_bytes_le(); - let c = felt252_bigint(c.to_bigint()).to_bytes_le(); - let d = felt252_bigint(d.to_bigint()).to_bytes_le(); + let a = felt252_bigint(a.to_bigint()).to_monty_bytes_le(); + let b = felt252_bigint(b.to_bigint()).to_monty_bytes_le(); + let c = felt252_bigint(c.to_bigint()).to_monty_bytes_le(); + let d = felt252_bigint(d.to_bigint()).to_monty_bytes_le(); let data = [a, b, c, d]; ptr.cast::<[[u8; 32]; 4]>().as_mut().copy_from_slice(&data); @@ -719,7 +725,12 @@ impl Value { data[0][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). data[1][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). - Self::EcPoint(Felt::from_bytes_le(&data[0]), Felt::from_bytes_le(&data[1])) + let x = montgomery::felt_from_monty_bytes(&data[0]) + .to_native_assert_error("Couldn't create felt from Montgomery bytes")?; + let y = montgomery::felt_from_monty_bytes(&data[1]) + .to_native_assert_error("Couldn't create felt from Montgomery bytes")?; + + Self::EcPoint(x, y) } CoreTypeConcrete::EcState(_) => { let data = ptr.cast::<[[u8; 32]; 4]>().as_mut(); @@ -729,18 +740,27 @@ impl Value { data[2][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). data[3][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). - Self::EcState( - Felt::from_bytes_le(&data[0]), - Felt::from_bytes_le(&data[1]), - Felt::from_bytes_le(&data[2]), - Felt::from_bytes_le(&data[3]), - ) + let limb0 = montgomery::felt_from_monty_bytes(&data[0]) + .to_native_assert_error("Couldn't create felt from Montgomery bytes")?; + let limb1 = montgomery::felt_from_monty_bytes(&data[1]) + .to_native_assert_error("Couldn't create felt from Montgomery bytes")?; + let limb2 = montgomery::felt_from_monty_bytes(&data[2]) + .to_native_assert_error("Couldn't create felt from Montgomery bytes")?; + let limb3 = montgomery::felt_from_monty_bytes(&data[3]) + .to_native_assert_error("Couldn't create felt from Montgomery bytes")?; + + Self::EcState(limb0, limb1, limb2, limb3) } CoreTypeConcrete::Felt252(_) => { let data = ptr.cast::<[u8; 32]>().as_mut(); data[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). - let data = Felt::from_bytes_le_slice(data); - Self::Felt252(data) + + // Felts are represented in Montgomery form. Due to this, we + // need to convert them back to their original representation. + Self::Felt252( + montgomery::felt_from_monty_bytes(data) + .to_native_assert_error("Couldn't create felt from Montgomery bytes")?, + ) } CoreTypeConcrete::Uint8(_) => Self::Uint8(*ptr.cast::().as_ref()), CoreTypeConcrete::Uint16(_) => Self::Uint16(*ptr.cast::().as_ref()), @@ -866,7 +886,10 @@ impl Value { let mut key = key; key[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). - let key = Felt::from_bytes_le(&key); + // Felts are represented in Montgomery form. Due to this, we + // need to convert them back to their original representation. + let key = montgomery::felt_from_monty_bytes(&key) + .to_native_assert_error("Couldn't create felt from Montgomery bytes")?; // The dictionary items are not being dropped here. They'll be dropped along // with the dictionary (if requested using `should_drop`). output_map.insert( @@ -920,8 +943,14 @@ impl Value { // felt values let data = ptr.cast::<[u8; 32]>().as_mut(); data[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). - let data = Felt::from_bytes_le(data); - Self::Felt252(data) + + // Felts are represented in Montgomery form. Due to this, we + // need to convert them back to their original representation. + Self::Felt252( + montgomery::felt_from_monty_bytes(data).to_native_assert_error( + "Couldn't create felt from Montgomery bytes", + )?, + ) } StarknetTypeConcrete::System(_) => { native_panic!("should be handled before") @@ -1195,6 +1224,7 @@ mod test { let registry = ProgramRegistry::::new(&program).unwrap(); + // Assert bytes of Montgomery form of 42_felt252. assert_eq!( unsafe { *Value::Felt252(Felt::from(42)) @@ -1208,9 +1238,13 @@ mod test { .cast::<[u32; 8]>() .as_ptr() }, - [42, 0, 0, 0, 0, 0, 0, 0] + [ + 4294965953, 4294967295, 4294967295, 4294967295, 4294967295, 4294967295, 4294944464, + 134217727 + ] ); + // Assert bytes of Montgomery form of Felt::MAX. assert_eq!( unsafe { *Value::Felt252(Felt::MAX) @@ -1221,11 +1255,10 @@ mod test { |_| todo!(), ) .unwrap() - .cast::<[u32; 8]>() + .cast::<[u8; 32]>() .as_ptr() }, - // 0x800000000000011000000000000000000000000000000000000000000000001 - 1 - [0, 0, 0, 0, 0, 0, 17, 134217728] + Felt::MAX.to_monty_bytes_le() ); assert_eq!( @@ -1493,10 +1526,13 @@ mod test { |_| todo!(), ) .unwrap() - .cast::<[[u32; 8]; 2]>() + .cast::<[[u8; 32]; 2]>() .as_ptr() }, - [[1234, 0, 0, 0, 0, 0, 0, 0], [4321, 0, 0, 0, 0, 0, 0, 0]] + [ + Felt::from(1234).to_monty_bytes_le(), + Felt::from(4321).to_monty_bytes_le() + ] ); } @@ -1523,14 +1559,14 @@ mod test { |_| todo!(), ) .unwrap() - .cast::<[[u32; 8]; 4]>() + .cast::<[[u8; 32]; 4]>() .as_ptr() }, [ - [1234, 0, 0, 0, 0, 0, 0, 0], - [4321, 0, 0, 0, 0, 0, 0, 0], - [3333, 0, 0, 0, 0, 0, 0, 0], - [4444, 0, 0, 0, 0, 0, 0, 0] + Felt::from(1234).to_monty_bytes_le(), + Felt::from(4321).to_monty_bytes_le(), + Felt::from(3333).to_monty_bytes_le(), + Felt::from(4444).to_monty_bytes_le() ] ); }