1616#include "src/xnnpack/config-types.h"
1717#include "src/xnnpack/config.h"
1818#include "src/xnnpack/datatype.h"
19- #include "src/xnnpack/hardware-config.h"
2019#include "src/xnnpack/log.h"
2120#include "src/xnnpack/math.h"
2221#include "src/xnnpack/microparams.h"
@@ -265,6 +264,15 @@ enum xnn_status xnn_create_binary_elementwise_nd(
265264 return xnn_status_success ;
266265}
267266
267+ static size_t get_tile_size (xnn_operator_t op ) {
268+ // Assume a default width (unrolling factor) of 32.
269+ const size_t element_tile = op -> binary_elementwise_config -> element_tile
270+ ? op -> binary_elementwise_config -> element_tile
271+ : 32 ;
272+ return round_up (16 * 1024 ,
273+ element_tile << op -> binary_elementwise .log2_element_size );
274+ }
275+
268276enum xnn_status xnn_reshape_binary_elementwise_nd (xnn_operator_t op ,
269277 size_t num_input1_dims ,
270278 const size_t * input1_shape ,
@@ -415,7 +423,6 @@ enum xnn_status xnn_reshape_binary_elementwise_nd(xnn_operator_t op,
415423 y_stride *= compressed_output_shape [i ];
416424 }
417425
418- const size_t element_tile = op -> binary_elementwise_config -> element_tile ;
419426 if (compressed_output_shape [5 ] == 1 ) {
420427 if (compressed_output_shape [4 ] == 1 ) {
421428 if (compressed_output_shape [3 ] == 1 ) {
@@ -434,25 +441,15 @@ enum xnn_status xnn_reshape_binary_elementwise_nd(xnn_operator_t op,
434441 xnn_compute_elementwise_binary_1d_tile ;
435442 op -> compute [0 ].range [0 ] = compressed_output_shape [0 ]
436443 << log2_element_size ;
437- size_t bytes_per_tile =
438- xnn_init_hardware_config ()-> l1_data_cache_bytes ;
439- if (!bytes_per_tile ) {
440- bytes_per_tile = 1 << 15 ; // Default to 32k if cachesize unknown.
441- }
442- const size_t num_tiles =
443- divide_round_up (compressed_output_shape [0 ] << log2_element_size ,
444- bytes_per_tile );
445- const size_t tile_size = round_up (
446- (compressed_output_shape [0 ] << log2_element_size ) / num_tiles ,
447- element_tile << log2_element_size );
448- op -> compute [0 ].tile [0 ] = tile_size ;
444+ op -> compute [0 ].tile [0 ] = get_tile_size (op );
449445 } else {
450446 op -> compute [0 ].type = xnn_parallelization_type_1d_tile_1d_dynamic ;
451447 op -> compute [0 ].task_1d_tile_1d_dynamic =
452448 (pthreadpool_task_1d_tile_1d_dynamic_t )
453449 xnn_compute_elementwise_binary_1d ;
454450 op -> compute [0 ].range [0 ] = compressed_output_shape [1 ];
455- op -> compute [0 ].tile [0 ] = 1 ;
451+ op -> compute [0 ].tile [0 ] = divide_round_up (
452+ get_tile_size (op ), op -> context .elementwise_binary .elements );
456453 }
457454 } else {
458455 op -> compute [0 ].type = xnn_parallelization_type_2d_tile_1d_dynamic ;
@@ -461,7 +458,8 @@ enum xnn_status xnn_reshape_binary_elementwise_nd(xnn_operator_t op,
461458 xnn_compute_elementwise_binary_2d ;
462459 op -> compute [0 ].range [0 ] = compressed_output_shape [2 ];
463460 op -> compute [0 ].range [1 ] = compressed_output_shape [1 ];
464- op -> compute [0 ].tile [0 ] = 1 ;
461+ op -> compute [0 ].tile [0 ] = divide_round_up (
462+ get_tile_size (op ), op -> context .elementwise_binary .elements );
465463 }
466464 } else {
467465 op -> compute [0 ].type = xnn_parallelization_type_3d_tile_2d_dynamic ;
0 commit comments