Skip to content

Commit dd0117a

Browse files
committed
also reverse
1 parent 9c80cf8 commit dd0117a

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,6 +1020,7 @@ REACTANT_ABI void CopyToBuffer(PjRtClient *client, PjRtBuffer *buffer,
10201020
auto pid = client->platform_id();
10211021
if (pid == xla::TpuId()) {
10221022
auto dims = buffer->on_device_shape().dimensions();
1023+
// TODO: note this assume that we want to copy the entire buffer size.
10231024
auto buf2 = ArrayFromHostBuffer(client, data, buffer->element_type(), dims.size(), dims.data(), buffer->device());
10241025
*bufferP = buf2;
10251026
PjRtBufferFree((PjRtBuffer *)buffer);
@@ -1063,6 +1064,14 @@ REACTANT_ABI void CopyToBuffer(PjRtClient *client, PjRtBuffer *buffer,
10631064

10641065
REACTANT_ABI void CopyFromBuffer(PjRtClient *client, PjRtBuffer *buffer,
10651066
void *data, size_t offset, size_t size, PjRtBuffer **bufferP) {
1067+
1068+
auto pid = client->platform_id();
1069+
if (pid == xla::TpuId()) {
1070+
// TODO: note this assume that we want to copy the entire buffer size.
1071+
BufferToHost(buffer, data);
1072+
return;
1073+
}
1074+
10661075
auto future = buffer->CopyRawToHost(data, offset, size);
10671076
future.Await();
10681077
#if 0

0 commit comments

Comments
 (0)