Skip to content

Commit 26f3760

Browse files
committed
changes so far
1 parent 94f022a commit 26f3760

File tree

4 files changed

+56
-15
lines changed

4 files changed

+56
-15
lines changed

Cargo.lock

Lines changed: 2 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,9 @@ cuda-runtime-sys = "0.3.0-alpha.1"
275275
ignored = ["cargo-openvm"]
276276

277277
# # The local openvm also needs to have stark-backend patched so all types match.
278-
# [patch."https://github.com/powdr-labs/stark-backend.git"]
279-
# openvm-stark-backend = { path = "../stark-backend/crates/stark-backend", default-features = false }
280-
# openvm-stark-sdk = { path = "../stark-backend/crates/stark-sdk", default-features = false }
281-
# openvm-cuda-backend = { path = "../stark-backend/crates/cuda-backend", default-features = false }
282-
# openvm-cuda-builder = { path = "../stark-backend/crates/cuda-builder", default-features = false }
283-
# openvm-cuda-common = { path = "../stark-backend/crates/cuda-common", default-features = false }
278+
[patch."https://github.com/powdr-labs/stark-backend.git"]
279+
openvm-stark-backend = { path = "../stark-backend/crates/stark-backend", default-features = false }
280+
openvm-stark-sdk = { path = "../stark-backend/crates/stark-sdk", default-features = false }
281+
openvm-cuda-backend = { path = "../stark-backend/crates/cuda-backend", default-features = false }
282+
openvm-cuda-builder = { path = "../stark-backend/crates/cuda-builder", default-features = false }
283+
openvm-cuda-common = { path = "../stark-backend/crates/cuda-common", default-features = false }

extensions/rv32im/circuit/cuda/src/alu.cu

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ __global__ void alu_tracegen(
3434
// Fp *d_apc_trace,
3535
uint32_t *subs, // same length as dummy width
3636
uint32_t calls_per_apc_row, // 1 for non-apc
37-
size_t width // dummy width
37+
size_t width // dummy width or apc width
3838
// uint32_t *apc_row_index, // dummy row mapping to apc row same length as d_records
3939
) {
4040
uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
@@ -56,7 +56,12 @@ __global__ void alu_tracegen(
5656
core.fill_trace_row_new(row.slice_from(COL_INDEX(Rv32BaseAluCols, core), number_of_gaps_in(sub, sizeof(Rv32BaseAluCols<uint8_t>))), rec.core, sub);
5757
} else {
5858
// TODO: use APC width if APC
59-
row.fill_zero(0, sizeof(Rv32BaseAluCols<uint8_t>));
59+
// this is now a hack because calls_per_apc_row can still be 1 even if we are in an APC
60+
if (calls_per_apc_row == 1) {
61+
row.fill_zero(0, sizeof(Rv32BaseAluCols<uint8_t>));
62+
} else {
63+
row.fill_zero(0, width);
64+
}
6065
}
6166
}
6267

@@ -77,7 +82,7 @@ extern "C" int _alu_tracegen(
7782
) {
7883
assert((height & (height - 1)) == 0);
7984
assert(height >= d_records.len());
80-
assert(width == sizeof(Rv32BaseAluCols<uint8_t>));
85+
// assert(width == sizeof(Rv32BaseAluCols<uint8_t>)); // this is no longer true for APC
8186
auto [grid, block] = kernel_launch_params(height);
8287
alu_tracegen<<<grid, block>>>(
8388
d_trace,
@@ -91,7 +96,7 @@ extern "C" int _alu_tracegen(
9196
// Fp *d_apc_trace,
9297
subs, // same length as dummy width
9398
calls_per_apc_row, // 1 for non-apc
94-
width // dummy width
99+
width // dummy width or apc width
95100
// uint32_t *apc_row_index, // dummy row mapping to apc row same length as d_records
96101
);
97102
return CHECK_KERNEL();

extensions/rv32im/circuit/src/base_alu/cuda.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use openvm_cuda_backend::{
99
base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F,
1010
};
1111
use openvm_cuda_common::copy::MemCopyH2D;
12+
use openvm_cuda_common::d_buffer::DeviceBuffer;
1213
use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
1314

1415
use crate::{
@@ -27,6 +28,44 @@ pub struct Rv32BaseAluChipGpu {
2728
}
2829

2930
impl Chip<DenseRecordArena, GpuBackend> for Rv32BaseAluChipGpu {
31+
fn generate_proving_ctx_new(&self, arena: DenseRecordArena, d_trace: &DeviceBuffer<F>, d_subs: &DeviceBuffer<u32>, calls_per_apc_row: u32) {
32+
const RECORD_SIZE: usize = size_of::<(
33+
Rv32BaseAluAdapterRecord,
34+
BaseAluCoreRecord<RV32_REGISTER_NUM_LIMBS>,
35+
)>();
36+
let records = arena.allocated();
37+
if records.is_empty() {
38+
return;
39+
// return get_empty_air_proving_ctx::<GpuBackend>();
40+
}
41+
debug_assert_eq!(records.len() % RECORD_SIZE, 0);
42+
43+
let trace_width = BaseAluCoreCols::<F, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>::width()
44+
+ Rv32BaseAluAdapterCols::<F>::width();
45+
let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE);
46+
47+
let d_records = records.to_device().unwrap();
48+
49+
unsafe {
50+
tracegen(
51+
d_trace, // replaced with apc trace
52+
trace_height,
53+
&d_records,
54+
&self.range_checker.count,
55+
self.range_checker.count.len(),
56+
&self.bitwise_lookup.count,
57+
RV32_CELL_BITS,
58+
self.timestamp_max_bits as u32,
59+
// d_apc_trace.buffer(),
60+
d_subs, // same length as dummy width
61+
calls_per_apc_row,
62+
)
63+
.unwrap();
64+
}
65+
66+
// AirProvingContext::simple_no_pis(d_trace)
67+
}
68+
3069
fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
3170
const RECORD_SIZE: usize = size_of::<(
3271
Rv32BaseAluAdapterRecord,

0 commit comments

Comments
 (0)