Skip to content

Commit 4f2a5da

Browse files
committed
UCX/BACKEND: Add internal connection establishment
Signed-off-by: Michal Shalev <[email protected]>
1 parent d865179 commit 4f2a5da

File tree

5 files changed

+36
-30
lines changed

5 files changed

+36
-30
lines changed

src/plugins/ucx/ucx_backend.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,6 +1248,8 @@ nixl_status_t nixlUcxEngine::loadRemoteConnInfo (const std::string &remote_agent
12481248

12491249
remoteConnMap.insert({remote_agent, conn});
12501250

1251+
performConnectionEstablishment(remote_agent, conn);
1252+
12511253
return NIXL_SUCCESS;
12521254
}
12531255

@@ -1817,3 +1819,32 @@ nixl_status_t nixlUcxEngine::genNotif(const std::string &remote_agent, const std
18171819
}
18181820
return NIXL_SUCCESS;
18191821
}
1822+
1823+
void nixlUcxEngine::performConnectionEstablishment(const std::string &remote_agent,
1824+
const std::shared_ptr<nixlUcxConnection> &conn) const
1825+
{
1826+
NIXL_DEBUG << "Establishing connection with " << remote_agent;
1827+
1828+
// Flush all endpoints to ensure connection establishment
1829+
// and avoid UCS_ERR_NOT_CONNECTED errors during data transfers
1830+
for (size_t i = 0; i < conn->eps.size(); ++i) {
1831+
nixlUcxReq req;
1832+
nixl_status_t ret = conn->eps[i]->flushEp(req);
1833+
1834+
if (ret == NIXL_IN_PROG) {
1835+
nixlUcxWorker *worker = getWorker(i).get();
1836+
do {
1837+
ret = worker->test(req);
1838+
} while (ret == NIXL_IN_PROG);
1839+
1840+
worker->reqRelease(req);
1841+
}
1842+
1843+
if (ret != NIXL_SUCCESS) {
1844+
NIXL_WARN << "Failed to flush endpoint " << i << " for " << remote_agent
1845+
<< ", status: " << ret;
1846+
}
1847+
}
1848+
1849+
NIXL_DEBUG << "Connection establishment completed for " << remote_agent;
1850+
}

src/plugins/ucx/ucx_backend.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,10 @@ class nixlUcxEngine : public nixlBackendEngine {
287287
ucx_connection_ptr_t
288288
getConnection(const std::string &remote_agent) const;
289289

290+
void
291+
performConnectionEstablishment(const std::string &remote_agent,
292+
const std::shared_ptr<nixlUcxConnection> &conn) const;
293+
290294
/* UCX data */
291295
std::unique_ptr<nixlUcxContext> uc;
292296
std::vector<std::unique_ptr<nixlUcxWorker>> uws;

test/gtest/device_api/single_write_test.cu

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -194,19 +194,6 @@ protected:
194194
agent.registerMem(reg_list);
195195
}
196196

197-
void
198-
completeWireup(size_t from_agent, size_t to_agent) {
199-
nixl_notifs_t notifs;
200-
nixl_status_t status = getAgent(from_agent).genNotif(getAgentName(to_agent), NOTIF_MSG);
201-
ASSERT_EQ(status, NIXL_SUCCESS) << "Failed to complete wireup";
202-
203-
do {
204-
nixl_status_t ret = getAgent(to_agent).getNotifs(notifs);
205-
ASSERT_EQ(ret, NIXL_SUCCESS) << "Failed to get notifications during wireup";
206-
std::this_thread::sleep_for(std::chrono::milliseconds(10));
207-
} while (notifs.size() == 0);
208-
}
209-
210197
void
211198
exchangeMD(size_t from_agent, size_t to_agent) {
212199
for (size_t i = 0; i < agents.size(); i++) {
@@ -222,8 +209,6 @@ protected:
222209
EXPECT_EQ(remote_agent_name, getAgentName(i));
223210
}
224211
}
225-
226-
completeWireup(from_agent, to_agent);
227212
}
228213

229214
void

test/gtest/device_api/utils.cu

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -100,18 +100,6 @@ void DeviceApiTestBase::registerMem(nixlAgent &agent, const std::vector<MemBuffe
100100
agent.registerMem(reg_list);
101101
}
102102

103-
void DeviceApiTestBase::completeWireup(size_t from_agent, size_t to_agent) {
104-
nixl_notifs_t notifs;
105-
nixl_status_t status = getAgent(from_agent).genNotif(getAgentName(to_agent), NOTIF_MSG);
106-
ASSERT_EQ(status, NIXL_SUCCESS) << "Failed to complete wireup";
107-
108-
do {
109-
nixl_status_t ret = getAgent(to_agent).getNotifs(notifs);
110-
ASSERT_EQ(ret, NIXL_SUCCESS) << "Failed to get notifications during wireup";
111-
std::this_thread::sleep_for(std::chrono::milliseconds(10));
112-
} while (notifs.size() == 0);
113-
}
114-
115103
void DeviceApiTestBase::exchangeMD(size_t from_agent, size_t to_agent) {
116104
for (size_t i = 0; i < agents.size(); i++) {
117105
nixl_blob_t md;
@@ -126,8 +114,6 @@ void DeviceApiTestBase::exchangeMD(size_t from_agent, size_t to_agent) {
126114
EXPECT_EQ(remote_agent_name, getAgentName(i));
127115
}
128116
}
129-
130-
completeWireup(from_agent, to_agent);
131117
}
132118

133119
void DeviceApiTestBase::invalidateMD() {

test/gtest/device_api/utils.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ protected:
139139
nixlDescList<Desc> makeDescList(const std::vector<MemBuffer> &buffers, nixl_mem_t mem_type);
140140

141141
void registerMem(nixlAgent &agent, const std::vector<MemBuffer> &buffers, nixl_mem_t mem_type);
142-
void completeWireup(size_t from_agent, size_t to_agent);
142+
143143
void exchangeMD(size_t from_agent, size_t to_agent);
144144
void invalidateMD();
145145

0 commit comments

Comments
 (0)