Skip to content

Commit c56b8bc

Browse files
pschuhGoogle-ML-Automation
authored andcommitted
Add make_error_array to the transfer lib experimental module.
PiperOrigin-RevId: 801054479
1 parent 545bb18 commit c56b8bc

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

jax/experimental/transfer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,8 @@ def await_pull(self, uuid: int, arrays: Any) -> Any:
7070
TransferServer = use_cpp_class(_xc._xla.TransferServer)(TransferServer)
7171

7272
start_transfer_server = _xc._xla.start_transfer_server
73+
if hasattr(_xc._xla, "_make_error_array"):
74+
75+
def make_error_array(aval, message):
76+
backend = next(iter(aval.sharding.device_set)).client
77+
return _xc._xla._make_error_array(backend, aval, str(message))

jaxlib/py_socket_transfer.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,42 @@ void RegisterTransferServerTypes(nanobind::module_& m) {
467467
.def("connect", [](PyTransferServer& self, const std::string& address) {
468468
return self.Connect(address);
469469
});
470+
m.def("_make_error_array", [](jax::nb_class_ptr<jax::PyClient> py_client,
471+
nb::object py_aval, std::string message) {
472+
auto* ifrt_client =
473+
llvm::dyn_cast_or_null<xla::ifrt::PjRtClient>(py_client->ifrt_client());
474+
if (ifrt_client == nullptr) {
475+
xla::ThrowIfError(absl::InvalidArgumentError(
476+
"_pull_flat only supported on pjrt-ifrt clients."));
477+
}
478+
auto aval = xla::ValueOrThrow(ArraySpecFromShapeDtypeStruct(py_aval));
479+
auto traceback = jax::Traceback::Get();
480+
xla::ifrt::PjRtArray::PjRtBuffers buffers;
481+
auto prim_type = xla::ValueOrThrow(xla::ifrt::ToPrimitiveType(aval.dtype));
482+
auto shards = xla::ValueOrThrow(aval.sharding->Disassemble(
483+
aval.shape, xla::ifrt::SingleDeviceShardSemantics::kAddressableShards));
484+
buffers.reserve(shards.size());
485+
for (auto& shard : shards) {
486+
auto* mem_space =
487+
xla::ValueOrThrow(MemorySpaceFromSharding(*shard.second));
488+
xla::PjRtClient::ShapeSpec shape_spec = {
489+
prim_type, xla::DimensionVector(shard.first.dims().begin(),
490+
shard.first.dims().end())};
491+
auto atm = xla::ValueOrThrow(
492+
py_client->pjrt_client()->CreateBuffersForAsyncHostToDevice(
493+
{shape_spec}, std::nullopt, mem_space));
494+
495+
atm->SetBufferError(0, absl::InternalError(message));
496+
buffers.push_back(atm->RetrieveBuffer(0));
497+
}
498+
auto arr = xla::ValueOrThrow(xla::ifrt::PjRtArray::Create(
499+
ifrt_client, aval.dtype, aval.shape, aval.sharding, std::move(buffers),
500+
aval.layout));
501+
return jax::PyArray::MakeFromIfrtArrayAndSharding(
502+
py_client, traceback, std::move(arr), py_aval.attr("sharding"), false,
503+
true,
504+
/*skip_checks=*/false);
505+
});
470506

471507
m.def(
472508
"start_transfer_server",

0 commit comments

Comments
 (0)