Skip to content

Commit 9c80cf8

Browse files
committed
fix
1 parent dab67dc commit 9c80cf8

File tree

1 file changed

+22
-21
lines changed

1 file changed

+22
-21
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,27 @@ REACTANT_ABI void *UnsafeBufferPointer(PjRtBuffer *buffer) {
985985
return (void *)unsafe;
986986
}
987987

988+
REACTANT_ABI PjRtBuffer *ArrayFromHostBuffer(PjRtClient *client, void *data,
989+
uint64_t ptype, size_t dim,
990+
const int64_t *cshape,
991+
PjRtDevice *device) {
992+
auto primtype = (xla::PrimitiveType)ptype;
993+
absl::Span<const int64_t> shape(cshape, dim);
994+
PjRtClient::HostBufferSemantics semantics =
995+
PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall;
996+
// xla::Layout layout(col_major(dim));
997+
// auto buffer = xla::MyValueOrThrow(client->BufferFromHostBuffer(data,
998+
// primtype, shape, /*byte_strides*/{}, semantics, /*ondone*/{}, device,
999+
// &layout));
1000+
const xla::Layout *layout = nullptr;
1001+
auto buffer = MyValueOrThrow(client->BufferFromHostBuffer(
1002+
data, primtype, shape, /*byte_strides*/ {}, semantics, /*ondone*/ {},
1003+
*device->default_memory_space(), layout));
1004+
auto bres = buffer.release();
1005+
return bres;
1006+
}
1007+
1008+
9881009
REACTANT_ABI void CopyToBuffer(PjRtClient *client, PjRtBuffer *buffer,
9891010
void *data, size_t offset, size_t size, PjRtBuffer **bufferP) {
9901011
if (buffer->IsOnCpu()) {
@@ -998,7 +1019,7 @@ REACTANT_ABI void CopyToBuffer(PjRtClient *client, PjRtBuffer *buffer,
9981019

9991020
auto pid = client->platform_id();
10001021
if (pid == xla::TpuId()) {
1001-
auto dims = argB->on_device_shape().dimensions();
1022+
auto dims = buffer->on_device_shape().dimensions();
10021023
auto buf2 = ArrayFromHostBuffer(client, data, buffer->element_type(), dims.size(), dims.data(), buffer->device());
10031024
*bufferP = buf2;
10041025
PjRtBufferFree((PjRtBuffer *)buffer);
@@ -1089,26 +1110,6 @@ REACTANT_ABI PjRtBuffer *UninitPJRTBuffer(PjRtClient *client,
10891110
return xbuffer.release();
10901111
}
10911112

1092-
REACTANT_ABI PjRtBuffer *ArrayFromHostBuffer(PjRtClient *client, void *data,
1093-
uint64_t ptype, size_t dim,
1094-
int64_t *cshape,
1095-
PjRtDevice *device) {
1096-
auto primtype = (xla::PrimitiveType)ptype;
1097-
absl::Span<const int64_t> shape(cshape, dim);
1098-
PjRtClient::HostBufferSemantics semantics =
1099-
PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall;
1100-
// xla::Layout layout(col_major(dim));
1101-
// auto buffer = xla::MyValueOrThrow(client->BufferFromHostBuffer(data,
1102-
// primtype, shape, /*byte_strides*/{}, semantics, /*ondone*/{}, device,
1103-
// &layout));
1104-
const xla::Layout *layout = nullptr;
1105-
auto buffer = MyValueOrThrow(client->BufferFromHostBuffer(
1106-
data, primtype, shape, /*byte_strides*/ {}, semantics, /*ondone*/ {},
1107-
*device->default_memory_space(), layout));
1108-
auto bres = buffer.release();
1109-
return bres;
1110-
}
1111-
11121113
REACTANT_ABI uint8_t BufferOnCPU(PjRtBuffer *buffer) {
11131114
return buffer->IsOnCpu();
11141115
}

0 commit comments

Comments
 (0)