@@ -985,6 +985,27 @@ REACTANT_ABI void *UnsafeBufferPointer(PjRtBuffer *buffer) {
985985 return (void *)unsafe;
986986}
987987
988+ REACTANT_ABI PjRtBuffer *ArrayFromHostBuffer (PjRtClient *client, void *data,
989+ uint64_t ptype, size_t dim,
990+ const int64_t *cshape,
991+ PjRtDevice *device) {
992+ auto primtype = (xla::PrimitiveType)ptype;
993+ absl::Span<const int64_t > shape (cshape, dim);
994+ PjRtClient::HostBufferSemantics semantics =
995+ PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall ;
996+ // xla::Layout layout(col_major(dim));
997+ // auto buffer = xla::MyValueOrThrow(client->BufferFromHostBuffer(data,
998+ // primtype, shape, /*byte_strides*/{}, semantics, /*ondone*/{}, device,
999+ // &layout));
1000+ const xla::Layout *layout = nullptr ;
1001+ auto buffer = MyValueOrThrow (client->BufferFromHostBuffer (
1002+ data, primtype, shape, /* byte_strides*/ {}, semantics, /* ondone*/ {},
1003+ *device->default_memory_space (), layout));
1004+ auto bres = buffer.release ();
1005+ return bres;
1006+ }
1007+
1008+
9881009REACTANT_ABI void CopyToBuffer (PjRtClient *client, PjRtBuffer *buffer,
9891010 void *data, size_t offset, size_t size, PjRtBuffer **bufferP) {
9901011 if (buffer->IsOnCpu ()) {
@@ -998,7 +1019,7 @@ REACTANT_ABI void CopyToBuffer(PjRtClient *client, PjRtBuffer *buffer,
9981019
9991020 auto pid = client->platform_id ();
10001021 if (pid == xla::TpuId ()) {
1001- auto dims = argB ->on_device_shape ().dimensions ();
1022+ auto dims = buffer ->on_device_shape ().dimensions ();
10021023 auto buf2 = ArrayFromHostBuffer (client, data, buffer->element_type (), dims.size (), dims.data (), buffer->device ());
10031024 *bufferP = buf2;
10041025 PjRtBufferFree ((PjRtBuffer *)buffer);
@@ -1089,26 +1110,6 @@ REACTANT_ABI PjRtBuffer *UninitPJRTBuffer(PjRtClient *client,
10891110 return xbuffer.release ();
10901111}
10911112
1092- REACTANT_ABI PjRtBuffer *ArrayFromHostBuffer (PjRtClient *client, void *data,
1093- uint64_t ptype, size_t dim,
1094- int64_t *cshape,
1095- PjRtDevice *device) {
1096- auto primtype = (xla::PrimitiveType)ptype;
1097- absl::Span<const int64_t > shape (cshape, dim);
1098- PjRtClient::HostBufferSemantics semantics =
1099- PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall ;
1100- // xla::Layout layout(col_major(dim));
1101- // auto buffer = xla::MyValueOrThrow(client->BufferFromHostBuffer(data,
1102- // primtype, shape, /*byte_strides*/{}, semantics, /*ondone*/{}, device,
1103- // &layout));
1104- const xla::Layout *layout = nullptr ;
1105- auto buffer = MyValueOrThrow (client->BufferFromHostBuffer (
1106- data, primtype, shape, /* byte_strides*/ {}, semantics, /* ondone*/ {},
1107- *device->default_memory_space (), layout));
1108- auto bres = buffer.release ();
1109- return bres;
1110- }
1111-
11121113REACTANT_ABI uint8_t BufferOnCPU (PjRtBuffer *buffer) {
11131114 return buffer->IsOnCpu ();
11141115}
0 commit comments