Skip to content

Commit dd46504

Browse files
author
ssjia
committed
Update base for Update on "[ET-VK] Implement select_at_dim_as_symint"
## Context The SDPA custom op accepts the `input_pos` (i.e. cache position) argument as a symbolic integer. The value of the symbolic integer is obtained by selecting the first element of a cache position input tensor and converting it to symint via local_scalar_dense. Currently, ET-VK handles this in a hacky manner. 1. the select + local_scalar_dense op pattern is removed, and the cache pos tensor is passed directly into the custom sdpa ops 2. Single element tensors that have users that are all select + local_scalar_dense will be interpreted as symints instead of tensors Unfortunately, this technique will not work for the huggingface implementation of transformer models, since the cache pos input tensor has not just a single element but is expected to be a vector of integer cache positions corresponding to all cache positions that will be updated. ## Changes Introduce a custom op to capture the select + local_scalar_dense op pattern, which is the proper way to handle the op pattern. Note that a custom op is needed because this op needs to access the staging buffer data of the input tensor, whereas `select` would typically be executed via a compute shader. The reason for this is because the `input_pos` value is needed to configure the sizes of attention weight tensors participating in the custom SDPA op, so the value must be set before any command buffers are dispatched. As a consequence of this change, the previous handling of select + local scalar dense can also be removed. Differential Revision: [D86340340](https://our.internmc.facebook.com/intern/diff/D86340340/) [ghstack-poisoned]
2 parents d871e1b + 2b02316 commit dd46504

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1938
-520
lines changed

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass
1414
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
1515
from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass
16+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1617
from executorch.exir.backend.utils import WhyNoPartitionReporter
1718
from executorch.exir.dialects._ops import ops as exir_ops
1819
from executorch.exir.pass_base import ExportPass
@@ -50,6 +51,15 @@ def get_view(op):
5051
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
5152

5253

54+
def get_quantization(op):
55+
"""Returns quant and dequant op of same type (per_channel/ tensor) as op if op is a dequant node, None otherwise."""
56+
if op in DQ_OPS:
57+
# Input of op can be placeholder, can't use that to get quant node directly.
58+
quant_type_index = DQ_OPS.index(op)
59+
return Q_OPS[quant_type_index], op
60+
return None
61+
62+
5363
class DecomposeMeanDimPass(ArmPass):
5464
"""
5565
Decomposes a meandim into avg_pool and/or sum + mul (1/N) depending on which dims the mean is taken for:
@@ -121,6 +131,7 @@ def call_operator(self, op, args, kwargs, meta):
121131
dims_to_reduce = [dim - 1 for dim in dims_to_reduce]
122132

123133
x = super().call_operator(view_op, (x, new_shape), {}, meta, True)
134+
x = self._maybe_insert_q_dq_after(x, meta)
124135

125136
# Reduce (h,w) dims by avg pool if possible
126137
x, dims_to_reduce = self._reduce_by_average_pool(op, x, dims_to_reduce, meta)
@@ -133,7 +144,7 @@ def call_operator(self, op, args, kwargs, meta):
133144
dims_to_reduce = [dim + len(original_dims) - 1 for dim in dims_to_reduce]
134145

135146
x = super().call_operator(view_op, (x, temp_shape), {}, meta, True)
136-
147+
x = self._maybe_insert_q_dq_after(x, meta)
137148
# Reduce remaining dims by sum
138149
x = self._reduce_by_sum(op, x, dims_to_reduce, meta, dtype)
139150

@@ -156,6 +167,45 @@ def _reduce_by_sum(self, op, input_node, dims, meta, dtype):
156167
full = super().call_operator(
157168
full_op, ([1] * len(output_shape), 1 / N), {"dtype": dtype}, meta, True
158169
)
170+
if (quant_ops := get_quantization(input_node.node.target)) is not None:
171+
# Insert Q and DQ nodes after full op.
172+
# Since the value of full is known, we can compute quant params such that dq(q_max_value)
173+
q_op, dq_op = quant_ops
174+
qmax = input_node.node.args[4]
175+
full_quant_args = (
176+
1 / (N * qmax), # Scale to map qmax to 1/N
177+
0, # Zero point
178+
*input_node.node.args[3:],
179+
)
180+
q_args = (full, *full_quant_args)
181+
full = super().call_operator(
182+
q_op,
183+
q_args,
184+
kwargs={},
185+
meta=meta,
186+
updated=True,
187+
)
188+
dq_args = (full, *full_quant_args)
189+
full = super().call_operator(
190+
dq_op, dq_args, kwargs={}, meta=meta, updated=True
191+
)
192+
193+
# Insert Q and DQ nodes after sum op.
194+
# Scale needs to be adjusted with N, since it was computed on data after the division with N.
195+
sum_quant_args = (input_node.node.args[1] * N, *input_node.node.args[2:])
196+
q_args = (sum, *sum_quant_args)
197+
sum = super().call_operator(
198+
q_op,
199+
q_args,
200+
kwargs={},
201+
meta=meta,
202+
updated=True,
203+
)
204+
dq_args = (sum, *sum_quant_args)
205+
sum = super().call_operator(
206+
dq_op, dq_args, kwargs={}, meta=meta, updated=True
207+
)
208+
159209
return super().call_operator(mul_op, (sum, full), {}, meta, True)
160210

161211
def _reduce_by_average_pool(self, op, input_node, dims, meta):
@@ -190,10 +240,38 @@ def _reduce_by_average_pool(self, op, input_node, dims, meta):
190240
)
191241

192242
if is_supported:
243+
out = super().call_operator(avgpool_op, args, {}, meta, True)
244+
out = self._maybe_insert_q_dq_after(out, meta)
193245
return (
194-
super().call_operator(avgpool_op, args, {}, meta, True),
246+
out,
195247
dims_to_reduce_by_sum,
196248
)
197249

198250
else:
199251
return input_node, dims
252+
253+
def _maybe_insert_q_dq_after(self, op, meta):
254+
"""If the input node of op is a dequant node, insert a q-dq pair after op with identical quantization parameters."""
255+
256+
if len(op.node.all_input_nodes) > 1:
257+
raise ValueError(
258+
f"Expected one input to {op.node}, got inputs {op.node.all_input_nodes}"
259+
)
260+
input_node = op.node.all_input_nodes[0]
261+
if (quant_ops := get_quantization(input_node.target)) is not None:
262+
q_op, dq_op = quant_ops
263+
quant_args = list(input_node.args[1:])
264+
q_args = (op, *quant_args)
265+
out = super().call_operator(
266+
q_op,
267+
q_args,
268+
kwargs={},
269+
meta=meta,
270+
updated=True,
271+
)
272+
dq_args = (out, *quant_args)
273+
return super().call_operator(
274+
dq_op, dq_args, kwargs={}, meta=meta, updated=True
275+
)
276+
else:
277+
return op

backends/arm/runtime/EthosUBackend.cpp

Lines changed: 172 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,8 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
326326
ET_LOG(Error, "Ethos-U invocation failed error (%d)", result);
327327
return Error::InvalidProgram;
328328
}
329-
int tensor_dim = 0, io_dim = 0;
329+
size_t tensor_bytes_total = 0;
330+
size_t io_bytes_total = 0;
330331
// Write outputs from scratch into EValue pointers
331332
for (int i = 0; i < handles.outputs->count; i++) {
332333
int tensor_count = 1, io_count = 1;
@@ -338,23 +339,39 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
338339
calculate_dimensions(
339340
tensor_out, &handles.outputs->io[i], &tensor_count, &io_count);
340341

341-
// At times the topological order of the outputs may change.
342-
// Lets instead ensure that the sum of dimensions match.
343-
tensor_dim = tensor_dim + tensor_count;
344-
io_dim = io_dim + io_count;
342+
size_t tensor_bytes = tensor_out.nbytes();
343+
size_t io_bytes = static_cast<size_t>(io_count) *
344+
static_cast<size_t>(handles.outputs->io[i].elem_size);
345+
346+
if (tensor_bytes != io_bytes) {
347+
Error status = copy_with_layout_adjustment(
348+
handles.outputs->io[i], i, output_addr, tensor_out, tensor_bytes);
349+
if (status != Error::Ok) {
350+
return status;
351+
}
352+
io_bytes_total += tensor_bytes;
353+
} else {
354+
EXECUTORCH_PROF_SCOPE(
355+
event_tracer, "+EthosUBackend::execute()handles.output.memcpy()");
345356

346-
EXECUTORCH_PROF_SCOPE(
347-
event_tracer, "+EthosUBackend::execute()handles.output.memcpy()");
357+
memcpy(
358+
tensor_out.mutable_data_ptr<char>(),
359+
static_cast<const char*>(output_addr),
360+
tensor_bytes);
361+
io_bytes_total += io_bytes;
362+
}
348363

349-
memcpy(
350-
tensor_out.mutable_data_ptr<char>(),
351-
static_cast<const char*>(output_addr),
352-
tensor_out.nbytes());
364+
// At times the topological order of the outputs may change.
365+
// Lets instead ensure that the sum of output bytes match.
366+
tensor_bytes_total += tensor_bytes;
353367
}
354-
if (tensor_dim != io_dim) {
368+
if (tensor_bytes_total != io_bytes_total) {
355369
ET_LOG(Error, "Total output tensor sizes do not match");
356370
ET_LOG(
357-
Error, "Program expects size of %d but got %d", tensor_dim, io_dim);
371+
Error,
372+
"Program expects %zu bytes but got %zu",
373+
io_bytes_total,
374+
tensor_bytes_total);
358375
return Error::InvalidProgram;
359376
}
360377
return Error::Ok;
@@ -365,6 +382,147 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
365382
}
366383

367384
private:
385+
// Copies Vela output into the ExecuTorch tensor, adjusting for padding or
386+
// packed layouts produced by the delegate.
387+
Error copy_with_layout_adjustment(
388+
const VelaIO& output_io,
389+
int output_index,
390+
const char* src,
391+
executorch::aten::Tensor& tensor_out,
392+
size_t tensor_bytes) const {
393+
const int elem_size = output_io.elem_size;
394+
if (elem_size == 0) {
395+
ET_LOG(
396+
Error, "Ethos-U output %d reports zero element size", output_index);
397+
return Error::InvalidProgram;
398+
}
399+
400+
size_t chunk_count = 1;
401+
for (int dim = 0; dim < shapeDim - 1; ++dim) {
402+
const int vela_dim = output_io.shape[dim];
403+
chunk_count *= static_cast<size_t>(vela_dim == 0 ? 1 : vela_dim);
404+
}
405+
const int last_dim = output_io.shape[shapeDim - 1];
406+
const size_t vela_chunk_elems =
407+
static_cast<size_t>(last_dim == 0 ? 1 : last_dim);
408+
const size_t vela_chunk_size =
409+
vela_chunk_elems * static_cast<size_t>(elem_size);
410+
411+
if (tensor_bytes % chunk_count != 0) {
412+
ET_LOG(
413+
Error,
414+
"Ethos-U output %d tensor bytes %zu not divisible by chunk count %zu",
415+
output_index,
416+
tensor_bytes,
417+
chunk_count);
418+
return Error::InvalidProgram;
419+
}
420+
421+
const size_t chunk_size = tensor_bytes / chunk_count;
422+
423+
// If Vela writes fewer bytes than the tensor expects we may need to
424+
// expand 4-bit data to 8-bit. Ethos-U outputs may be
425+
// packed 4-bit values but ExecuTorch tensors are at least 8-bit.
426+
if (vela_chunk_size < chunk_size) {
427+
if (chunk_size % vela_chunk_size != 0) {
428+
ET_LOG(
429+
Error,
430+
"Ethos-U output %d chunk bytes %zu not divisible by vela chunk bytes %zu",
431+
output_index,
432+
chunk_size,
433+
vela_chunk_size);
434+
return Error::InvalidProgram;
435+
}
436+
437+
const size_t expand_factor = chunk_size / vela_chunk_size;
438+
if (expand_factor == 2 && elem_size == 1 &&
439+
tensor_out.scalar_type() == ScalarType::Char) {
440+
return unpack_chunks_4bit_to_int8(
441+
reinterpret_cast<const uint8_t*>(src),
442+
tensor_out.mutable_data_ptr<int8_t>(),
443+
chunk_count,
444+
chunk_size,
445+
vela_chunk_size);
446+
}
447+
448+
ET_LOG(
449+
Error,
450+
"Ethos-U output %d expansion factor %zu with element size %d not supported",
451+
output_index,
452+
expand_factor,
453+
elem_size);
454+
return Error::InvalidProgram;
455+
}
456+
457+
return strip_delegate_padding(
458+
src,
459+
tensor_out.mutable_data_ptr<char>(),
460+
chunk_count,
461+
chunk_size,
462+
vela_chunk_size);
463+
}
464+
465+
Error unpack_chunks_4bit_to_int8(
466+
const uint8_t* src,
467+
int8_t* dest,
468+
size_t chunk_count,
469+
size_t dest_chunk_size,
470+
size_t src_chunk_size) const {
471+
const uint8_t* chunk_src = src;
472+
int8_t* chunk_dest = dest;
473+
for (size_t chunk_idx = 0; chunk_idx < chunk_count; ++chunk_idx) {
474+
unpack_single_chunk_4bit_to_int8(chunk_src, chunk_dest, src_chunk_size);
475+
chunk_src += src_chunk_size;
476+
chunk_dest += dest_chunk_size;
477+
}
478+
return Error::Ok;
479+
}
480+
481+
void unpack_single_chunk_4bit_to_int8(
482+
const uint8_t* src,
483+
int8_t* dest,
484+
size_t chunk_size) const {
485+
for (size_t byte_idx = 0; byte_idx < chunk_size; ++byte_idx) {
486+
const uint8_t packed = src[byte_idx];
487+
int8_t low = static_cast<int8_t>(packed & 0x0F);
488+
int8_t high = static_cast<int8_t>((packed >> 4) & 0x0F);
489+
if (low >= 8) {
490+
low -= 16;
491+
}
492+
if (high >= 8) {
493+
high -= 16;
494+
}
495+
dest[2 * byte_idx] = low;
496+
dest[2 * byte_idx + 1] = high;
497+
}
498+
}
499+
500+
Error strip_delegate_padding(
501+
const char* src,
502+
char* dest,
503+
size_t chunk_count,
504+
size_t dest_chunk_size,
505+
size_t src_chunk_size) const {
506+
if (dest_chunk_size > src_chunk_size) {
507+
ET_LOG(
508+
Error,
509+
"dest chunk size %zu must not exceed src chunk size %zu",
510+
dest_chunk_size,
511+
src_chunk_size);
512+
return Error::InvalidProgram;
513+
}
514+
if (src == nullptr || dest == nullptr) {
515+
ET_LOG(Error, "Ethos-U padded copy received null buffer");
516+
return Error::InvalidState;
517+
}
518+
for (size_t chunk_idx = 0; chunk_idx < chunk_count; ++chunk_idx) {
519+
memcpy(dest, src, dest_chunk_size);
520+
src += src_chunk_size;
521+
dest += dest_chunk_size;
522+
}
523+
return Error::Ok;
524+
}
525+
368526
void calculate_dimensions(
369527
const executorch::aten::Tensor tensor,
370528
VelaIO* io,
@@ -389,4 +547,4 @@ static auto registered = register_backend(backend_id);
389547

390548
} // namespace arm
391549
} // namespace backends
392-
} // namespace executorch
550+
} // namespace executorch

backends/arm/test/models/test_lstm_arm.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ def test_lstm_tosa_FP():
5151
exir_op=[],
5252
use_to_edge_transform_and_lower=True,
5353
)
54-
pipeline.change_args("run_method_and_compare_outputs", get_test_inputs(), atol=3e-1)
54+
pipeline.change_args(
55+
"run_method_and_compare_outputs", inputs=get_test_inputs(), atol=3e-1
56+
)
5557
pipeline.run()
5658

5759

@@ -64,7 +66,10 @@ def test_lstm_tosa_INT():
6466
use_to_edge_transform_and_lower=True,
6567
)
6668
pipeline.change_args(
67-
"run_method_and_compare_outputs", get_test_inputs(), atol=3e-1, qtol=1.0
69+
"run_method_and_compare_outputs",
70+
inputs=get_test_inputs(),
71+
atol=3e-1,
72+
qtol=1.0,
6873
)
6974
pipeline.run()
7075

@@ -79,7 +84,10 @@ def test_lstm_u55_INT():
7984
use_to_edge_transform_and_lower=True,
8085
)
8186
pipeline.change_args(
82-
"run_method_and_compare_outputs", get_test_inputs(), atol=3e-1, qtol=1.0
87+
"run_method_and_compare_outputs",
88+
inputs=get_test_inputs(),
89+
atol=3e-1,
90+
qtol=1.0,
8391
)
8492
pipeline.run()
8593

@@ -94,7 +102,10 @@ def test_lstm_u85_INT():
94102
use_to_edge_transform_and_lower=True,
95103
)
96104
pipeline.change_args(
97-
"run_method_and_compare_outputs", get_test_inputs(), atol=3e-1, qtol=1.0
105+
"run_method_and_compare_outputs",
106+
inputs=get_test_inputs(),
107+
atol=3e-1,
108+
qtol=1.0,
98109
)
99110
pipeline.run()
100111

0 commit comments

Comments
 (0)