diff --git a/lit_tests/kernel/wave/gemm.py b/lit_tests/kernel/wave/gemm.py index 7d51d46e4..253555490 100644 --- a/lit_tests/kernel/wave/gemm.py +++ b/lit_tests/kernel/wave/gemm.py @@ -1419,7 +1419,7 @@ def test_gemm_four_stage_global_to_lds(): # Prologue # Verify prologue stores to shared memory - # CHECK: rocdl.tensor.load.to.lds + # CHECK: amdgpu.tensor_load_to_lds # CHECK: rocdl.s.wait.tensorcnt 0 # CHECK: rocdl.s.wait.dscnt 0 @@ -1431,7 +1431,7 @@ def test_gemm_four_stage_global_to_lds(): # CHECK: %[[LOAD_IDX2:.*]] = affine.apply #[[MAP_LOAD2:.*]]()[%thread_id_x] # CHECK: vector.load %[[VIEW1]][%[[LOAD_IDX1]], %[[LOAD_IDX2]]] - # CHECK: rocdl.tensor.load.to.lds + # CHECK: amdgpu.tensor_load_to_lds # Main Loop: # Verify Pipelined Loop, iter_args should contain vector values from prologue @@ -1452,7 +1452,7 @@ def test_gemm_four_stage_global_to_lds(): # CHECK: rocdl.s.barrier.signal id = -1 # CHECK: rocdl.s.barrier.wait id = -1 - # CHECK: rocdl.tensor.load.to.lds + # CHECK: amdgpu.tensor_load_to_lds # CHECK: scf.yield {{.*}}, %[[LVIEW2]], %[[LVIEW3]], %[[LVIEW0]], %[[LVIEW1]] diff --git a/lit_tests/kernel/wave/mma.py b/lit_tests/kernel/wave/mma.py index 4eadb7f25..833f2605f 100644 --- a/lit_tests/kernel/wave/mma.py +++ b/lit_tests/kernel/wave/mma.py @@ -644,6 +644,8 @@ def mma( # CHECK: func.func @mma ### global buffer is bound to %0, %1 and %2 : MK, NK, MN + # CHECK: %[[C0:.*]] = arith.constant 0 : index + # CHECK: %[[SUBSPAN0:.*]] = stream.binding.subspan # CHECK: %[[SUBSPAN1:.*]] = stream.binding.subspan # CHECK: %[[SUBSPAN2:.*]] = stream.binding.subspan @@ -658,14 +660,8 @@ def mma( # CHECK: %[[VIEW0:.*]] = memref.view %[[SMEM]][{{.*}}] : memref<4608xi8, #gpu.address_space> to memref<32x36xf16, #gpu.address_space> # CHECK: %[[VIEW1:.*]] = memref.view %[[SMEM]][{{.*}}] : memref<4608xi8, #gpu.address_space> to memref<32x36xf16, #gpu.address_space> - ### get global buffer pointer - # CHECK: %[[INT_PTR_0:.+]] = memref.extract_aligned_pointer_as_index - - ### get shared buffer pointer - # CHECK: %[[CAST_3:.*]] = memref.reinterpret_cast %[[VIEW1]] - # CHECK: %[[INT_PTR_1:.+]] = memref.extract_aligned_pointer_as_index %[[CAST_3]] - - # CHECK: %[[D0:.*]] = vector.from_elements + ### make DMA base + # CHECK: %[[DMA_BASE0:.+]] = amdgpu.make_dma_base {{.*}}, %[[VIEW1]][{{.*}}] # Cluster mask generation # CHECK: %[[COND0:.*]] = arith.cmpi eq, %{{.*}}, %{{.*}} : index @@ -677,30 +673,17 @@ def mma( # CHECK: %[[MASK3:.*]] = arith.select %[[COND1]], %{{.*}}, %[[MASK2]] : index # CHECK: %[[MASK4:.*]] = arith.select %[[COND0]], %{{.*}}, %[[MASK3]] : index - ### pack descriptors and invoke tensor load - - # CHECK: %[[TENSOR_DESC_0:.*]] = vector.from_elements - # CHECK-NOT: rocdl.tensor.load.to.lds - # CHECK-NOT: rocdl.s.wait.tensorcnt - # CHECK-NOT: amdgpu.lds_barrier - - ### get shared buffer pointer - # CHECK: %[[CAST_4:.*]] = memref.reinterpret_cast %[[VIEW0]] - # CHECK: %[[INT_PTR_2:.+]] = memref.extract_aligned_pointer_as_index %[[CAST_4]] - # CHECK: %[[INT_PTR_2_CAST:.+]] = arith.index_cast %[[INT_PTR_2]] : index to i32 - # CHECK: %[[INT_PTR_2_CAST_ADDED:.+]] = arith.addi %[[INT_PTR_2_CAST]], %{{.*}} : i32 + # CHECK: %[[TENSOR_DESC_0:.*]] = amdgpu.make_dma_descriptor %[[DMA_BASE0:.+]] globalSize [%{{.*}}, 32] globalStride [32, 1] sharedSize [16, 32] - ### pack descriptors and invoke tensor load - # CHECK: %[[D1:.*]] = vector.from_elements %{{.*}}, %[[INT_PTR_2_CAST_ADDED]], %{{.*}}, %{{.*}} : vector<4xi32> - # CHECK: %[[TENSOR_DESC_1:.*]] = vector.from_elements + # CHECK: %[[DMA_BASE1:.+]] = amdgpu.make_dma_base {{.*}}, %[[VIEW0]][{{.*}}] + # CHECK: %[[TENSOR_DESC_1:.*]] = amdgpu.make_dma_descriptor %[[DMA_BASE1:.+]] globalSize [%{{.*}}, 32] globalStride [32, 1] sharedSize [16, 32] # Fused descriptors # CHECK: %[[SELECTED:.*]] = arith.cmpi eq, %{{.*}}, %[[C0]] : index - # CHECK: %[[D_FUSED:.*]] = arith.select %[[SELECTED]], %[[D0]], %[[D1]] : vector<4xi32> - # CHECK: %[[DESC_FUSED:.*]] = arith.select %[[SELECTED]], %[[TENSOR_DESC_0]], %[[TENSOR_DESC_1]] : vector<8xi32> + # CHECK: %[[DESC_FUSED:.*]] = arith.select %[[SELECTED]], %[[TENSOR_DESC_0]], %[[TENSOR_DESC_1]] : !amdgpu.tdm_descriptor ### resource provider - # CHECK: rocdl.tensor.load.to.lds %[[D_FUSED]], %[[DESC_FUSED]], {{.*}}, {{.*}} cachepolicy {{.*}} : vector<4xi32>, vector<8xi32> + # CHECK: amdgpu.tensor_load_to_lds %[[DESC_FUSED:.*]] # CHECK: rocdl.s.wait.tensorcnt 0 # CHECK: rocdl.s.wait.dscnt 0 # CHECK: rocdl.s.barrier.signal id = -1 diff --git a/tests/kernel/wave_gemm_test.py b/tests/kernel/wave_gemm_test.py index cdd10bf45..381c919da 100644 --- a/tests/kernel/wave_gemm_test.py +++ b/tests/kernel/wave_gemm_test.py @@ -409,8 +409,8 @@ def testGemmGlobalToLDS( asm = gemm.asm assert ( - "amdgpu.gather_to_lds" in asm or "tensor.load.to.lds" in asm - ), "gather_to_lds / tensor.load.to.lds not found in asm" + "amdgpu.gather_to_lds" in asm or "amdgpu.tensor_load_to_lds" in asm + ), "gather_to_lds / tensor_load_to_lds not found in asm" if run_bench: options.benchmark_results_file = perf_filename_iree diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 098ee560e..f6db4f7b2 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -31,7 +31,6 @@ gpu_d, llvm_d, memref_d, - rocdl_d, vector_d, ) from .ir_utils import ( @@ -862,26 +861,20 @@ def handle_tensor_load_to_lds(emitter: WaveEmitter, node: fx.Node): destinations ), "sources and destinations must have the same number of elements." - # construct default descriptors + i1 = IntegerType.get_signless(1) + i16 = IntegerType.get_signless(16) i32 = IntegerType.get_signless(32) - i48 = IntegerType.get_signless(48) - i57 = IntegerType.get_signless(57) + v1i16 = VectorType.get([1], i16) + v16i1 = VectorType.get([16], i1) - vec_type_4 = VectorType.get((4,), i32) - vec_type_8 = VectorType.get((8,), i32) + ir_type = IrType.parse(element_type.dtype.ir_type_asm()) + dma_type = amdgpu_d.TDMBaseType.get(ir_type) - c0 = arith_d.constant(i32, 0) - - d0_results = [] - d1_results = [] - d2_results = [] - d3_results = [] + results = [] subs = add_emitter_subs(emitter) for i, (src, dst) in enumerate(zip(sources, destinations)): - dst_memory = get_custom(dst) - symbolic_shape = _get_symbolic_shape(src) global_tile_index_current = {k: global_tile_index[k] for k in symbolic_shape} global_tile_index_current = _subs_index_dict( @@ -896,210 +889,75 @@ def handle_tensor_load_to_lds(emitter: WaveEmitter, node: fx.Node): strides = strides_from_symbolic_shape( IndexingContext.current(), symbolic_shape, allow_mixed_shapes=True ) - # Descriptor assumes rightmost stride 1 and expect last stride as full data size - strides = [strides[0] * symbolic_shape[0]] + strides[:-1] - strides = [gen_sympy_index(subs, s) for s in strides] distributed_shape_vals = [ gen_sympy_index(subs, distributed_shape[s]) for s in symbolic_shape ] - d0 = vector_d.broadcast(vec_type_4, c0) - d1 = vector_d.broadcast(vec_type_8, c0) - d2 = vector_d.broadcast(vec_type_4, c0) - d3 = vector_d.broadcast(vec_type_4, c0) - - # descriptor properties - mode = 2 # vimage - valid = 1 - dim_stride_1 = arith_d.index_cast(i48, strides[0]) - dim_stride_0 = arith_d.index_cast(i48, strides[1]) - tile_size_1 = arith_d.index_cast(i32, distributed_shape_vals[0]) - tile_size_0 = arith_d.index_cast(i32, distributed_shape_vals[1]) - dim_size_1 = arith_d.index_cast(i32, local_bounds[0]) - dim_size_0 = arith_d.index_cast(i32, local_bounds[1]) - - # 0: 1 byte; 1: 2 byte; 2: 4 byte; 3: 8 byte - descriptor_type = lambda x: int(math.log2(x.bitwidth() >> 3)) - data_size = cast_py_value(emitter, descriptor_type(element_type), i32).ir_value - global_mem = cast_py_value(emitter, src) shared_mem = cast_py_value(emitter, dst) global_value = global_mem.ir_value shared_value = shared_mem.ir_value - bytewidth = element_type.bitwidth() // 8 - element_byte_index = arith_d.constant(IndexType.get(), bytewidth) - - # calculcate global address - # 0. breakdown index sequence to WG & TH offsets : ele - # 1. uniform per wave access : ele - # 2. linearize global memory buffer - # 3. offset = X + Y * tensor dim 0 stride : ele - # 4. offset_byte = offset * element byte : byte - # 5. get global memory pointer - # 6. move global memory pointer by offset_byte to get global address of a tile : byte index, _, _ = _build_start_indices(emitter, global_tile_index_current) - wave_index_x = assume_index_subgroup_uniform(index[1], i32) # k - wave_index_y = assume_index_subgroup_uniform(index[0], i32) # m - - stride0 = arith_d.index_cast(IndexType.get(), dim_stride_0) - y_offset = arith_d.muli(wave_index_y, stride0) - global_base_offset = arith_d.addi(wave_index_x, y_offset) - global_index_offset = arith_d.muli(global_base_offset, element_byte_index) - - global_ptr = memref_d.extract_aligned_pointer_as_index(global_value) - global_byte_address = arith_d.addi(global_ptr, global_index_offset) - - # calculate shared address - # 0. extract shared tile index from IndexSequence structure - # 1. calculate byte offset from tile indices and distributed shape - # 2. get shared memory pointer - # 3. move shared memory pointer by offset_byte to get shared memory address of a tile. - shared_buffer = _linearize_shared_mem(shared_value) - - shared_strides = strides_from_symbolic_shape( - IndexingContext.current(), - dst_memory.distributed_shape, - allow_mixed_shapes=True, - ) - shared_tile_index_current = {k: shared_tile_index[k] for k in symbolic_shape} shared_tile_index_current = _subs_index_dict( shared_tile_index_current, {INPUT_SELECTOR: i} ) - linearized_index = { - "linearized_idx": linearize_index(shared_tile_index_current, shared_strides) - } - # Calculate shared memory offset from tile indices - shared_index, _, _ = _build_start_indices(emitter, linearized_index) - - shared_index_offset = arith_d.muli(shared_index[0], element_byte_index) - shared_byte_offset = arith_d.index_cast(i32, shared_index_offset) - - shared_ptr = memref_d.extract_aligned_pointer_as_index(shared_buffer) - shared_ptr = arith_d.index_cast(i32, shared_ptr) - - shared_ptr_base_offset = memref_d.extract_strided_metadata(shared_buffer)[1] - shared_ptr_base_offset = arith_d.index_cast(i32, shared_ptr_base_offset) - - shared_byte_address = arith_d.addi(shared_ptr_base_offset, shared_byte_offset) - shared_byte_address = arith_d.addi(shared_ptr, shared_byte_address) - - # assume no mapping - def lshift(value, bits): - sh = arith_d.constant(value.type, bits) - val = arith_d.shli(value, sh) - return val - - def rshift(value, bits): - sh = arith_d.constant(value.type, bits) - val = arith_d.shrui(value, sh) - return val - - # pack global address of a tile - # 1. get lower 32 bit from global value - global_val = arith_d.index_cast(i57, global_byte_address) # i57 - global_val_lower = arith_d.trunci(i32, global_val) - d0 = vector_d.insert( - global_val_lower, d0, static_position=[2], dynamic_position=[] + shared_index, _, _ = _build_start_indices(emitter, shared_tile_index_current) + + base = amdgpu_d.make_dma_base( + base=dma_type, + global_=global_value, + global_indices=index, + lds=shared_value, + lds_indices=shared_index, ) - # 2. get rest of the upper 25 bit from global value and cast to i32 - global_val_rest = rshift(global_val, 32) - global_val_upper = arith_d.trunci(i32, global_val_rest) - # 3. pack with image mode bit - mode = arith_d.constant(i32, mode) - image_mode = lshift(mode, 30) - pack = arith_d.ori(image_mode, global_val_upper) - d0 = vector_d.insert(pack, d0, static_position=[3], dynamic_position=[]) - - # insert shared addreess to descriptor 0 - d0 = vector_d.insert( - shared_byte_address, d0, static_position=[1], dynamic_position=[] - ) - - # valid tensor - valid_tensor = arith_d.constant(i32, valid) - d0 = vector_d.insert(valid_tensor, d0, static_position=[0], dynamic_position=[]) - - # get data size val packed to i32 - data_size_val = lshift(data_size, 16) + pad_interval = None + pad_amount = None original_dst = propagate_loop_carried_vars(dst) original_dst = get_custom(original_dst) if padding := original_dst.padding: + bytewidth = element_type.bitwidth() // 8 unpadded_dim = int(subs_idxc(original_dst.unpadded_shape[-1])) * bytewidth assert ( unpadded_dim >= 8 ), f"Invalid unpadded_dim for padding: {unpadded_dim} (must be at least 8 bytes)" - pad_enable = 1 << 20 - pad_interval = int(math.log2((unpadded_dim // 4) - 1)) << 22 - pad_amount = ((padding * bytewidth) // 4 - 1) << 25 - pad_packed = pad_enable | pad_interval | pad_amount - data_size_val = arith_d.ori( - data_size_val, arith_d.constant(i32, pad_packed) - ) - - local_multicast_mask = subs_idxc(safe_subs(multicast_mask, {INPUT_SELECTOR: i})) - - if local_multicast_mask: + DWORD_SIZE = 4 + pad_interval = arith_d.constant(i32, unpadded_dim // DWORD_SIZE) + pad_amount = arith_d.constant(i32, (padding * bytewidth) // DWORD_SIZE) + + workgroup_mask = None + if local_multicast_mask := subs_idxc( + safe_subs(multicast_mask, {INPUT_SELECTOR: i}) + ): local_multicast_mask = sympy.simplify(local_multicast_mask) local_multicast_mask_val = gen_sympy_index(subs, local_multicast_mask) - local_multicast_mask_val = arith_d.index_cast(i32, local_multicast_mask_val) - data_size_val = arith_d.ori(data_size_val, local_multicast_mask_val) - - d1 = vector_d.insert( - data_size_val, d1, static_position=[0], dynamic_position=[] - ) - - # get lower 16 bit from tensor dim 0 and pack to i32 - tensor_dim_0_lower = lshift(dim_size_0, 16) - d1 = vector_d.insert( - tensor_dim_0_lower, d1, static_position=[1], dynamic_position=[] - ) - - # get upper 16 bit from tensor dim 0 and lower 16 bit from tensor dim 1, pack to i32 - tensor_dim_0_upper = rshift(dim_size_0, 16) - tensor_dim_1_lower = lshift(dim_size_1, 16) - pack = arith_d.ori(tensor_dim_1_lower, tensor_dim_0_upper) - d1 = vector_d.insert(pack, d1, static_position=[2], dynamic_position=[]) - - # get upper 16 bit from tensor dim 1, packed with tile size 0 - tensor_dim_1_upper = rshift(dim_size_1, 16) - tile_size_0_shift = lshift(tile_size_0, 16) - pack = arith_d.ori(tensor_dim_1_upper, tile_size_0_shift) - d1 = vector_d.insert(pack, d1, static_position=[3], dynamic_position=[]) - - # tile size 1 is in good form - d1 = vector_d.insert(tile_size_1, d1, static_position=[4], dynamic_position=[]) - - # truncate upper 16 bit from dim stride 0 -> i48 to i32 - dim_stride_0_trunc = arith_d.trunci(i32, dim_stride_0) - d1 = vector_d.insert( - dim_stride_0_trunc, d1, static_position=[5], dynamic_position=[] + workgroup_mask = arith_d.index_cast(i16, local_multicast_mask_val) + workgroup_mask = vector_d.from_elements(v1i16, [workgroup_mask]) + workgroup_mask = vector_d.bitcast(v16i1, workgroup_mask) + + desc = amdgpu_d.make_dma_descriptor( + base=base, + global_dynamic_sizes=local_bounds, + global_static_sizes=[ShapedType.get_dynamic_size()] * len(local_bounds), + global_dynamic_strides=None, + global_static_strides=strides, + shared_dynamic_sizes=distributed_shape_vals, + shared_static_sizes=[ShapedType.get_dynamic_size()] + * len(distributed_shape_vals), + atomic_barrier_indices=None, + workgroup_mask=workgroup_mask, + pad_amount=pad_amount, + pad_interval=pad_interval, ) - # get upper 16 bit from dim stride 0, get lower 16 bit from dim stride 1, packed to i32 - dim_stride_0_upper = rshift(dim_stride_0, 32) - dim_stride_0_trunc = arith_d.trunci(i32, dim_stride_0_upper) - dim_stride_1_lower = arith_d.trunci(i32, dim_stride_1) - dim_stride_1_trunc = lshift(dim_stride_1_lower, 16) - pack = arith_d.ori(dim_stride_0_trunc, dim_stride_1_trunc) - d1 = vector_d.insert(pack, d1, static_position=[6], dynamic_position=[]) - - # shift dim stride 1 to get upper 32 bit and pack to i32 - dim_stride_1_sh = rshift(dim_stride_1, 16) - pack = arith_d.trunci(i32, dim_stride_1_sh) - d1 = vector_d.insert(pack, d1, static_position=[7], dynamic_position=[]) - - d0_results.append(d0) - d1_results.append(d1) - d2_results.append(d2) - d3_results.append(d3) + results.append(desc) # Select the appropriate descriptors based on input_selector # Build chained select operations for each descriptor @@ -1122,14 +980,9 @@ def select_descriptor(results_list, input_selector_val): return selected input_selector_val = gen_sympy_index(subs, input_selector) - d0_selected = select_descriptor(d0_results, input_selector_val) - d1_selected = select_descriptor(d1_results, input_selector_val) - d2_selected = select_descriptor(d2_results, input_selector_val) - d3_selected = select_descriptor(d3_results, input_selector_val) + selected = select_descriptor(results, input_selector_val) - return rocdl_d.tensor_load_to_lds( - d0_selected, d1_selected, d2_selected, d3_selected, 0 - ) + return amdgpu_d.tensor_load_to_lds(selected) @handle_op(gather_to_lds)