diff --git a/src/api/cpp/nixl.h b/src/api/cpp/nixl.h index 956ad4fe7..01c35ead5 100644 --- a/src/api/cpp/nixl.h +++ b/src/api/cpp/nixl.h @@ -105,6 +105,15 @@ class nixlAgent { createBackend (const nixl_backend_t &type, const nixl_b_params_t ¶ms, nixlBackendH* &backend); + /** + * @brief Get the type of a backend + * + * @param backend Backend handle + * @param type [out] Backend type + * @return nixl_status_t Error code if call was not successful + */ + nixl_status_t + getBackendType(const nixlBackendH *backend, nixl_backend_t &type) const; /** * @brief Register a memory/storage with NIXL. If a list of backends hints is provided * (via extra_params), the registration is limited to the specified backends. diff --git a/src/bindings/rust/src/agent.rs b/src/bindings/rust/src/agent.rs index 00ad05303..83b83109e 100644 --- a/src/bindings/rust/src/agent.rs +++ b/src/bindings/rust/src/agent.rs @@ -931,23 +931,45 @@ impl Agent { /// * `req` - Transfer request handle after `post_xfer_req` /// /// # Returns - /// A handle to the backend used for the transfer + /// Name of the backend used for the transfer /// /// # Errors /// Returns a NixlError if the operation fails - pub fn query_xfer_backend(&self, req: &XferRequest) -> Result { - let mut backend = std::ptr::null_mut(); + pub fn query_xfer_backend(&self, req: &XferRequest) -> Result { + let mut backend_type_ptr: *mut std::ffi::c_void = std::ptr::null_mut(); + let mut backend_type_size: usize = 0; + let inner_guard = self.inner.write().unwrap(); let status = unsafe { - nixl_capi_query_xfer_backend( + nixl_capi_query_xfer_backend_type( inner_guard.handle.as_ptr(), req.handle(), - &mut backend + &mut backend_type_ptr, + &mut backend_type_size, ) }; match status { NIXL_CAPI_SUCCESS => { - Ok(Backend{ inner: NonNull::new(backend).ok_or(NixlError::FailedToCreateBackend)? }) + if backend_type_ptr.is_null() || backend_type_size == 0 { + return Err(NixlError::BackendError); + } + + // Extract the backend type string from the binary data + let backend_type = unsafe { + let slice = std::slice::from_raw_parts( + backend_type_ptr as *const u8, + backend_type_size + ); + // Handle potential embedded nulls and convert to String + String::from_utf8_lossy(slice).to_string() + }; + + // Verify this backend type exists in our backends map + if inner_guard.backends.contains_key(&backend_type) { + Ok(backend_type) + } else { + Err(NixlError::BackendError) + } } NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam), _ => Err(NixlError::BackendError), diff --git a/src/bindings/rust/src/lib.rs b/src/bindings/rust/src/lib.rs index 21c556081..ea14313c2 100644 --- a/src/bindings/rust/src/lib.rs +++ b/src/bindings/rust/src/lib.rs @@ -71,8 +71,8 @@ use bindings::{ nixl_capi_query_resp_list_size, nixl_capi_query_resp_list_has_value, nixl_capi_query_resp_list_get_params, nixl_capi_prep_xfer_dlist, nixl_capi_release_xfer_dlist_handle, nixl_capi_make_xfer_req, nixl_capi_get_local_partial_md, - nixl_capi_send_local_partial_md, nixl_capi_query_xfer_backend, nixl_capi_opt_args_set_ip_addr, - nixl_capi_opt_args_set_port + nixl_capi_send_local_partial_md, nixl_capi_opt_args_set_ip_addr, + nixl_capi_opt_args_set_port, nixl_capi_query_xfer_backend_type }; // Re-export status codes diff --git a/src/bindings/rust/tests/tests.rs b/src/bindings/rust/tests/tests.rs index a6e60d3de..c233816dd 100644 --- a/src/bindings/rust/tests/tests.rs +++ b/src/bindings/rust/tests/tests.rs @@ -1324,10 +1324,10 @@ fn test_query_xfer_backend_success() { None ).expect("Failed to create transfer request"); // Query which backend will be used for this transfer - let result: Result = agent1.query_xfer_backend(&xfer_req); + let result: Result = agent1.query_xfer_backend(&xfer_req); assert!(result.is_ok(), "query_xfer_backend failed with error: {:?}", result.err()); - let backend = result.unwrap(); - println!("Transfer will use backend: {:?}", backend); + let backend_name = result.unwrap(); + println!("Transfer will use backend: {}", backend_name); } } #[test] diff --git a/src/bindings/rust/wrapper.cpp b/src/bindings/rust/wrapper.cpp index 14b7ba4a2..2c495536d 100644 --- a/src/bindings/rust/wrapper.cpp +++ b/src/bindings/rust/wrapper.cpp @@ -1507,6 +1507,32 @@ nixl_capi_query_xfer_backend(nixl_capi_agent_t agent, } } +nixl_capi_status_t +nixl_capi_query_xfer_backend_type(nixl_capi_agent_t agent, + nixl_capi_xfer_req_t req_hndl, + void **backend_type, + size_t *backend_type_size) { + if (!agent || !req_hndl || !backend_type || !backend_type_size) { + return NIXL_CAPI_ERROR_INVALID_PARAM; + } + + nixl_capi_backend_t backend_handle; + auto ret = nixl_capi_query_xfer_backend(agent, req_hndl, &backend_handle); + if (ret != NIXL_CAPI_SUCCESS) { + return ret; + } + + static thread_local nixl_backend_t type; + auto nixl_ret = agent->inner->getBackendType(backend_handle->backend, type); + if (nixl_ret != NIXL_SUCCESS) { + return NIXL_CAPI_ERROR_BACKEND; + } + + *backend_type = type.data(); + *backend_type_size = type.size(); + return NIXL_CAPI_SUCCESS; +} + nixl_capi_status_t nixl_capi_destroy_xfer_req(nixl_capi_xfer_req_t req) { diff --git a/src/bindings/rust/wrapper.h b/src/bindings/rust/wrapper.h index 8bac7169a..2e265259e 100644 --- a/src/bindings/rust/wrapper.h +++ b/src/bindings/rust/wrapper.h @@ -243,6 +243,12 @@ nixl_capi_query_xfer_backend(nixl_capi_agent_t agent, nixl_capi_xfer_req_t req_hndl, nixl_capi_backend_t *backend); +nixl_capi_status_t +nixl_capi_query_xfer_backend_type(nixl_capi_agent_t agent, + nixl_capi_xfer_req_t req_hndl, + void **backend_type, + size_t *backend_type_size); + nixl_capi_status_t nixl_capi_release_xfer_req(nixl_capi_agent_t agent, nixl_capi_xfer_req_t req); nixl_capi_status_t nixl_capi_destroy_xfer_req(nixl_capi_xfer_req_t req); diff --git a/src/core/nixl_agent.cpp b/src/core/nixl_agent.cpp index 43a7b0458..a80d83163 100644 --- a/src/core/nixl_agent.cpp +++ b/src/core/nixl_agent.cpp @@ -396,6 +396,22 @@ nixlAgent::createBackend(const nixl_backend_t &type, return NIXL_ERR_BACKEND; } +nixl_status_t +nixlAgent::getBackendType(const nixlBackendH *backend, nixl_backend_t &type) const { + if (!backend) { + NIXL_ERROR_FUNC << "backend handle is not provided"; + return NIXL_ERR_INVALID_PARAM; + } + + try { + type = backend->getType(); + return NIXL_SUCCESS; + } + catch (...) { + return NIXL_ERR_BACKEND; + } +} + nixl_status_t nixlAgent::queryMem(const nixl_reg_dlist_t &descs, std::vector &resp,