Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/cupynumeric/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,8 @@ void NDArray::trilu(NDArray rhs, int32_t k, bool lower)

void NDArray::dot(NDArray rhs1, NDArray rhs2) { dot_MM(rhs1.get_store(), rhs2.get_store()); }

void NDArray::binary_op(int32_t op_code, NDArray rhs1, NDArray rhs2)
void NDArray::binary_op(int32_t op_code, NDArray rhs1, NDArray rhs2,
const std::vector<legate::Scalar>& extra_args /*= {}*/)
{
if (rhs1.type() != rhs2.type()) {
throw std::invalid_argument("Operands must have the same type");
Expand All @@ -482,6 +483,10 @@ void NDArray::binary_op(int32_t op_code, NDArray rhs1, NDArray rhs2)
auto p_rhs2 = task.add_input(rhs2_store);
task.add_scalar_arg(legate::Scalar(op_code));

for (auto&& arg : extra_args) {
task.add_scalar_arg(arg);
}

task.add_constraint(align(p_lhs, p_rhs1));
task.add_constraint(align(p_rhs1, p_rhs2));

Expand Down
2 changes: 1 addition & 1 deletion src/cupynumeric/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class NDArray {
public:
void random(int32_t gen_code);
void fill(const Scalar& value);
void binary_op(int32_t op_code, NDArray rhs1, NDArray rhs2);
void binary_op(int32_t op_code, NDArray rhs1, NDArray rhs2, const std::vector<legate::Scalar>& extra_args = {});
void binary_reduction(int32_t op_code, NDArray rhs1, NDArray rhs2);
void unary_op(int32_t op_code, NDArray input, const std::vector<legate::Scalar>& extra_args = {});
void unary_reduction(int32_t op_code, NDArray input);
Expand Down
5 changes: 3 additions & 2 deletions src/cupynumeric/operators.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,15 @@ NDArray unary_reduction(UnaryRedCode op_code, NDArray input)
return out;
}

NDArray binary_op(BinaryOpCode op_code, NDArray rhs1, NDArray rhs2, std::optional<NDArray> out)
NDArray binary_op(BinaryOpCode op_code, NDArray rhs1, NDArray rhs2, std::optional<NDArray> out,
const std::vector<legate::Scalar>& extra_args = {})
{
auto runtime = CuPyNumericRuntime::get_runtime();
if (!out.has_value()) {
auto out_shape = broadcast_shapes({rhs1, rhs2});
out = runtime->create_array(out_shape, rhs1.type());
}
out->binary_op(static_cast<int32_t>(op_code), std::move(rhs1), std::move(rhs2));
out->binary_op(static_cast<int32_t>(op_code), std::move(rhs1), std::move(rhs2), extra_args);
return out.value();
}

Expand Down