@@ -467,6 +467,42 @@ void RegisterTransferServerTypes(nanobind::module_& m) {
467
467
.def (" connect" , [](PyTransferServer& self, const std::string& address) {
468
468
return self.Connect (address);
469
469
});
470
+ m.def (" _make_error_array" , [](jax::nb_class_ptr<jax::PyClient> py_client,
471
+ nb::object py_aval, std::string message) {
472
+ auto * ifrt_client =
473
+ llvm::dyn_cast_or_null<xla::ifrt::PjRtClient>(py_client->ifrt_client ());
474
+ if (ifrt_client == nullptr ) {
475
+ xla::ThrowIfError (absl::InvalidArgumentError (
476
+ " _pull_flat only supported on pjrt-ifrt clients." ));
477
+ }
478
+ auto aval = xla::ValueOrThrow (ArraySpecFromShapeDtypeStruct (py_aval));
479
+ auto traceback = jax::Traceback::Get ();
480
+ xla::ifrt::PjRtArray::PjRtBuffers buffers;
481
+ auto prim_type = xla::ValueOrThrow (xla::ifrt::ToPrimitiveType (aval.dtype ));
482
+ auto shards = xla::ValueOrThrow (aval.sharding ->Disassemble (
483
+ aval.shape , xla::ifrt::SingleDeviceShardSemantics::kAddressableShards ));
484
+ buffers.reserve (shards.size ());
485
+ for (auto & shard : shards) {
486
+ auto * mem_space =
487
+ xla::ValueOrThrow (MemorySpaceFromSharding (*shard.second ));
488
+ xla::PjRtClient::ShapeSpec shape_spec = {
489
+ prim_type, xla::DimensionVector (shard.first .dims ().begin (),
490
+ shard.first .dims ().end ())};
491
+ auto atm = xla::ValueOrThrow (
492
+ py_client->pjrt_client ()->CreateBuffersForAsyncHostToDevice (
493
+ {shape_spec}, std::nullopt, mem_space));
494
+
495
+ atm->SetBufferError (0 , absl::InternalError (message));
496
+ buffers.push_back (atm->RetrieveBuffer (0 ));
497
+ }
498
+ auto arr = xla::ValueOrThrow (xla::ifrt::PjRtArray::Create (
499
+ ifrt_client, aval.dtype , aval.shape , aval.sharding , std::move (buffers),
500
+ aval.layout ));
501
+ return jax::PyArray::MakeFromIfrtArrayAndSharding (
502
+ py_client, traceback, std::move (arr), py_aval.attr (" sharding" ), false ,
503
+ true ,
504
+ /* skip_checks=*/ false );
505
+ });
470
506
471
507
m.def (
472
508
" start_transfer_server" ,
0 commit comments