@@ -296,7 +296,7 @@ auto make_dot_impl(dot_type type, size_t num_k_dims) {
296296
297297// Make a kernel wrapper for packing the input of a dot kernel, i.e.
298298// interleaving `tile_k` rows at a time.
299- auto make_pack_impl (dot_type type ) {
299+ auto make_pack_impl () {
300300 return [](slinky::raw_buffer input, slinky::raw_buffer output) -> index_t {
301301 const slinky::dim& input_n = input.dim (0 );
302302 const slinky::dim& input_k = input.dim (1 );
@@ -389,24 +389,13 @@ auto make_pack_impl(dot_type type) {
389389 };
390390}
391391
392- std::optional<size_t > get_extent (const ynn_value& x, int dim) {
393- return dim < x.extents .size () ? as_constant (x.extents [dim]) : 1 ;
394- }
395-
396- void learn_shape_from_b (dot_shape& shape, size_t num_k_dims,
397- const ynn_value& b) {
398- shape.n = get_extent (b, 0 );
399- shape.k1 = get_extent (b, 1 );
400- shape.k2 = num_k_dims >= 2 ? get_extent (b, 2 ) : 1 ;
401- shape.k3 = num_k_dims >= 3 ? get_extent (b, 3 ) : 1 ;
402- }
403-
404392// Packing means transposing
405393// b(n, k, ...) => b(k%tile_k, n%nr, k/tile_k, n/tile_n, ...)
406394// where tile_n is a multiple of the kernel's tile_n, but not greater than the
407395// kernel's block_n.
408396uint32_t define_pack_b (ynn_subgraph_t subgraph, const dot_type& type,
409- size_t num_k_dims, uint32_t input_b_id) {
397+ const dot_kernel& kernel, size_t num_k_dims,
398+ uint32_t input_b_id) {
410399 const ynn_value& b = subgraph->value (input_b_id);
411400
412401 ynn_value& packed_b = subgraph->new_internal_value ();
@@ -420,10 +409,6 @@ uint32_t define_pack_b(ynn_subgraph_t subgraph, const dot_type& type,
420409 slinky::expr k3 =
421410 3 < b.extents .size () && b.extents [3 ].defined () ? b.extents [3 ] : 1 ;
422411
423- dot_shape shape;
424- learn_shape_from_b (shape, num_k_dims, b);
425- dot_kernel kernel = get_dot_kernel (type, shape);
426-
427412 const index_t cache_elements = cache_size_l2 / type_size_bytes (b.type );
428413
429414 // How many blocks of N fit in the cache?
@@ -448,7 +433,7 @@ uint32_t define_pack_b(ynn_subgraph_t subgraph, const dot_type& type,
448433 node.inputs = {input_b_id};
449434 node.outputs = {packed_b_id};
450435 node.op = ynn_node::pack_b{};
451- node.create = [type ](const ynn_node& node, ynn_runtime& runtime) {
436+ node.create = [](const ynn_node& node, ynn_runtime& runtime) {
452437 const ynn_runtime_value& input = runtime.value (node.inputs [0 ]);
453438 ynn_runtime_value& output = runtime.value (node.outputs [0 ]);
454439
@@ -482,9 +467,8 @@ uint32_t define_pack_b(ynn_subgraph_t subgraph, const dot_type& type,
482467
483468 slinky::call_stmt::attributes attrs;
484469 attrs.name = " pack_b" ;
485- auto func =
486- slinky::func::make (make_pack_impl (type), {std::move (func_input)},
487- {{output.buffer , dims}}, std::move (attrs));
470+ auto func = slinky::func::make (make_pack_impl (), {std::move (func_input)},
471+ {{output.buffer , dims}}, std::move (attrs));
488472
489473 auto sched = std::make_unique<scheduling_info>();
490474
@@ -570,15 +554,20 @@ std::tuple<slinky::expr, slinky::expr> choose_split_factors(
570554 return {split_n, split_m};
571555}
572556
573- } // namespace
557+ std::optional<size_t > get_extent (const ynn_value& x, int dim) {
558+ return dim < x.extents .size () ? as_constant (x.extents [dim]) : 1 ;
559+ }
574560
575- extern " C" {
561+ void learn_shape_from_b (dot_shape& shape, size_t num_k_dims,
562+ const ynn_value& b) {
563+ shape.n = get_extent (b, 0 );
564+ shape.k1 = get_extent (b, 1 );
565+ shape.k2 = num_k_dims >= 2 ? get_extent (b, 2 ) : 1 ;
566+ shape.k3 = num_k_dims >= 3 ? get_extent (b, 3 ) : 1 ;
567+ }
576568
577- ynn_status ynn_define_dot (ynn_subgraph_t subgraph, size_t num_k_dims,
578- uint32_t input_a_id, uint32_t input_b_id,
579- uint32_t input_c_id, uint32_t * output_id,
580- uint32_t flags) {
581- const ynn_node* b_producer = subgraph->get_producer (input_b_id);
569+ ynn_status always_alias_transpose (ynn_subgraph& subgraph, uint32_t & id) {
570+ const ynn_node* b_producer = subgraph.get_producer (id);
582571 if (b_producer && std::get_if<ynn_node::static_transpose>(&b_producer->op )) {
583572 // The producer of this pack is a transpose. If it is transposing the rows
584573 // and columns of B, we can handle it with packing.
@@ -592,16 +581,26 @@ ynn_status ynn_define_dot(ynn_subgraph_t subgraph, size_t num_k_dims,
592581 // We don't rewrite the existing transpose op in the (unlikely) event that
593582 // it is used elsewhere. The existing transpose op will likely be
594583 // invalidated as a dead operation.
595- input_b_id = YNN_INVALID_VALUE_ID;
596- ynn_status status = define_static_transpose (
597- subgraph, op. permutation , b_producer->inputs [0 ], &input_b_id ,
598- /* alias=*/ true );
584+ id = YNN_INVALID_VALUE_ID;
585+ ynn_status status = define_static_transpose (&subgraph, op. permutation ,
586+ b_producer->inputs [0 ], &id ,
587+ /* alias=*/ true );
599588 if (status != ynn_status_success) {
600589 return status;
601590 }
602591 }
603592 }
593+ return ynn_status_success;
594+ }
595+
596+ } // namespace
604597
598+ extern " C" {
599+
600+ ynn_status ynn_define_dot (ynn_subgraph_t subgraph, size_t num_k_dims,
601+ uint32_t input_a_id, uint32_t input_b_id,
602+ uint32_t input_c_id, uint32_t * output_id,
603+ uint32_t flags) {
605604 // Validate arguments.
606605 assert (subgraph);
607606 assert (subgraph->is_valid_value (input_a_id));
@@ -610,9 +609,13 @@ ynn_status ynn_define_dot(ynn_subgraph_t subgraph, size_t num_k_dims,
610609 assert (num_k_dims <= 3 );
611610 assert (num_k_dims > 0 );
612611
612+ ynn_status status = always_alias_transpose (*subgraph, input_b_id);
613+ if (status != ynn_status_success) {
614+ return status;
615+ }
616+
613617 const ynn_value& a = subgraph->value (input_a_id);
614618 const ynn_value& b = subgraph->value (input_b_id);
615-
616619 // If any input is a float, the output should be a float.
617620 const ynn_type c_type = !type_is_integral (a.type ) || !type_is_integral (b.type )
618621 ? ynn_type_fp32
@@ -627,7 +630,12 @@ ynn_status ynn_define_dot(ynn_subgraph_t subgraph, size_t num_k_dims,
627630
628631 // Insert a packing node (if necessary).
629632 dot_type type = {a.type , b.type , c.type };
630- uint32_t packed_b_id = define_pack_b (subgraph, type, num_k_dims, input_b_id);
633+ dot_shape shape;
634+ learn_shape_from_b (shape, num_k_dims, b);
635+ dot_kernel kernel = get_dot_kernel (type, shape, nullptr );
636+
637+ uint32_t packed_b_id =
638+ define_pack_b (subgraph, type, kernel, num_k_dims, input_b_id);
631639
632640 ynn_node node;
633641 // We need both the original input b (for shape inference only) and packed b.
0 commit comments