|
59 | 59 | #include "xla/tsl/profiler/rpc/client/capture_profile.h" |
60 | 60 | #include "xla/tsl/profiler/rpc/profiler_server.h" |
61 | 61 | #include "xla/python/profiler_utils.h" |
| 62 | +#include "tsl/platform/init_main.h" |
62 | 63 |
|
63 | 64 | #include "xla/python/ifrt/hlo/hlo_program.h" |
64 | 65 | #include "llvm/ExecutionEngine/ExecutionEngine.h" |
@@ -205,7 +206,11 @@ T *unwrap_absl_statusor(absl::StatusOr<T> status, char **error_msg) { |
205 | 206 | // int xla::_LayoutProto_default_instance_; |
206 | 207 |
|
207 | 208 | extern "C" void InitializeLogs() { |
208 | | - absl::InitializeLog(); |
| 209 | + const char* binary = "julia"; |
| 210 | + int argc = 1; |
| 211 | + char* argv[] = {(char*)binary}; |
| 212 | + char** argv2 = &argv[0]; |
| 213 | + tsl::port::InitMain(binary, &argc, &argv2); |
209 | 214 | LLVMInitializeX86Target(); |
210 | 215 | LLVMInitializeX86TargetInfo(); |
211 | 216 | LLVMInitializeX86TargetMC(); |
@@ -668,7 +673,9 @@ extern "C" xla::PjRtLoadedExecutable *ClientCompile(PjRtClient *client, |
668 | 673 | options.executable_build_options.set_device_assignment(device_assignment); |
669 | 674 |
|
670 | 675 | // https://github.com/openxla/xla/blob/b3c641b05692f3712fb3c272e38665fdfa28bdf8/xla/python/py_client.cc#L460 |
671 | | - xla::ExportShardyForHloRoundTrip(cmodop); |
| 676 | + auto status = xla::ExportShardyForHloRoundTrip(cmodop); |
| 677 | + if (!status.ok()) |
| 678 | + ReactantThrowError(status.ToString().c_str()); |
672 | 679 | } else { |
673 | 680 | assert(device_id >= 0); |
674 | 681 |
|
@@ -867,8 +874,6 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int op_args_len, |
867 | 874 | uint8_t *is_arg_donatable, |
868 | 875 | int num_results, PjRtBuffer **op_results, |
869 | 876 | uint8_t *futures, FutureType **future_results) { |
870 | | - auto client = exec->client(); |
871 | | - |
872 | 877 | // Ensure argument_handles is structured as num_mesh_ids x num_args |
873 | 878 | std::vector<std::vector<PjRtBuffer *>> argument_handles(num_mesh_ids); |
874 | 879 | int num_args = op_args_len / num_mesh_ids; |
|
0 commit comments