@@ -13,9 +13,19 @@ namespace extension {
13
13
namespace training {
14
14
15
15
namespace {
16
- std::string gradients_method_prefix = " __et_training_gradients_index_" ;
17
- std::string parameters_method_prefix = " __et_training_parameters_index_" ;
18
- std::string fqn_method_prefix = " __et_training_fqn_" ;
16
+
17
+ std::string make_parameters_method_name (const std::string& method_name) {
18
+ return " __et_training_parameters_index_" + method_name;
19
+ }
20
+
21
+ std::string make_gradients_method_name (const std::string& method_name) {
22
+ return " __et_training_gradients_index_" + method_name;
23
+ }
24
+
25
+ std::string make_fqn_method_name (const std::string& method_name) {
26
+ return " __et_training_fqn_" + method_name;
27
+ }
28
+
19
29
} // namespace
20
30
21
31
runtime::Result<std::vector<runtime::EValue>>
@@ -24,15 +34,15 @@ TrainingModule::execute_forward_backward(
24
34
const std::vector<runtime::EValue>& input) {
25
35
// Find where the user outputs end.
26
36
const std::string gradients_method_name =
27
- gradients_method_prefix + method_name;
37
+ make_gradients_method_name ( method_name) ;
28
38
auto res = executorch::extension::Module::execute (gradients_method_name);
29
39
if (!res.ok ()) {
30
40
return res.error ();
31
41
}
32
42
uint64_t grad_start = res.get ()[0 ].toInt ();
33
43
34
44
const std::string parameters_method_name =
35
- parameters_method_prefix + method_name;
45
+ make_parameters_method_name ( method_name) ;
36
46
// get params start.
37
47
auto param_res =
38
48
executorch::extension::Module::execute (parameters_method_name);
@@ -66,7 +76,7 @@ TrainingModule::execute_forward_backward(
66
76
auto & gradients_map = method_named_gradients_.at (method_name);
67
77
68
78
// Get names if we havent seen this method before.
69
- const std::string fqn_method_name = fqn_method_prefix + method_name;
79
+ const std::string fqn_method_name = make_fqn_method_name ( method_name) ;
70
80
auto fqn_res = executorch::extension::Module::execute (fqn_method_name);
71
81
if (!fqn_res.ok ()) {
72
82
return fqn_res.error ();
@@ -92,9 +102,9 @@ TrainingModule::named_parameters(const std::string& method_name) {
92
102
// If we haven't seen this method before, populate the dict.
93
103
if (method_named_parameters_.find (method_name) ==
94
104
method_named_parameters_.end ()) {
95
- const std::string fqn_method_name = fqn_method_prefix + method_name;
105
+ const std::string fqn_method_name = make_fqn_method_name ( method_name) ;
96
106
const std::string parameters_method_name =
97
- parameters_method_prefix + method_name;
107
+ make_parameters_method_name ( method_name) ;
98
108
99
109
method_named_parameters_.insert ({method_name, {}});
100
110
0 commit comments