@@ -833,7 +833,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
833
833
// Runs for 3 iterations or 1 second and picks the best option
834
834
int pickBestTactic (MOEParallelismConfig parallelism_config, GemmToProfile gemm_to_profile)
835
835
{
836
- auto tactics = mMoERunner .getTactics ();
836
+ auto tactics = mMoERunner .getTactics (static_cast <MoeGemmId>(gemm_to_profile) );
837
837
::nvtx3::scoped_range nvtx (tensorrt_llm::common::nvtx::nextColor (),
838
838
" Tactic Profiling GEMM " + std::to_string (static_cast <int >(gemm_to_profile)));
839
839
// We save space by reusing the same workspace buffer for all tactics when doing full layer profiling. So we
@@ -925,12 +925,14 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
925
925
std::pair<int , int > setTactic (
926
926
int tactic_idx1, int tactic_idx2, MOEParallelismConfig parallelism_config, GemmToProfile gemm_to_profile)
927
927
{
928
- auto tactics = mMoERunner .getTactics ();
928
+ auto tactics1 = mMoERunner .getTactics (MoeGemmId::GEMM_1);
929
+ auto tactics2 = mMoERunner .getTactics (MoeGemmId::GEMM_2);
929
930
std::vector<std::pair<std::reference_wrapper<int >, GemmToProfile>> tactics_to_profile{
930
931
{tactic_idx1, GemmToProfile::GEMM_1}, {tactic_idx2, GemmToProfile::GEMM_2}};
931
932
for (auto & combo : tactics_to_profile)
932
933
{
933
934
auto & t = combo.first .get ();
935
+ auto & tactics = combo.second == GemmToProfile::GEMM_1 ? tactics1 : tactics2;
934
936
if (combo.second != gemm_to_profile && gemm_to_profile != GemmToProfile::LAYER)
935
937
{
936
938
t = 0 ; // Unneeded tactic, set to 0
@@ -947,7 +949,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
947
949
}
948
950
}
949
951
950
- mMoERunner .setTactic (tactics [tactic_idx1], tactics [tactic_idx2]);
952
+ mMoERunner .setTactic (tactics1 [tactic_idx1], tactics2 [tactic_idx2]);
951
953
mBestTacticGemm1 = tactic_idx1;
952
954
mBestTacticGemm2 = tactic_idx2;
953
955
return {tactic_idx1, tactic_idx2};
@@ -965,7 +967,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
965
967
auto expert_weights_size
966
968
= gemm_to_profile == GemmToProfile::GEMM_1 ? mExpertWeight1Size : mExpertWeight2Size ;
967
969
968
- auto tactics = mMoERunner .getTactics ()[tactic_idx];
970
+ auto tactics = mMoERunner .getTactics (static_cast <MoeGemmId>(gemm_to_profile) )[tactic_idx];
969
971
if (static_cast <int >(gemm_to_profile) != static_cast <int >(mGemmProfilerBackend .mGemmToProfile ))
970
972
{
971
973
throw std::runtime_error (" Configuration mismatch between mGemmProfilerBackend and runMoEPermute" );
@@ -1074,11 +1076,12 @@ void MixtureOfExpertsBenchmark<TypeTuple_>::runBenchmark(benchmark::State& state
1074
1076
}
1075
1077
if (LOG_LEVEL >= INFO)
1076
1078
{
1077
- auto tactics = mMoERunner .getTactics ();
1078
- std::cout << " Selected tactic #1: " << tactic_idx1 << " /" << tactics.size () << " \n "
1079
- << tactics[tactic_idx1].toString () << std::endl;
1080
- std::cout << " Selected tactic #2: " << tactic_idx2 << " /" << tactics.size () << " \n "
1081
- << tactics[tactic_idx2].toString () << std::endl;
1079
+ auto tactics1 = mMoERunner .getTactics (MoeGemmId::GEMM_1);
1080
+ auto tactics2 = mMoERunner .getTactics (MoeGemmId::GEMM_2);
1081
+ std::cout << " Selected tactic #1: " << tactic_idx1 << " /" << tactics1.size () << " \n "
1082
+ << tactics1[tactic_idx1].toString () << std::endl;
1083
+ std::cout << " Selected tactic #2: " << tactic_idx2 << " /" << tactics2.size () << " \n "
1084
+ << tactics2[tactic_idx2].toString () << std::endl;
1082
1085
}
1083
1086
state.counters [" tactic_idx1" ] = tactic_idx1;
1084
1087
state.counters [" tactic_idx2" ] = tactic_idx2;
0 commit comments