Skip to content

Commit bf9d731

Browse files
committed
Merge remote-tracking branch 'origin/master' into f16_igemm
2 parents 923b7f9 + 3dfc2f5 commit bf9d731

File tree

6 files changed

+1626
-1091
lines changed

6 files changed

+1626
-1091
lines changed

src/configs/unary-elementwise-config.c

Lines changed: 425 additions & 0 deletions
Large diffs are not rendered by default.

src/operators/binary-elementwise-nd.c

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#include "src/xnnpack/config-types.h"
1717
#include "src/xnnpack/config.h"
1818
#include "src/xnnpack/datatype.h"
19-
#include "src/xnnpack/hardware-config.h"
2019
#include "src/xnnpack/log.h"
2120
#include "src/xnnpack/math.h"
2221
#include "src/xnnpack/microparams.h"
@@ -265,6 +264,15 @@ enum xnn_status xnn_create_binary_elementwise_nd(
265264
return xnn_status_success;
266265
}
267266

267+
static size_t get_tile_size(xnn_operator_t op) {
268+
// Assume a default width (unrolling factor) of 32.
269+
const size_t element_tile = op->binary_elementwise_config->element_tile
270+
? op->binary_elementwise_config->element_tile
271+
: 32;
272+
return round_up(16 * 1024,
273+
element_tile << op->binary_elementwise.log2_element_size);
274+
}
275+
268276
enum xnn_status xnn_reshape_binary_elementwise_nd(xnn_operator_t op,
269277
size_t num_input1_dims,
270278
const size_t* input1_shape,
@@ -415,7 +423,6 @@ enum xnn_status xnn_reshape_binary_elementwise_nd(xnn_operator_t op,
415423
y_stride *= compressed_output_shape[i];
416424
}
417425

418-
const size_t element_tile = op->binary_elementwise_config->element_tile;
419426
if (compressed_output_shape[5] == 1) {
420427
if (compressed_output_shape[4] == 1) {
421428
if (compressed_output_shape[3] == 1) {
@@ -434,25 +441,15 @@ enum xnn_status xnn_reshape_binary_elementwise_nd(xnn_operator_t op,
434441
xnn_compute_elementwise_binary_1d_tile;
435442
op->compute[0].range[0] = compressed_output_shape[0]
436443
<< log2_element_size;
437-
size_t bytes_per_tile =
438-
xnn_init_hardware_config()->l1_data_cache_bytes;
439-
if (!bytes_per_tile) {
440-
bytes_per_tile = 1 << 15; // Default to 32k if cachesize unknown.
441-
}
442-
const size_t num_tiles =
443-
divide_round_up(compressed_output_shape[0] << log2_element_size,
444-
bytes_per_tile);
445-
const size_t tile_size = round_up(
446-
(compressed_output_shape[0] << log2_element_size) / num_tiles,
447-
element_tile << log2_element_size);
448-
op->compute[0].tile[0] = tile_size;
444+
op->compute[0].tile[0] = get_tile_size(op);
449445
} else {
450446
op->compute[0].type = xnn_parallelization_type_1d_tile_1d_dynamic;
451447
op->compute[0].task_1d_tile_1d_dynamic =
452448
(pthreadpool_task_1d_tile_1d_dynamic_t)
453449
xnn_compute_elementwise_binary_1d;
454450
op->compute[0].range[0] = compressed_output_shape[1];
455-
op->compute[0].tile[0] = 1;
451+
op->compute[0].tile[0] = divide_round_up(
452+
get_tile_size(op), op->context.elementwise_binary.elements);
456453
}
457454
} else {
458455
op->compute[0].type = xnn_parallelization_type_2d_tile_1d_dynamic;
@@ -461,7 +458,8 @@ enum xnn_status xnn_reshape_binary_elementwise_nd(xnn_operator_t op,
461458
xnn_compute_elementwise_binary_2d;
462459
op->compute[0].range[0] = compressed_output_shape[2];
463460
op->compute[0].range[1] = compressed_output_shape[1];
464-
op->compute[0].tile[0] = 1;
461+
op->compute[0].tile[0] = divide_round_up(
462+
get_tile_size(op), op->context.elementwise_binary.elements);
465463
}
466464
} else {
467465
op->compute[0].type = xnn_parallelization_type_3d_tile_2d_dynamic;

0 commit comments

Comments
 (0)