Skip to content

Commit 645aa39

Browse files
dsharletgxnnpack-bot
authored andcommitted
Minor refactoring/cleanup of dot subgraph
- Choose kernel to determine packing layout in `ynn_define_dot` instead of packing helper - Move transpose optimization logic to a helper function - Remove unused `dot_type` parameters. The goal of this is to enable transposing A, to support using SME in subgraphs. PiperOrigin-RevId: 822337555
1 parent 36f6110 commit 645aa39

File tree

1 file changed

+43
-35
lines changed

1 file changed

+43
-35
lines changed

ynnpack/subgraph/dot.cc

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ auto make_dot_impl(dot_type type, size_t num_k_dims) {
296296

297297
// Make a kernel wrapper for packing the input of a dot kernel, i.e.
298298
// interleaving `tile_k` rows at a time.
299-
auto make_pack_impl(dot_type type) {
299+
auto make_pack_impl() {
300300
return [](slinky::raw_buffer input, slinky::raw_buffer output) -> index_t {
301301
const slinky::dim& input_n = input.dim(0);
302302
const slinky::dim& input_k = input.dim(1);
@@ -389,24 +389,13 @@ auto make_pack_impl(dot_type type) {
389389
};
390390
}
391391

392-
std::optional<size_t> get_extent(const ynn_value& x, int dim) {
393-
return dim < x.extents.size() ? as_constant(x.extents[dim]) : 1;
394-
}
395-
396-
void learn_shape_from_b(dot_shape& shape, size_t num_k_dims,
397-
const ynn_value& b) {
398-
shape.n = get_extent(b, 0);
399-
shape.k1 = get_extent(b, 1);
400-
shape.k2 = num_k_dims >= 2 ? get_extent(b, 2) : 1;
401-
shape.k3 = num_k_dims >= 3 ? get_extent(b, 3) : 1;
402-
}
403-
404392
// Packing means transposing
405393
// b(n, k, ...) => b(k%tile_k, n%nr, k/tile_k, n/tile_n, ...)
406394
// where tile_n is a multiple of the kernel's tile_n, but not greater than the
407395
// kernel's block_n.
408396
uint32_t define_pack_b(ynn_subgraph_t subgraph, const dot_type& type,
409-
size_t num_k_dims, uint32_t input_b_id) {
397+
const dot_kernel& kernel, size_t num_k_dims,
398+
uint32_t input_b_id) {
410399
const ynn_value& b = subgraph->value(input_b_id);
411400

412401
ynn_value& packed_b = subgraph->new_internal_value();
@@ -420,10 +409,6 @@ uint32_t define_pack_b(ynn_subgraph_t subgraph, const dot_type& type,
420409
slinky::expr k3 =
421410
3 < b.extents.size() && b.extents[3].defined() ? b.extents[3] : 1;
422411

423-
dot_shape shape;
424-
learn_shape_from_b(shape, num_k_dims, b);
425-
dot_kernel kernel = get_dot_kernel(type, shape);
426-
427412
const index_t cache_elements = cache_size_l2 / type_size_bytes(b.type);
428413

429414
// How many blocks of N fit in the cache?
@@ -448,7 +433,7 @@ uint32_t define_pack_b(ynn_subgraph_t subgraph, const dot_type& type,
448433
node.inputs = {input_b_id};
449434
node.outputs = {packed_b_id};
450435
node.op = ynn_node::pack_b{};
451-
node.create = [type](const ynn_node& node, ynn_runtime& runtime) {
436+
node.create = [](const ynn_node& node, ynn_runtime& runtime) {
452437
const ynn_runtime_value& input = runtime.value(node.inputs[0]);
453438
ynn_runtime_value& output = runtime.value(node.outputs[0]);
454439

@@ -482,9 +467,8 @@ uint32_t define_pack_b(ynn_subgraph_t subgraph, const dot_type& type,
482467

483468
slinky::call_stmt::attributes attrs;
484469
attrs.name = "pack_b";
485-
auto func =
486-
slinky::func::make(make_pack_impl(type), {std::move(func_input)},
487-
{{output.buffer, dims}}, std::move(attrs));
470+
auto func = slinky::func::make(make_pack_impl(), {std::move(func_input)},
471+
{{output.buffer, dims}}, std::move(attrs));
488472

489473
auto sched = std::make_unique<scheduling_info>();
490474

@@ -570,15 +554,20 @@ std::tuple<slinky::expr, slinky::expr> choose_split_factors(
570554
return {split_n, split_m};
571555
}
572556

573-
} // namespace
557+
std::optional<size_t> get_extent(const ynn_value& x, int dim) {
558+
return dim < x.extents.size() ? as_constant(x.extents[dim]) : 1;
559+
}
574560

575-
extern "C" {
561+
void learn_shape_from_b(dot_shape& shape, size_t num_k_dims,
562+
const ynn_value& b) {
563+
shape.n = get_extent(b, 0);
564+
shape.k1 = get_extent(b, 1);
565+
shape.k2 = num_k_dims >= 2 ? get_extent(b, 2) : 1;
566+
shape.k3 = num_k_dims >= 3 ? get_extent(b, 3) : 1;
567+
}
576568

577-
ynn_status ynn_define_dot(ynn_subgraph_t subgraph, size_t num_k_dims,
578-
uint32_t input_a_id, uint32_t input_b_id,
579-
uint32_t input_c_id, uint32_t* output_id,
580-
uint32_t flags) {
581-
const ynn_node* b_producer = subgraph->get_producer(input_b_id);
569+
ynn_status always_alias_transpose(ynn_subgraph& subgraph, uint32_t& id) {
570+
const ynn_node* b_producer = subgraph.get_producer(id);
582571
if (b_producer && std::get_if<ynn_node::static_transpose>(&b_producer->op)) {
583572
// The producer of this pack is a transpose. If it is transposing the rows
584573
// and columns of B, we can handle it with packing.
@@ -592,16 +581,26 @@ ynn_status ynn_define_dot(ynn_subgraph_t subgraph, size_t num_k_dims,
592581
// We don't rewrite the existing transpose op in the (unlikely) event that
593582
// it is used elsewhere. The existing transpose op will likely be
594583
// invalidated as a dead operation.
595-
input_b_id = YNN_INVALID_VALUE_ID;
596-
ynn_status status = define_static_transpose(
597-
subgraph, op.permutation, b_producer->inputs[0], &input_b_id,
598-
/*alias=*/true);
584+
id = YNN_INVALID_VALUE_ID;
585+
ynn_status status = define_static_transpose(&subgraph, op.permutation,
586+
b_producer->inputs[0], &id,
587+
/*alias=*/true);
599588
if (status != ynn_status_success) {
600589
return status;
601590
}
602591
}
603592
}
593+
return ynn_status_success;
594+
}
595+
596+
} // namespace
604597

598+
extern "C" {
599+
600+
ynn_status ynn_define_dot(ynn_subgraph_t subgraph, size_t num_k_dims,
601+
uint32_t input_a_id, uint32_t input_b_id,
602+
uint32_t input_c_id, uint32_t* output_id,
603+
uint32_t flags) {
605604
// Validate arguments.
606605
assert(subgraph);
607606
assert(subgraph->is_valid_value(input_a_id));
@@ -610,9 +609,13 @@ ynn_status ynn_define_dot(ynn_subgraph_t subgraph, size_t num_k_dims,
610609
assert(num_k_dims <= 3);
611610
assert(num_k_dims > 0);
612611

612+
ynn_status status = always_alias_transpose(*subgraph, input_b_id);
613+
if (status != ynn_status_success) {
614+
return status;
615+
}
616+
613617
const ynn_value& a = subgraph->value(input_a_id);
614618
const ynn_value& b = subgraph->value(input_b_id);
615-
616619
// If any input is a float, the output should be a float.
617620
const ynn_type c_type = !type_is_integral(a.type) || !type_is_integral(b.type)
618621
? ynn_type_fp32
@@ -627,7 +630,12 @@ ynn_status ynn_define_dot(ynn_subgraph_t subgraph, size_t num_k_dims,
627630

628631
// Insert a packing node (if necessary).
629632
dot_type type = {a.type, b.type, c.type};
630-
uint32_t packed_b_id = define_pack_b(subgraph, type, num_k_dims, input_b_id);
633+
dot_shape shape;
634+
learn_shape_from_b(shape, num_k_dims, b);
635+
dot_kernel kernel = get_dot_kernel(type, shape, nullptr);
636+
637+
uint32_t packed_b_id =
638+
define_pack_b(subgraph, type, kernel, num_k_dims, input_b_id);
631639

632640
ynn_node node;
633641
// We need both the original input b (for shape inference only) and packed b.

0 commit comments

Comments
 (0)