From 709e7014843bb5178a74e69b6cafa8eab69bae69 Mon Sep 17 00:00:00 2001 From: Ethan Meitz <54505069+ejmeitz@users.noreply.github.com> Date: Fri, 29 Aug 2025 15:31:39 -0400 Subject: [PATCH 1/2] pass through args in binop --- src/cupynumeric/ndarray.cc | 7 ++++++- src/cupynumeric/ndarray.h | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/cupynumeric/ndarray.cc b/src/cupynumeric/ndarray.cc index cce215e3ac..f15ab79a94 100644 --- a/src/cupynumeric/ndarray.cc +++ b/src/cupynumeric/ndarray.cc @@ -459,7 +459,8 @@ void NDArray::trilu(NDArray rhs, int32_t k, bool lower) void NDArray::dot(NDArray rhs1, NDArray rhs2) { dot_MM(rhs1.get_store(), rhs2.get_store()); } -void NDArray::binary_op(int32_t op_code, NDArray rhs1, NDArray rhs2) +void NDArray::binary_op(int32_t op_code, NDArray rhs1, NDArray rhs2, + const std::vector& extra_args /*= {}*/) { if (rhs1.type() != rhs2.type()) { throw std::invalid_argument("Operands must have the same type"); @@ -482,6 +483,10 @@ void NDArray::binary_op(int32_t op_code, NDArray rhs1, NDArray rhs2) auto p_rhs2 = task.add_input(rhs2_store); task.add_scalar_arg(legate::Scalar(op_code)); + for (auto&& arg : extra_args) { + task.add_scalar_arg(arg); + } + task.add_constraint(align(p_lhs, p_rhs1)); task.add_constraint(align(p_rhs1, p_rhs2)); diff --git a/src/cupynumeric/ndarray.h b/src/cupynumeric/ndarray.h index 0433816f57..d68d780a87 100644 --- a/src/cupynumeric/ndarray.h +++ b/src/cupynumeric/ndarray.h @@ -71,7 +71,7 @@ class NDArray { public: void random(int32_t gen_code); void fill(const Scalar& value); - void binary_op(int32_t op_code, NDArray rhs1, NDArray rhs2); + void binary_op(int32_t op_code, NDArray rhs1, NDArray rhs2, const std::vector& extra_args = {}); void binary_reduction(int32_t op_code, NDArray rhs1, NDArray rhs2); void unary_op(int32_t op_code, NDArray input, const std::vector& extra_args = {}); void unary_reduction(int32_t op_code, NDArray input); From d990aa56df60c861379f6d1adf5905499342e215 Mon Sep 17 00:00:00 2001 From: Ethan Meitz <54505069+ejmeitz@users.noreply.github.com> Date: Fri, 29 Aug 2025 15:34:08 -0400 Subject: [PATCH 2/2] also do it in operators.cc --- src/cupynumeric/operators.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/cupynumeric/operators.cc b/src/cupynumeric/operators.cc index a68a3c663e..c55c1f4439 100644 --- a/src/cupynumeric/operators.cc +++ b/src/cupynumeric/operators.cc @@ -52,14 +52,15 @@ NDArray unary_reduction(UnaryRedCode op_code, NDArray input) return out; } -NDArray binary_op(BinaryOpCode op_code, NDArray rhs1, NDArray rhs2, std::optional out) +NDArray binary_op(BinaryOpCode op_code, NDArray rhs1, NDArray rhs2, std::optional out, + const std::vector& extra_args = {}) { auto runtime = CuPyNumericRuntime::get_runtime(); if (!out.has_value()) { auto out_shape = broadcast_shapes({rhs1, rhs2}); out = runtime->create_array(out_shape, rhs1.type()); } - out->binary_op(static_cast(op_code), std::move(rhs1), std::move(rhs2)); + out->binary_op(static_cast(op_code), std::move(rhs1), std::move(rhs2), extra_args); return out.value(); }