Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/plugins/ucx/ucx_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,8 @@ nixl_status_t nixlUcxEngine::loadRemoteConnInfo (const std::string &remote_agent

remoteConnMap.insert({remote_agent, conn});

performConnectionEstablishment(remote_agent, conn);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handle possible failure


return NIXL_SUCCESS;
}

Expand Down Expand Up @@ -1817,3 +1819,33 @@ nixl_status_t nixlUcxEngine::genNotif(const std::string &remote_agent, const std
}
return NIXL_SUCCESS;
}

void
nixlUcxEngine::performConnectionEstablishment(
const std::string &remote_agent,
const std::shared_ptr<nixlUcxConnection> &conn) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const std::shared_ptr<nixlUcxConnection> &conn) const {
const nixlUcxConnection &conn) const {

https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#f7-for-general-use-take-t-or-t-arguments-rather-than-smart-pointers

Or pass just endpoints.

NIXL_DEBUG << "Establishing connection with " << remote_agent;

// Flush all endpoints to ensure connection establishment
// and avoid UCS_ERR_NOT_CONNECTED errors during data transfers
Comment on lines +1829 to +1830
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me it looks like trying to move workaround from user level to ucx backend instead of fixing UCP API, UCP EP should not return NOT_CONNECTED to avoid blocking on any level. Instead, the request should go on pending until completion as any other operation posted on UCP EP.

for (size_t i = 0; i < conn->eps.size(); ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think flush without any previous request could complete without endpoint being connected. if confirmed, we could first send dummy op, then start flush

nixlUcxReq req;
nixl_status_t ret = conn->eps[i]->flushEp(req);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since it is ucx engine api, i think it would be safer to start all the flush operations, then progress all flush requests in another loop, for the case where we would have inter-dependency somehow.


if (ret == NIXL_IN_PROG) {
nixlUcxWorker *worker = getWorker(i).get();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nixlUcxWorker & ?

do {
ret = worker->test(req);
} while (ret == NIXL_IN_PROG);

worker->reqRelease(req);
}

if (ret != NIXL_SUCCESS) {
NIXL_WARN << "Failed to flush endpoint " << i << " for " << remote_agent
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return error

<< ", status: " << ret;
}
}

NIXL_DEBUG << "Connection establishment completed for " << remote_agent;
}
4 changes: 4 additions & 0 deletions src/plugins/ucx/ucx_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,10 @@ class nixlUcxEngine : public nixlBackendEngine {
ucx_connection_ptr_t
getConnection(const std::string &remote_agent) const;

void
performConnectionEstablishment(const std::string &remote_agent,
const std::shared_ptr<nixlUcxConnection> &conn) const;

/* UCX data */
std::unique_ptr<nixlUcxContext> uc;
std::vector<std::unique_ptr<nixlUcxWorker>> uws;
Expand Down
15 changes: 0 additions & 15 deletions test/gtest/device_api/single_write_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -194,19 +194,6 @@ protected:
agent.registerMem(reg_list);
}

void
completeWireup(size_t from_agent, size_t to_agent) {
nixl_notifs_t notifs;
nixl_status_t status = getAgent(from_agent).genNotif(getAgentName(to_agent), NOTIF_MSG);
ASSERT_EQ(status, NIXL_SUCCESS) << "Failed to complete wireup";

do {
nixl_status_t ret = getAgent(to_agent).getNotifs(notifs);
ASSERT_EQ(ret, NIXL_SUCCESS) << "Failed to get notifications during wireup";
std::this_thread::sleep_for(std::chrono::milliseconds(10));
} while (notifs.size() == 0);
}

void
exchangeMD(size_t from_agent, size_t to_agent) {
for (size_t i = 0; i < agents.size(); i++) {
Expand All @@ -222,8 +209,6 @@ protected:
EXPECT_EQ(remote_agent_name, getAgentName(i));
}
}

completeWireup(from_agent, to_agent);
}

void
Expand Down
14 changes: 0 additions & 14 deletions test/gtest/device_api/utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,6 @@ void DeviceApiTestBase::registerMem(nixlAgent &agent, const std::vector<MemBuffe
agent.registerMem(reg_list);
}

void DeviceApiTestBase::completeWireup(size_t from_agent, size_t to_agent) {
nixl_notifs_t notifs;
nixl_status_t status = getAgent(from_agent).genNotif(getAgentName(to_agent), NOTIF_MSG);
ASSERT_EQ(status, NIXL_SUCCESS) << "Failed to complete wireup";

do {
nixl_status_t ret = getAgent(to_agent).getNotifs(notifs);
ASSERT_EQ(ret, NIXL_SUCCESS) << "Failed to get notifications during wireup";
std::this_thread::sleep_for(std::chrono::milliseconds(10));
} while (notifs.size() == 0);
}

void DeviceApiTestBase::exchangeMD(size_t from_agent, size_t to_agent) {
for (size_t i = 0; i < agents.size(); i++) {
nixl_blob_t md;
Expand All @@ -126,8 +114,6 @@ void DeviceApiTestBase::exchangeMD(size_t from_agent, size_t to_agent) {
EXPECT_EQ(remote_agent_name, getAgentName(i));
}
}

completeWireup(from_agent, to_agent);
}

void DeviceApiTestBase::invalidateMD() {
Expand Down
2 changes: 1 addition & 1 deletion test/gtest/device_api/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ protected:
nixlDescList<Desc> makeDescList(const std::vector<MemBuffer> &buffers, nixl_mem_t mem_type);

void registerMem(nixlAgent &agent, const std::vector<MemBuffer> &buffers, nixl_mem_t mem_type);
void completeWireup(size_t from_agent, size_t to_agent);

void exchangeMD(size_t from_agent, size_t to_agent);
void invalidateMD();

Expand Down
Loading