Skip to content

Commit 6b44728

Browse files
committed
fix global-constructors
1 parent 1f885b9 commit 6b44728

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

extension/training/module/training_module.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,34 @@ namespace extension {
1313
namespace training {
1414

1515
namespace {
16-
std::string gradients_method_prefix = "__et_training_gradients_index_";
17-
std::string parameters_method_prefix = "__et_training_parameters_index_";
18-
std::string fqn_method_prefix = "__et_training_fqn_";
16+
17+
std::string make_parameters_method_name(const std::string& method_name) {
18+
return "__et_training_parameters_index_" + method_name;
19+
}
20+
21+
std::string make_gradients_method_name(const std::string& method_name) {
22+
return "__et_training_gradients_index_" + method_name;
23+
}
24+
25+
std::string make_fqn_method_name(const std::string& method_name) {
26+
return "__et_training_fqn_" + method_name;
27+
}
28+
1929
} // namespace
2030

2131
runtime::Result<std::vector<runtime::EValue>>
2232
TrainingModule::execute_forward_backward(
2333
const std::string& method_name,
2434
const std::vector<runtime::EValue>& input) {
2535
// Find where the user outputs end.
26-
const std::string gradients_method_name =
27-
gradients_method_prefix + method_name;
36+
const std::string gradients_method_name = make_gradients_method_name(method_name);
2837
auto res = executorch::extension::Module::execute(gradients_method_name);
2938
if (!res.ok()) {
3039
return res.error();
3140
}
3241
uint64_t grad_start = res.get()[0].toInt();
3342

34-
const std::string parameters_method_name =
35-
parameters_method_prefix + method_name;
43+
const std::string parameters_method_name = make_parameters_method_name(method_name);
3644
// get params start.
3745
auto param_res =
3846
executorch::extension::Module::execute(parameters_method_name);
@@ -66,7 +74,7 @@ TrainingModule::execute_forward_backward(
6674
auto& gradients_map = method_named_gradients_.at(method_name);
6775

6876
// Get names if we havent seen this method before.
69-
const std::string fqn_method_name = fqn_method_prefix + method_name;
77+
const std::string fqn_method_name = make_fqn_method_name(method_name);
7078
auto fqn_res = executorch::extension::Module::execute(fqn_method_name);
7179
if (!fqn_res.ok()) {
7280
return fqn_res.error();
@@ -92,9 +100,9 @@ TrainingModule::named_parameters(const std::string& method_name) {
92100
// If we haven't seen this method before, populate the dict.
93101
if (method_named_parameters_.find(method_name) ==
94102
method_named_parameters_.end()) {
95-
const std::string fqn_method_name = fqn_method_prefix + method_name;
103+
const std::string fqn_method_name = make_fqn_method_name(method_name);
96104
const std::string parameters_method_name =
97-
parameters_method_prefix + method_name;
105+
make_parameters_method_name(method_name);
98106

99107
method_named_parameters_.insert({method_name, {}});
100108

0 commit comments

Comments
 (0)