diff --git a/src/plugins/ucx/ucx_backend.cpp b/src/plugins/ucx/ucx_backend.cpp index 57b808f54..4dbc4fd78 100644 --- a/src/plugins/ucx/ucx_backend.cpp +++ b/src/plugins/ucx/ucx_backend.cpp @@ -1248,6 +1248,8 @@ nixl_status_t nixlUcxEngine::loadRemoteConnInfo (const std::string &remote_agent remoteConnMap.insert({remote_agent, conn}); + performConnectionEstablishment(remote_agent, conn); + return NIXL_SUCCESS; } @@ -1817,3 +1819,33 @@ nixl_status_t nixlUcxEngine::genNotif(const std::string &remote_agent, const std } return NIXL_SUCCESS; } + +void +nixlUcxEngine::performConnectionEstablishment( + const std::string &remote_agent, + const std::shared_ptr &conn) const { + NIXL_DEBUG << "Establishing connection with " << remote_agent; + + // Flush all endpoints to ensure connection establishment + // and avoid UCS_ERR_NOT_CONNECTED errors during data transfers + for (size_t i = 0; i < conn->eps.size(); ++i) { + nixlUcxReq req; + nixl_status_t ret = conn->eps[i]->flushEp(req); + + if (ret == NIXL_IN_PROG) { + nixlUcxWorker *worker = getWorker(i).get(); + do { + ret = worker->test(req); + } while (ret == NIXL_IN_PROG); + + worker->reqRelease(req); + } + + if (ret != NIXL_SUCCESS) { + NIXL_WARN << "Failed to flush endpoint " << i << " for " << remote_agent + << ", status: " << ret; + } + } + + NIXL_DEBUG << "Connection establishment completed for " << remote_agent; +} diff --git a/src/plugins/ucx/ucx_backend.h b/src/plugins/ucx/ucx_backend.h index 51d5ec423..569b38a89 100644 --- a/src/plugins/ucx/ucx_backend.h +++ b/src/plugins/ucx/ucx_backend.h @@ -287,6 +287,10 @@ class nixlUcxEngine : public nixlBackendEngine { ucx_connection_ptr_t getConnection(const std::string &remote_agent) const; + void + performConnectionEstablishment(const std::string &remote_agent, + const std::shared_ptr &conn) const; + /* UCX data */ std::unique_ptr uc; std::vector> uws; diff --git a/test/gtest/device_api/single_write_test.cu b/test/gtest/device_api/single_write_test.cu index 086d1547f..33b711671 100644 --- a/test/gtest/device_api/single_write_test.cu +++ b/test/gtest/device_api/single_write_test.cu @@ -194,19 +194,6 @@ protected: agent.registerMem(reg_list); } - void - completeWireup(size_t from_agent, size_t to_agent) { - nixl_notifs_t notifs; - nixl_status_t status = getAgent(from_agent).genNotif(getAgentName(to_agent), NOTIF_MSG); - ASSERT_EQ(status, NIXL_SUCCESS) << "Failed to complete wireup"; - - do { - nixl_status_t ret = getAgent(to_agent).getNotifs(notifs); - ASSERT_EQ(ret, NIXL_SUCCESS) << "Failed to get notifications during wireup"; - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } while (notifs.size() == 0); - } - void exchangeMD(size_t from_agent, size_t to_agent) { for (size_t i = 0; i < agents.size(); i++) { @@ -222,8 +209,6 @@ protected: EXPECT_EQ(remote_agent_name, getAgentName(i)); } } - - completeWireup(from_agent, to_agent); } void diff --git a/test/gtest/device_api/utils.cu b/test/gtest/device_api/utils.cu index acbc60313..bb1513b6e 100644 --- a/test/gtest/device_api/utils.cu +++ b/test/gtest/device_api/utils.cu @@ -100,18 +100,6 @@ void DeviceApiTestBase::registerMem(nixlAgent &agent, const std::vector makeDescList(const std::vector &buffers, nixl_mem_t mem_type); void registerMem(nixlAgent &agent, const std::vector &buffers, nixl_mem_t mem_type); - void completeWireup(size_t from_agent, size_t to_agent); + void exchangeMD(size_t from_agent, size_t to_agent); void invalidateMD();