diff --git a/crates/circuits/primitives/cuda/include/primitives/histogram.cuh b/crates/circuits/primitives/cuda/include/primitives/histogram.cuh index a6809df581..7853236457 100644 --- a/crates/circuits/primitives/cuda/include/primitives/histogram.cuh +++ b/crates/circuits/primitives/cuda/include/primitives/histogram.cuh @@ -80,6 +80,33 @@ struct VariableRangeChecker { } #ifdef CUDA_DEBUG assert(bits_remaining == 0 && x == 0); +#endif + } + + __device__ __forceinline__ void decompose_new( + uint32_t x, + size_t bits, + RowSliceNew limbs, + const size_t limbs_len + ) { + size_t range_max_bits = max_bits(); +#ifdef CUDA_DEBUG + assert(limbs_len >= d_div_ceil(bits, range_max_bits)); +#endif + uint32_t mask = (1 << range_max_bits) - 1; + size_t bits_remaining = bits; +#pragma unroll + for (int i = 0; i < limbs_len; i++) { + uint32_t limb_u32 = x & mask; + limbs.write_new(i, limb_u32); + if (!limbs.is_apc) { + add_count(limb_u32, min(bits_remaining, range_max_bits)); + } + x >>= range_max_bits; + bits_remaining -= min(bits_remaining, range_max_bits); + } +#ifdef CUDA_DEBUG + assert(bits_remaining == 0 && x == 0); #endif } }; diff --git a/crates/circuits/primitives/cuda/include/primitives/less_than.cuh b/crates/circuits/primitives/cuda/include/primitives/less_than.cuh index 6d24759fdc..1c4cf90cf9 100644 --- a/crates/circuits/primitives/cuda/include/primitives/less_than.cuh +++ b/crates/circuits/primitives/cuda/include/primitives/less_than.cuh @@ -40,6 +40,17 @@ __device__ __forceinline__ void generate_subrow( ) { rc.decompose(y - x - 1, max_bits, lower_decomp, lower_decomp_len); } + +__device__ __forceinline__ void generate_subrow_new( + VariableRangeChecker &rc, + const uint32_t max_bits, + uint32_t x, + uint32_t y, + const size_t lower_decomp_len, + RowSliceNew lower_decomp +) { + rc.decompose_new(y - x - 1, max_bits, lower_decomp, lower_decomp_len); +} } // namespace AssertLessThan namespace IsLessThan { diff --git a/crates/circuits/primitives/cuda/include/primitives/row_print_buffer.cuh b/crates/circuits/primitives/cuda/include/primitives/row_print_buffer.cuh new file mode 100644 index 0000000000..64a8e25908 --- /dev/null +++ b/crates/circuits/primitives/cuda/include/primitives/row_print_buffer.cuh @@ -0,0 +1,56 @@ +#pragma once + +#include + +// Utility buffer to print a single APC row atomically from device code. +struct RowPrintBuffer { + static constexpr int kCapacity = 8192; + char data[kCapacity]; + int len; + + __device__ __forceinline__ void reset() { len = 0; } + + __device__ __forceinline__ void append_char(char c) { + if (len < kCapacity - 1) { + data[len++] = c; + } + } + + __device__ __forceinline__ void append_literal(const char *literal) { + for (const char *ptr = literal; *ptr != '\0'; ++ptr) { + append_char(*ptr); + } + } + + __device__ __forceinline__ void append_uint(unsigned long long value) { + char tmp[32]; + int tmp_len = 0; + + if (value == 0) { + tmp[tmp_len++] = '0'; + } else { + while (value > 0 && tmp_len < static_cast(sizeof(tmp))) { + tmp[tmp_len++] = static_cast('0' + (value % 10)); + value /= 10; + } + } + + for (int i = tmp_len - 1; i >= 0; --i) { + append_char(tmp[i]); + } + } + + __device__ __forceinline__ void flush() { + data[len] = '\0'; + printf("%s", data); + } + + // Execute `fn` with this buffer after clearing it, then flush. + // `fn` must be a device callable accepting `RowPrintBuffer &`. + template + __device__ __forceinline__ void write_with(Fn fn) { + reset(); + fn(*this); + flush(); + } +}; diff --git a/crates/circuits/primitives/cuda/include/primitives/trace_access.h b/crates/circuits/primitives/cuda/include/primitives/trace_access.h index 8507db1138..7ca2b5c5a4 100644 --- a/crates/circuits/primitives/cuda/include/primitives/trace_access.h +++ b/crates/circuits/primitives/cuda/include/primitives/trace_access.h @@ -1,7 +1,112 @@ #pragma once #include "fp.h" +#include "primitives/row_print_buffer.cuh" #include +#include +#include + + +__device__ __forceinline__ size_t number_of_gaps_in(const uint32_t *sub, size_t start, size_t len); + +/// A RowSlice is a contiguous section of a row in col-based trace. +struct RowSliceNew { + Fp *ptr; + size_t stride; + size_t optimized_offset; + size_t dummy_offset; + uint32_t *subs; + bool is_apc; + + + __device__ RowSliceNew(Fp *ptr, size_t stride, size_t optimized_offset, size_t dummy_offset, uint32_t *subs, bool is_apc) : ptr(ptr), stride(stride), optimized_offset(optimized_offset), dummy_offset(dummy_offset), subs(subs), is_apc(is_apc) {} + + __device__ __forceinline__ Fp &operator[](size_t column_index) const { + // While implementing tracegen for SHA256, we encountered what we believe to be an nvcc + // compiler bug. Occasionally, at various non-zero PTXAS optimization levels the compiler + // tries to replace this multiplication with a series of SHL, ADD, and AND instructions + // that we believe erroneously adds ~2^49 to the final address via an improper carry + // propagation. To read more, see https://github.com/stephenh-axiom-xyz/cuda-illegal. + return ptr[column_index * stride]; + } + + __device__ static RowSliceNew null() { return RowSliceNew(nullptr, 0, 0, 0, nullptr, false); } + + __device__ bool is_valid() const { return ptr != nullptr; } + + template + __device__ __forceinline__ void write(size_t column_index, T value) const { + ptr[column_index * stride] = value; + } + + template + __device__ __forceinline__ void write_new(size_t column_index, T value) const { + if (is_apc) { + const uint32_t apc_idx = subs[dummy_offset + column_index]; + if (apc_idx != UINT32_MAX) { + ptr[(apc_idx - optimized_offset) * stride] = value; + } + } else { + ptr[column_index * stride] = value; + } + } + + template + __device__ __forceinline__ void write_array(size_t column_index, size_t length, const T *values) + const { +#pragma unroll + for (size_t i = 0; i < length; i++) { + ptr[(column_index + i) * stride] = values[i]; + } + } + + template + __device__ __forceinline__ void write_array_new(size_t column_index, size_t length, const T *values) + const { + if (is_apc) { + #pragma unroll + for (size_t i = 0; i < length; i++) { + const uint32_t apc_idx = subs[dummy_offset + column_index + i]; + if (apc_idx != UINT32_MAX) { + ptr[(apc_idx - optimized_offset) * stride] = values[i]; + } + } + } else { + #pragma unroll + for (size_t i = 0; i < length; i++) { + ptr[(column_index + i) * stride] = values[i]; + } + } + } + + template + __device__ __forceinline__ void write_bits(size_t column_index, const T value) const { +#pragma unroll + for (size_t i = 0; i < sizeof(T) * 8; i++) { + ptr[(column_index + i) * stride] = (value >> i) & 1; + } + } + + __device__ __forceinline__ void fill_zero(size_t column_index_from, size_t length) const { +#pragma unroll + for (size_t i = 0, c = column_index_from; i < length; i++, c++) { + ptr[c * stride] = 0; + } + } + + __device__ __forceinline__ RowSliceNew slice_from(size_t column_index) const { + if (is_apc) { + uint32_t gap = number_of_gaps_in(subs, dummy_offset, column_index); + return RowSliceNew(ptr + (column_index - gap) * stride, stride, optimized_offset + column_index - gap, dummy_offset + column_index, subs, is_apc); + } else { + return RowSliceNew(ptr + column_index * stride, stride, 0, 0, nullptr, false); + } + } + + __device__ __forceinline__ RowSliceNew shift_row(size_t n) const { + return RowSliceNew(ptr + n, stride, optimized_offset, dummy_offset, subs, is_apc); + } +}; /// A RowSlice is a contiguous section of a row in col-based trace. struct RowSlice { @@ -61,6 +166,16 @@ struct RowSlice { } }; +template +__device__ __forceinline__ unsigned long long to_debug_uint(T value) { + using Base = std::remove_cv_t>; + if constexpr (std::is_same_v) { + return static_cast(value.asRaw()); + } else { + return static_cast(value); + } +} + /// Compute the 0-based column index of member `FIELD` within struct template `STRUCT`, /// by instantiating it as `STRUCT` so that offsetof yields the element index. #define COL_INDEX(STRUCT, FIELD) (offsetof(STRUCT, FIELD)) @@ -71,10 +186,17 @@ struct RowSlice { /// Write a single value into `FIELD` of struct `STRUCT` at a given row. #define COL_WRITE_VALUE(ROW, STRUCT, FIELD, VALUE) (ROW).write(COL_INDEX(STRUCT, FIELD), VALUE) +/// Write a single value into `FIELD` of struct `STRUCT` at a given row. +#define COL_WRITE_VALUE_NEW(ROW, STRUCT, FIELD, VALUE) (ROW).write_new(COL_INDEX(STRUCT, FIELD), VALUE) + /// Write an array of values into the fixed‐length `FIELD` array of `STRUCT` for one row. #define COL_WRITE_ARRAY(ROW, STRUCT, FIELD, VALUES) \ (ROW).write_array(COL_INDEX(STRUCT, FIELD), COL_ARRAY_LEN(STRUCT, FIELD), VALUES) +/// Write an array of values into the fixed‐length `FIELD` array of `STRUCT` for one row. +#define COL_WRITE_ARRAY_NEW(ROW, STRUCT, FIELD, VALUES) \ + (ROW).write_array_new(COL_INDEX(STRUCT, FIELD), COL_ARRAY_LEN(STRUCT, FIELD), VALUES) + /// Write a single value bits into `FIELD` of struct `STRUCT` at a given row. #define COL_WRITE_BITS(ROW, STRUCT, FIELD, VALUE) (ROW).write_bits(COL_INDEX(STRUCT, FIELD), VALUE) @@ -83,3 +205,14 @@ struct RowSlice { (ROW).fill_zero( \ COL_INDEX(STRUCT, FIELD), sizeof(static_cast *>(nullptr)->FIELD) \ ) + +__device__ __forceinline__ size_t number_of_gaps_in(const uint32_t *sub, size_t start, size_t len) { + size_t gaps = 0; +#pragma unroll + for (size_t i = start; i < start + len; ++i) { + if (sub[i] == UINT32_MAX) { + ++gaps; + } + } + return gaps; +} diff --git a/crates/circuits/primitives/cuda/src/range_tuple.cu b/crates/circuits/primitives/cuda/src/range_tuple.cu index 020b42e4d4..2a3baf71f8 100644 --- a/crates/circuits/primitives/cuda/src/range_tuple.cu +++ b/crates/circuits/primitives/cuda/src/range_tuple.cu @@ -1,5 +1,8 @@ +#include + #include "fp.h" #include "launcher.cuh" +#include "primitives/row_print_buffer.cuh" __global__ void range_tuple_checker_tracegen( const uint32_t *count, diff --git a/crates/vm/cuda/include/system/memory/controller.cuh b/crates/vm/cuda/include/system/memory/controller.cuh index 96697b8eaf..830681eac0 100644 --- a/crates/vm/cuda/include/system/memory/controller.cuh +++ b/crates/vm/cuda/include/system/memory/controller.cuh @@ -23,6 +23,18 @@ struct MemoryAuxColsFactory { COL_WRITE_VALUE(row, MemoryBaseAuxCols, prev_timestamp, prev_timestamp); } + __device__ void fill_new(RowSliceNew row, uint32_t prev_timestamp, uint32_t timestamp) { + AssertLessThan::generate_subrow_new( + range_checker, + timestamp_max_bits, + prev_timestamp, + timestamp, + AUX_LEN, + row.slice_from(COL_INDEX(MemoryBaseAuxCols, timestamp_lt_aux)) + ); + COL_WRITE_VALUE_NEW(row, MemoryBaseAuxCols, prev_timestamp, prev_timestamp); + } + __device__ void fill_zero(RowSlice row) { row.fill_zero(0, sizeof(MemoryBaseAuxCols)); } diff --git a/extensions/rv32im/circuit/Cargo.toml b/extensions/rv32im/circuit/Cargo.toml index 6e7cf75190..3193069732 100644 --- a/extensions/rv32im/circuit/Cargo.toml +++ b/extensions/rv32im/circuit/Cargo.toml @@ -40,7 +40,7 @@ test-case.workspace = true openvm-cuda-builder = { workspace = true, optional = true } [features] -default = ["parallel", "jemalloc"] +default = ["parallel", "jemalloc", "cuda"] parallel = ["openvm-circuit/parallel"] test-utils = ["openvm-circuit/test-utils", "dep:openvm-stark-sdk"] tco = ["openvm-circuit/tco"] diff --git a/extensions/rv32im/circuit/cuda/include/primitives/row_print_buffer.cuh b/extensions/rv32im/circuit/cuda/include/primitives/row_print_buffer.cuh new file mode 100644 index 0000000000..64a8e25908 --- /dev/null +++ b/extensions/rv32im/circuit/cuda/include/primitives/row_print_buffer.cuh @@ -0,0 +1,56 @@ +#pragma once + +#include + +// Utility buffer to print a single APC row atomically from device code. +struct RowPrintBuffer { + static constexpr int kCapacity = 8192; + char data[kCapacity]; + int len; + + __device__ __forceinline__ void reset() { len = 0; } + + __device__ __forceinline__ void append_char(char c) { + if (len < kCapacity - 1) { + data[len++] = c; + } + } + + __device__ __forceinline__ void append_literal(const char *literal) { + for (const char *ptr = literal; *ptr != '\0'; ++ptr) { + append_char(*ptr); + } + } + + __device__ __forceinline__ void append_uint(unsigned long long value) { + char tmp[32]; + int tmp_len = 0; + + if (value == 0) { + tmp[tmp_len++] = '0'; + } else { + while (value > 0 && tmp_len < static_cast(sizeof(tmp))) { + tmp[tmp_len++] = static_cast('0' + (value % 10)); + value /= 10; + } + } + + for (int i = tmp_len - 1; i >= 0; --i) { + append_char(tmp[i]); + } + } + + __device__ __forceinline__ void flush() { + data[len] = '\0'; + printf("%s", data); + } + + // Execute `fn` with this buffer after clearing it, then flush. + // `fn` must be a device callable accepting `RowPrintBuffer &`. + template + __device__ __forceinline__ void write_with(Fn fn) { + reset(); + fn(*this); + flush(); + } +}; diff --git a/extensions/rv32im/circuit/cuda/include/rv32im/adapters/alu.cuh b/extensions/rv32im/circuit/cuda/include/rv32im/adapters/alu.cuh index 9348ddcdd6..5f3402e3ac 100644 --- a/extensions/rv32im/circuit/cuda/include/rv32im/adapters/alu.cuh +++ b/extensions/rv32im/circuit/cuda/include/rv32im/adapters/alu.cuh @@ -2,6 +2,7 @@ #include "primitives/execution.h" #include "primitives/trace_access.h" +#include "primitives/row_print_buffer.cuh" #include "system/memory/controller.cuh" #include "system/memory/offline_checker.cuh" @@ -81,4 +82,49 @@ struct Rv32BaseAluAdapter { record.from_timestamp + 2 ); } -}; \ No newline at end of file + + __device__ void fill_trace_row_new(RowSliceNew row, Rv32BaseAluAdapterRecord record) { + COL_WRITE_VALUE_NEW(row, Rv32BaseAluAdapterCols, from_state.pc, record.from_pc); + COL_WRITE_VALUE_NEW(row, Rv32BaseAluAdapterCols, from_state.timestamp, record.from_timestamp); + + COL_WRITE_VALUE_NEW(row, Rv32BaseAluAdapterCols, rd_ptr, record.rd_ptr); + COL_WRITE_VALUE_NEW(row, Rv32BaseAluAdapterCols, rs1_ptr, record.rs1_ptr); + COL_WRITE_VALUE_NEW(row, Rv32BaseAluAdapterCols, rs2, record.rs2); + COL_WRITE_VALUE_NEW(row, Rv32BaseAluAdapterCols, rs2_as, record.rs2_as); + + // Read auxiliary for rs1 + mem_helper.fill_new( + row.slice_from(COL_INDEX(Rv32BaseAluAdapterCols, reads_aux[0])), + record.reads_aux[0].prev_timestamp, + record.from_timestamp + ); + + // rs2: register read when rs2_as == RV32_REGISTER_AS (== 1), otherwise immediate. + if (record.rs2_as != 0) { + mem_helper.fill_new( + row.slice_from(COL_INDEX(Rv32BaseAluAdapterCols, reads_aux[1])), + record.reads_aux[1].prev_timestamp, + record.from_timestamp + 1 + ); + } else { + RowSliceNew rs2_aux = row.slice_from(COL_INDEX(Rv32BaseAluAdapterCols, reads_aux[1])); +#pragma unroll + for (size_t i = 0; i < sizeof(MemoryReadAuxCols); i++) { + rs2_aux.write_new(i, 0); + } + uint32_t mask = (1u << RV32_CELL_BITS) - 1u; + if (!rs2_aux.is_apc) { + bitwise_lookup.add_range(record.rs2 & mask, (record.rs2 >> RV32_CELL_BITS) & mask); + } + } + + COL_WRITE_ARRAY_NEW( + row, Rv32BaseAluAdapterCols, writes_aux.prev_data, record.writes_aux.prev_data + ); + mem_helper.fill_new( + row.slice_from(COL_INDEX(Rv32BaseAluAdapterCols, writes_aux)), + record.writes_aux.prev_timestamp, + record.from_timestamp + 2 + ); + } +}; diff --git a/extensions/rv32im/circuit/cuda/include/rv32im/cores/alu.cuh b/extensions/rv32im/circuit/cuda/include/rv32im/cores/alu.cuh index d6d6976ec3..53e92a8881 100644 --- a/extensions/rv32im/circuit/cuda/include/rv32im/cores/alu.cuh +++ b/extensions/rv32im/circuit/cuda/include/rv32im/cores/alu.cuh @@ -139,4 +139,55 @@ template struct BaseAluCore { } } } + + __device__ void fill_trace_row_new(RowSliceNew row, BaseAluCoreRecord record) { + uint8_t a[NUM_LIMBS]; + uint8_t carry_buf[NUM_LIMBS]; + + switch (record.local_opcode) { + case 0: + run_add(record.b, record.c, a, carry_buf); + break; + case 1: + run_sub(record.b, record.c, a, carry_buf); + break; + case 2: + run_xor(record.b, record.c, a); + break; + case 3: + run_or(record.b, record.c, a); + break; + case 4: + run_and(record.b, record.c, a); + break; + default: +#pragma unroll + for (size_t i = 0; i < NUM_LIMBS; i++) { + a[i] = 0; + carry_buf[i] = 0; + } + } + + // TODO: we just optionally write here but we can also optionally compute things like `run_add` above + COL_WRITE_ARRAY_NEW(row, Cols, a, a); + COL_WRITE_ARRAY_NEW(row, Cols, b, record.b); + COL_WRITE_ARRAY_NEW(row, Cols, c, record.c); + + if (!row.is_apc) { + COL_WRITE_VALUE_NEW(row, Cols, opcode_add_flag, record.local_opcode == 0); + COL_WRITE_VALUE_NEW(row, Cols, opcode_sub_flag, record.local_opcode == 1); + COL_WRITE_VALUE_NEW(row, Cols, opcode_xor_flag, record.local_opcode == 2); + COL_WRITE_VALUE_NEW(row, Cols, opcode_or_flag, record.local_opcode == 3); + COL_WRITE_VALUE_NEW(row, Cols, opcode_and_flag, record.local_opcode == 4); + #pragma unroll + + for (size_t i = 0; i < NUM_LIMBS; i++) { + if (record.local_opcode == 0 || record.local_opcode == 1) { + bitwise_lookup.add_xor(a[i], a[i]); + } else { + bitwise_lookup.add_xor(record.b[i], record.c[i]); + } + } + } + } }; \ No newline at end of file diff --git a/extensions/rv32im/circuit/cuda/src/alu.cu b/extensions/rv32im/circuit/cuda/src/alu.cu index c64585204c..bed4395d48 100644 --- a/extensions/rv32im/circuit/cuda/src/alu.cu +++ b/extensions/rv32im/circuit/cuda/src/alu.cu @@ -5,6 +5,8 @@ #include "rv32im/adapters/alu.cuh" #include "rv32im/cores/alu.cuh" +#include + using namespace riscv; // Concrete type aliases for 32-bit @@ -23,17 +25,32 @@ struct Rv32BaseAluRecord { }; __global__ void alu_tracegen( - Fp *d_trace, + Fp *d_trace, // can be apc trace size_t height, DeviceBufferConstView d_records, uint32_t *d_range_checker_ptr, size_t range_checker_bins, uint32_t *d_bitwise_lookup_ptr, size_t bitwise_num_bits, - uint32_t timestamp_max_bits + uint32_t timestamp_max_bits, + uint32_t *subs, + uint32_t *d_opt_widths, + uint32_t *d_post_opt_offsets, + size_t apc_width, // 0 for non-apc + uint32_t calls_per_apc_row // 1 for non-apc ) { uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; - RowSlice row(d_trace + idx, height); + // d_post_opt_offsets is always 0 for non APC case + bool is_apc = apc_width != 0; + RowSliceNew row( + is_apc ? d_trace + idx / calls_per_apc_row + d_post_opt_offsets[idx % calls_per_apc_row] * height : d_trace + idx, + height, + is_apc ? d_post_opt_offsets[idx % calls_per_apc_row] : 0, + is_apc ? sizeof(Rv32BaseAluCols) * (idx % calls_per_apc_row): 0, // this way we don't need to pass over d_pre_opt_offsets + subs, + is_apc + ); // we need to slice to the correct APC row, but if non-APC it's dividing by 1 and therefore the same idx + if (idx < d_records.len()) { auto const &rec = d_records[idx]; @@ -42,12 +59,20 @@ __global__ void alu_tracegen( BitwiseOperationLookup(d_bitwise_lookup_ptr, bitwise_num_bits), timestamp_max_bits ); - adapter.fill_trace_row(row, rec.adapter); + adapter.fill_trace_row_new(row, rec.adapter); Rv32BaseAluCore core(BitwiseOperationLookup(d_bitwise_lookup_ptr, bitwise_num_bits)); - core.fill_trace_row(row.slice_from(COL_INDEX(Rv32BaseAluCols, core)), rec.core); + core.fill_trace_row_new(row.slice_from(COL_INDEX(Rv32BaseAluCols, core)), rec.core); } else { - row.fill_zero(0, sizeof(Rv32BaseAluCols)); + if (!is_apc) { + // non-apc case + row.fill_zero(0, sizeof(Rv32BaseAluCols)); + } else if (idx < height * calls_per_apc_row) { + // apc case, but we need to limit idx to smaller than the # of dummy instruction runs + // because `kernel_launch_params` rounds to the next MAX_THREADS number of runs + // which can write beyond what we desire + row.fill_zero(0, d_opt_widths[idx % calls_per_apc_row]); + } } } @@ -60,21 +85,38 @@ extern "C" int _alu_tracegen( size_t range_checker_bins, uint32_t *d_bitwise_lookup_ptr, size_t bitwise_num_bits, - uint32_t timestamp_max_bits + uint32_t timestamp_max_bits, + uint32_t *subs, + uint32_t *d_opt_widths, + uint32_t *d_post_opt_offsets, + size_t apc_height, // 0 for non-apc + size_t apc_width, // 0 for non-apc + uint32_t calls_per_apc_row // 1 for non-apc ) { assert((height & (height - 1)) == 0); + assert((apc_height & (apc_height - 1)) == 0); assert(height >= d_records.len()); - assert(width == sizeof(Rv32BaseAluCols)); - auto [grid, block] = kernel_launch_params(height); + bool is_apc = apc_width != 0; + if (!is_apc) { // only check for non-apc + assert(width == sizeof(Rv32BaseAluCols)); + } + size_t threads = is_apc ? (apc_height * calls_per_apc_row) : height; + auto [grid, block] = kernel_launch_params(threads); alu_tracegen<<>>( d_trace, - height, + is_apc ? apc_height : height, d_records, d_range_checker_ptr, range_checker_bins, d_bitwise_lookup_ptr, bitwise_num_bits, - timestamp_max_bits + timestamp_max_bits, + subs, + d_opt_widths, + d_post_opt_offsets, + apc_width, // 0 for non-apc + calls_per_apc_row // 1 for non-apc ); + return CHECK_KERNEL(); -} \ No newline at end of file +} diff --git a/extensions/rv32im/circuit/src/base_alu/cuda.rs b/extensions/rv32im/circuit/src/base_alu/cuda.rs index f1cf443200..8f8d3354a5 100644 --- a/extensions/rv32im/circuit/src/base_alu/cuda.rs +++ b/extensions/rv32im/circuit/src/base_alu/cuda.rs @@ -9,6 +9,7 @@ use openvm_cuda_backend::{ base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F, }; use openvm_cuda_common::copy::MemCopyH2D; +use openvm_cuda_common::d_buffer::DeviceBuffer; use openvm_stark_backend::{prover::types::AirProvingContext, Chip}; use crate::{ @@ -27,6 +28,42 @@ pub struct Rv32BaseAluChipGpu { } impl Chip for Rv32BaseAluChipGpu { + fn generate_proving_ctx_new(&self, arena: DenseRecordArena, d_trace: &DeviceBuffer, d_subs: &DeviceBuffer, d_opt_widths: &DeviceBuffer, d_post_opt_offsets: &DeviceBuffer, calls_per_apc_row: u32, apc_height: usize, apc_width: usize) { + const RECORD_SIZE: usize = size_of::<( + Rv32BaseAluAdapterRecord, + BaseAluCoreRecord, + )>(); + let records = arena.allocated(); + if records.is_empty() { + return; + } + debug_assert_eq!(records.len() % RECORD_SIZE, 0); + + let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE); + + let d_records = records.to_device().unwrap(); + + unsafe { + tracegen( + d_trace, // APC trace + trace_height, + &d_records, + &self.range_checker.count, + self.range_checker.count.len(), + &self.bitwise_lookup.count, + RV32_CELL_BITS, + self.timestamp_max_bits as u32, + d_subs, + d_opt_widths, + d_post_opt_offsets, + apc_height, + apc_width, + calls_per_apc_row, + ) + .unwrap(); + } + } + fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext { const RECORD_SIZE: usize = size_of::<( Rv32BaseAluAdapterRecord, @@ -55,6 +92,12 @@ impl Chip for Rv32BaseAluChipGpu { &self.bitwise_lookup.count, RV32_CELL_BITS, self.timestamp_max_bits as u32, + &DeviceBuffer::new(), // nullptr + &DeviceBuffer::new(), // nullptr + &DeviceBuffer::new(), // nullptr + 0, // apc_height: not used in this path so set to 0 + 0, // apc_width: not used in this path so set to 0 + 1, // calls_per_apc_row: 1 for non-apc ) .unwrap(); } diff --git a/extensions/rv32im/circuit/src/cuda_abi.rs b/extensions/rv32im/circuit/src/cuda_abi.rs index 90f733dc36..b7b909fafe 100644 --- a/extensions/rv32im/circuit/src/cuda_abi.rs +++ b/extensions/rv32im/circuit/src/cuda_abi.rs @@ -322,6 +322,12 @@ pub mod alu_cuda { d_bitwise_lookup: *mut u32, bitwise_num_bits: usize, timestamp_max_bits: u32, + d_subs: *mut u32, + d_opt_widths: *mut u32, + d_post_opt_offsets: *mut u32, + apc_height: usize, + apc_width: usize, + calls_per_apc_row: u32, // 1 for non-apc ) -> i32; } @@ -334,8 +340,16 @@ pub mod alu_cuda { d_bitwise_lookup: &DeviceBuffer, bitwise_num_bits: usize, timestamp_max_bits: u32, + d_subs: &DeviceBuffer, + d_opt_widths: &DeviceBuffer, + d_post_opt_offsets: &DeviceBuffer, + apc_height: usize, + apc_width: usize, + calls_per_apc_row: u32 ) -> Result<(), CudaError> { - let width = d_trace.len() / height; + // `width` is non sensical for APC, as we would have APC trace divided by dummy height + // It's used in an assertion for non-APC + let width = d_trace.len() / height; CudaError::from_result(_alu_tracegen( d_trace.as_mut_ptr(), height, @@ -346,6 +360,12 @@ pub mod alu_cuda { d_bitwise_lookup.as_mut_ptr() as *mut u32, bitwise_num_bits, timestamp_max_bits, + d_subs.as_mut_ptr() as *mut u32, + d_opt_widths.as_mut_ptr() as *mut u32, + d_post_opt_offsets.as_mut_ptr() as *mut u32, + apc_height, + apc_width, + calls_per_apc_row, )) } }