1313
1414#include "ompi_config.h"
1515#include "coll_ucc.h"
16+ #include "coll_ucc_common.h"
1617#include "coll_ucc_dtypes.h"
1718#include "ompi/mca/coll/base/coll_tags.h"
1819#include "ompi/mca/pml/pml.h"
@@ -219,7 +220,8 @@ static ucc_status_t oob_allgather(void *sbuf, void *rbuf, size_t msglen,
219220}
220221
221222
222- static int mca_coll_ucc_init_ctx () {
223+ static int mca_coll_ucc_init_ctx (ompi_communicator_t * comm )
224+ {
223225 mca_coll_ucc_component_t * cm = & mca_coll_ucc_component ;
224226 char str_buf [256 ];
225227 ompi_attribute_fn_ptr_union_t del_fn ;
@@ -270,9 +272,9 @@ static int mca_coll_ucc_init_ctx() {
270272 ctx_params .oob .allgather = oob_allgather ;
271273 ctx_params .oob .req_test = oob_allgather_test ;
272274 ctx_params .oob .req_free = oob_allgather_free ;
273- ctx_params .oob .coll_info = (void * )MPI_COMM_WORLD ;
274- ctx_params .oob .n_oob_eps = ompi_comm_size (& ompi_mpi_comm_world . comm );
275- ctx_params .oob .oob_ep = ompi_comm_rank (& ompi_mpi_comm_world . comm );
275+ ctx_params .oob .coll_info = (void * )comm ;
276+ ctx_params .oob .n_oob_eps = ompi_comm_size (comm );
277+ ctx_params .oob .oob_ep = ompi_comm_rank (comm );
276278 if (UCC_OK != ucc_context_config_read (cm -> ucc_lib , NULL , & ctx_config )) {
277279 UCC_ERROR ("UCC context config read failed" );
278280 goto cleanup_lib ;
@@ -329,7 +331,7 @@ static int mca_coll_ucc_init_ctx() {
329331 return OMPI_ERROR ;
330332}
331333
332- uint64_t rank_map_cb (uint64_t ep , void * cb_ctx )
334+ static uint64_t rank_map_cb (uint64_t ep , void * cb_ctx )
333335{
334336 struct ompi_communicator_t * comm = cb_ctx ;
335337
@@ -433,8 +435,7 @@ static int mca_coll_ucc_module_enable(mca_coll_base_module_t *module,
433435 ucc_team_params_t team_params = {
434436 .mask = UCC_TEAM_PARAM_FIELD_EP_MAP |
435437 UCC_TEAM_PARAM_FIELD_EP |
436- UCC_TEAM_PARAM_FIELD_EP_RANGE |
437- UCC_TEAM_PARAM_FIELD_ID ,
438+ UCC_TEAM_PARAM_FIELD_EP_RANGE ,
438439 .ep_map = {
439440 .type = (comm == & ompi_mpi_comm_world .comm ) ?
440441 UCC_EP_MAP_FULL : UCC_EP_MAP_CB ,
@@ -443,9 +444,12 @@ static int mca_coll_ucc_module_enable(mca_coll_base_module_t *module,
443444 .cb .cb_ctx = (void * )comm
444445 },
445446 .ep = ompi_comm_rank (comm ),
446- .ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG ,
447- .id = ompi_comm_get_local_cid (comm )
447+ .ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG
448448 };
449+ if (OMPI_COMM_IS_GLOBAL_INDEX (comm )) {
450+ team_params .mask |= UCC_TEAM_PARAM_FIELD_ID ;
451+ team_params .id = ompi_comm_get_local_cid (comm );
452+ }
449453 UCC_VERBOSE (2 , "creating ucc_team for comm %p, comm_id %llu, comm_size %d" ,
450454 (void * )comm , (long long unsigned )team_params .id ,
451455 ompi_comm_size (comm ));
@@ -555,7 +559,7 @@ mca_coll_ucc_comm_query(struct ompi_communicator_t *comm, int *priority)
555559 }
556560
557561 if (!cm -> libucc_initialized ) {
558- if (OMPI_SUCCESS != mca_coll_ucc_init_ctx ()) {
562+ if (OMPI_SUCCESS != mca_coll_ucc_init_ctx (comm )) {
559563 cm -> ucc_enable = 0 ;
560564 return NULL ;
561565 }
0 commit comments