Skip to content

Commit 94f022a

Browse files
committed
added offset support for APC and sub; currently backward compatible with non-APC but not sure if APC works
1 parent 2acdfa2 commit 94f022a

File tree

9 files changed

+104
-27
lines changed

9 files changed

+104
-27
lines changed

crates/circuits/primitives/cuda/include/primitives/histogram.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ struct VariableRangeChecker {
8686
__device__ __forceinline__ void decompose_new(
8787
uint32_t x,
8888
size_t bits,
89-
RowSlice limbs,
89+
RowSliceNew limbs,
9090
const size_t limbs_len,
9191
uint32_t *sub
9292
) {

crates/circuits/primitives/cuda/include/primitives/less_than.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ __device__ __forceinline__ void generate_subrow_new(
4747
uint32_t x,
4848
uint32_t y,
4949
const size_t lower_decomp_len,
50-
RowSlice lower_decomp,
50+
RowSliceNew lower_decomp,
5151
uint32_t *sub
5252
) {
5353
rc.decompose_new(y - x - 1, max_bits, lower_decomp, lower_decomp_len, sub);

crates/circuits/primitives/cuda/include/primitives/trace_access.h

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,79 @@
44
#include <cstddef>
55
#include <cstdint>
66

7+
8+
/// A RowSlice is a contiguous section of a row in col-based trace.
9+
struct RowSliceNew {
10+
Fp *ptr;
11+
size_t stride;
12+
size_t optimized_offset;
13+
size_t dummy_offset;
14+
15+
16+
__device__ RowSliceNew(Fp *ptr, size_t stride, size_t optimized_offset, size_t dummy_offset) : ptr(ptr), stride(stride), optimized_offset(optimized_offset), dummy_offset(dummy_offset) {}
17+
18+
__device__ __forceinline__ Fp &operator[](size_t column_index) const {
19+
// While implementing tracegen for SHA256, we encountered what we believe to be an nvcc
20+
// compiler bug. Occasionally, at various non-zero PTXAS optimization levels the compiler
21+
// tries to replace this multiplication with a series of SHL, ADD, and AND instructions
22+
// that we believe erroneously adds ~2^49 to the final address via an improper carry
23+
// propagation. To read more, see https://github.com/stephenh-axiom-xyz/cuda-illegal.
24+
return ptr[column_index * stride];
25+
}
26+
27+
__device__ static RowSliceNew null() { return RowSliceNew(nullptr, 0, 0, 0); }
28+
29+
__device__ bool is_valid() const { return ptr != nullptr; }
30+
31+
template <typename T>
32+
__device__ __forceinline__ void write(size_t column_index, T value) const {
33+
ptr[column_index * stride] = value;
34+
}
35+
36+
template <typename T>
37+
__device__ __forceinline__ void write_array(size_t column_index, size_t length, const T *values)
38+
const {
39+
#pragma unroll
40+
for (size_t i = 0; i < length; i++) {
41+
ptr[(column_index + i) * stride] = values[i];
42+
}
43+
}
44+
45+
template <typename T>
46+
__device__ __forceinline__ void write_array_new(size_t column_index, size_t length, const T *values, const uint32_t *sub)
47+
const {
48+
#pragma unroll
49+
for (size_t i = 0; i < length; i++) {
50+
if (sub[i] != UINT32_MAX) {
51+
ptr[(column_index + i) * stride] = values[i];
52+
}
53+
}
54+
}
55+
56+
template <typename T>
57+
__device__ __forceinline__ void write_bits(size_t column_index, const T value) const {
58+
#pragma unroll
59+
for (size_t i = 0; i < sizeof(T) * 8; i++) {
60+
ptr[(column_index + i) * stride] = (value >> i) & 1;
61+
}
62+
}
63+
64+
__device__ __forceinline__ void fill_zero(size_t column_index_from, size_t length) const {
65+
#pragma unroll
66+
for (size_t i = 0, c = column_index_from; i < length; i++, c++) {
67+
ptr[c * stride] = 0;
68+
}
69+
}
70+
71+
__device__ __forceinline__ RowSliceNew slice_from(size_t column_index, uint32_t gap) const {
72+
return RowSliceNew(ptr + (column_index - gap) * stride, stride, optimized_offset + column_index - gap, dummy_offset + column_index);
73+
}
74+
75+
__device__ __forceinline__ RowSliceNew shift_row(size_t n) const {
76+
return RowSliceNew(ptr + n, stride, optimized_offset, dummy_offset);
77+
}
78+
};
79+
780
/// A RowSlice is a contiguous section of a row in col-based trace.
881
struct RowSlice {
982
Fp *ptr;
@@ -87,9 +160,9 @@ struct RowSlice {
87160
#define COL_WRITE_VALUE_NEW(ROW, STRUCT, FIELD, VALUE, SUB) \
88161
do { \
89162
const size_t _col_idx = COL_INDEX(STRUCT, FIELD); \
90-
const auto _apc_idx = (SUB)[_col_idx]; \
163+
const auto _apc_idx = (SUB)[_col_idx + ROW.dummy_offset]; \
91164
if (_apc_idx != UINT32_MAX) { \
92-
(ROW).write(_apc_idx, VALUE); \
165+
(ROW).write(_apc_idx - ROW.optimized_offset, VALUE); \
93166
} \
94167
} while (0)
95168

crates/vm/cuda/include/system/memory/controller.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ struct MemoryAuxColsFactory {
2323
COL_WRITE_VALUE(row, MemoryBaseAuxCols, prev_timestamp, prev_timestamp);
2424
}
2525

26-
__device__ void fill_new(RowSlice row, uint32_t prev_timestamp, uint32_t timestamp, uint32_t *sub) {
26+
__device__ void fill_new(RowSliceNew row, uint32_t prev_timestamp, uint32_t timestamp, uint32_t *sub) {
2727
AssertLessThan::generate_subrow_new(
2828
range_checker,
2929
timestamp_max_bits,
3030
prev_timestamp,
3131
timestamp,
3232
AUX_LEN,
33-
row.slice_from(COL_INDEX(MemoryBaseAuxCols, timestamp_lt_aux) - number_of_gaps_in(sub, sizeof(MemoryBaseAuxCols<uint8_t>))),
33+
row.slice_from(COL_INDEX(MemoryBaseAuxCols, timestamp_lt_aux), number_of_gaps_in(sub, sizeof(MemoryBaseAuxCols<uint8_t>))),
3434
sub
3535
);
3636
COL_WRITE_VALUE_NEW(row, MemoryBaseAuxCols, prev_timestamp, prev_timestamp, sub);

extensions/rv32im/circuit/cuda/include/rv32im/adapters/alu.cuh

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ struct Rv32BaseAluAdapter {
8282
);
8383
}
8484

85-
__device__ void fill_trace_row_new(RowSlice row, Rv32BaseAluAdapterRecord record, uint32_t *sub) {
85+
__device__ void fill_trace_row_new(RowSliceNew row, Rv32BaseAluAdapterRecord record, uint32_t *sub) {
8686
COL_WRITE_VALUE_NEW(row, Rv32BaseAluAdapterCols, from_state.pc, record.from_pc, sub);
8787
COL_WRITE_VALUE_NEW(row, Rv32BaseAluAdapterCols, from_state.timestamp, record.from_timestamp, sub);
8888

@@ -93,7 +93,7 @@ struct Rv32BaseAluAdapter {
9393

9494
// Read auxiliary for rs1
9595
mem_helper.fill_new(
96-
row.slice_from(COL_INDEX(Rv32BaseAluAdapterCols, reads_aux[0]) - number_of_gaps_in(sub, sizeof(Rv32BaseAluAdapterCols<uint8_t>))),
96+
row.slice_from(COL_INDEX(Rv32BaseAluAdapterCols, reads_aux[0]), number_of_gaps_in(sub, sizeof(Rv32BaseAluAdapterCols<uint8_t>))),
9797
record.reads_aux[0].prev_timestamp,
9898
record.from_timestamp,
9999
sub
@@ -102,13 +102,13 @@ struct Rv32BaseAluAdapter {
102102
// rs2: register read when rs2_as == RV32_REGISTER_AS (== 1), otherwise immediate.
103103
if (record.rs2_as != 0) {
104104
mem_helper.fill_new(
105-
row.slice_from(COL_INDEX(Rv32BaseAluAdapterCols, reads_aux[1]) - number_of_gaps_in(sub, sizeof(Rv32BaseAluAdapterCols<uint8_t>))),
105+
row.slice_from(COL_INDEX(Rv32BaseAluAdapterCols, reads_aux[1]), number_of_gaps_in(sub, sizeof(Rv32BaseAluAdapterCols<uint8_t>))),
106106
record.reads_aux[1].prev_timestamp,
107107
record.from_timestamp + 1,
108108
sub
109109
);
110110
} else {
111-
RowSlice rs2_aux = row.slice_from(COL_INDEX(Rv32BaseAluAdapterCols, reads_aux[1]) - number_of_gaps_in(sub, sizeof(Rv32BaseAluAdapterCols<uint8_t>)));
111+
RowSliceNew rs2_aux = row.slice_from(COL_INDEX(Rv32BaseAluAdapterCols, reads_aux[1]), number_of_gaps_in(sub, sizeof(Rv32BaseAluAdapterCols<uint8_t>)));
112112
#pragma unroll
113113
for (size_t i = 0; i < sizeof(MemoryReadAuxCols<uint8_t>); i++) {
114114
rs2_aux.write(i, 0);
@@ -117,11 +117,11 @@ struct Rv32BaseAluAdapter {
117117
bitwise_lookup.add_range(record.rs2 & mask, (record.rs2 >> RV32_CELL_BITS) & mask);
118118
}
119119

120-
COL_WRITE_ARRAY(
121-
row, Rv32BaseAluAdapterCols, writes_aux.prev_data, record.writes_aux.prev_data
120+
COL_WRITE_ARRAY_NEW(
121+
row, Rv32BaseAluAdapterCols, writes_aux.prev_data, record.writes_aux.prev_data, sub
122122
);
123123
mem_helper.fill_new(
124-
row.slice_from(COL_INDEX(Rv32BaseAluAdapterCols, writes_aux) - number_of_gaps_in(sub, sizeof(Rv32BaseAluAdapterCols<uint8_t>))),
124+
row.slice_from(COL_INDEX(Rv32BaseAluAdapterCols, writes_aux), number_of_gaps_in(sub, sizeof(Rv32BaseAluAdapterCols<uint8_t>))),
125125
record.writes_aux.prev_timestamp,
126126
record.from_timestamp + 2,
127127
sub

extensions/rv32im/circuit/cuda/include/rv32im/cores/alu.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ template <size_t NUM_LIMBS> struct BaseAluCore {
140140
}
141141
}
142142

143-
__device__ void fill_trace_row_new(RowSlice row, BaseAluCoreRecord<NUM_LIMBS> record, uint32_t *sub) {
143+
__device__ void fill_trace_row_new(RowSliceNew row, BaseAluCoreRecord<NUM_LIMBS> record, uint32_t *sub) {
144144
uint8_t a[NUM_LIMBS];
145145
uint8_t carry_buf[NUM_LIMBS];
146146

extensions/rv32im/circuit/cuda/src/alu.cu

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,26 +23,27 @@ struct Rv32BaseAluRecord {
2323
};
2424

2525
__global__ void alu_tracegen(
26-
Fp *d_trace,
27-
size_t height,
26+
Fp *d_trace, // can be apc trace
27+
size_t height, // can be apc height
2828
DeviceBufferConstView<Rv32BaseAluRecord> d_records,
2929
uint32_t *d_range_checker_ptr,
3030
size_t range_checker_bins,
3131
uint32_t *d_bitwise_lookup_ptr,
3232
size_t bitwise_num_bits,
3333
uint32_t timestamp_max_bits,
3434
// Fp *d_apc_trace,
35-
uint32_t *subs // same length as dummy width
36-
// size_t width, // dummy width
35+
uint32_t *subs, // same length as dummy width
36+
uint32_t calls_per_apc_row, // 1 for non-apc
37+
size_t width // dummy width
3738
// uint32_t *apc_row_index, // dummy row mapping to apc row same length as d_records
3839
) {
3940
uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
40-
RowSlice row(d_trace + idx, height);
41+
RowSliceNew row(d_trace + idx / calls_per_apc_row, height, 0, 0); // we need to slice to the correct APC row, but if non-APC it's dividing by 1 and therefore the same idx
4142
if (idx < d_records.len()) {
4243
auto const &rec = d_records[idx];
4344
// RowSlice apc_row(d_apc_trace + apc_row_index[idx], height);
4445
// auto const sub = subs[idx * width]; // offset the subs to the corresponding dummy row
45-
uint32_t *sub = subs;
46+
uint32_t *sub = &subs[(idx % calls_per_apc_row) * width]; // dummy width
4647

4748
Rv32BaseAluAdapter adapter(
4849
VariableRangeChecker(d_range_checker_ptr, range_checker_bins),
@@ -52,7 +53,7 @@ __global__ void alu_tracegen(
5253
adapter.fill_trace_row_new(row, rec.adapter, sub);
5354

5455
Rv32BaseAluCore core(BitwiseOperationLookup(d_bitwise_lookup_ptr, bitwise_num_bits));
55-
core.fill_trace_row_new(row.slice_from(COL_INDEX(Rv32BaseAluCols, core) - number_of_gaps_in(sub, sizeof(Rv32BaseAluCols<uint8_t>))), rec.core, sub);
56+
core.fill_trace_row_new(row.slice_from(COL_INDEX(Rv32BaseAluCols, core), number_of_gaps_in(sub, sizeof(Rv32BaseAluCols<uint8_t>))), rec.core, sub);
5657
} else {
5758
// TODO: use APC width if APC
5859
row.fill_zero(0, sizeof(Rv32BaseAluCols<uint8_t>));
@@ -70,7 +71,8 @@ extern "C" int _alu_tracegen(
7071
size_t bitwise_num_bits,
7172
uint32_t timestamp_max_bits,
7273
// Fp *d_apc_trace,
73-
uint32_t *subs // same length as dummy width
74+
uint32_t *subs, // same length as dummy width
75+
uint32_t calls_per_apc_row // 1 for non-apc
7476
// uint32_t *apc_row_index, // dummy row mapping to apc row same length as d_records
7577
) {
7678
assert((height & (height - 1)) == 0);
@@ -87,8 +89,9 @@ extern "C" int _alu_tracegen(
8789
bitwise_num_bits,
8890
timestamp_max_bits,
8991
// Fp *d_apc_trace,
90-
subs // same length as dummy width
91-
// size_t width, // dummy width
92+
subs, // same length as dummy width
93+
calls_per_apc_row, // 1 for non-apc
94+
width // dummy width
9295
// uint32_t *apc_row_index, // dummy row mapping to apc row same length as d_records
9396
);
9497
return CHECK_KERNEL();

extensions/rv32im/circuit/src/base_alu/cuda.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ impl Chip<DenseRecordArena, GpuBackend> for Rv32BaseAluChipGpu {
6363
self.timestamp_max_bits as u32,
6464
// d_apc_trace.buffer(),
6565
&d_subs, // same length as dummy width
66-
// apc_row_index, // dummy row mapping to apc row same length as d_records
66+
1, // calls_per_apc_row: 1 for non-apc
6767
)
6868
.unwrap();
6969
}

extensions/rv32im/circuit/src/cuda_abi.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ pub mod alu_cuda {
324324
timestamp_max_bits: u32,
325325
// d_apc_trace: *mut F,
326326
d_subs: *mut u32,
327-
// apc_row_index: *mut u32,
327+
calls_per_apc_row: u32, // 1 for non-apc
328328
) -> i32;
329329
}
330330

@@ -339,7 +339,7 @@ pub mod alu_cuda {
339339
timestamp_max_bits: u32,
340340
// d_apc_trace: &DeviceBuffer<F>,
341341
d_subs: &DeviceBuffer<u32>,
342-
// apc_row_index: Option<Vec<u32>>,
342+
calls_per_apc_row: u32 // 1 for non-apc
343343
) -> Result<(), CudaError> {
344344
let width = d_trace.len() / height;
345345
CudaError::from_result(_alu_tracegen(
@@ -354,6 +354,7 @@ pub mod alu_cuda {
354354
timestamp_max_bits,
355355
// d_apc_trace.as_mut_ptr(),
356356
d_subs.as_mut_ptr() as *mut u32,
357+
calls_per_apc_row, // 1 for non-apc
357358
// apc_row_index,
358359
))
359360
}

0 commit comments

Comments
 (0)