@@ -13,26 +13,34 @@ namespace extension {
13
13
namespace training {
14
14
15
15
namespace {
16
- std::string gradients_method_prefix = " __et_training_gradients_index_" ;
17
- std::string parameters_method_prefix = " __et_training_parameters_index_" ;
18
- std::string fqn_method_prefix = " __et_training_fqn_" ;
16
+
17
+ std::string make_parameters_method_name (const std::string& method_name) {
18
+ return " __et_training_parameters_index_" + method_name;
19
+ }
20
+
21
+ std::string make_gradients_method_name (const std::string& method_name) {
22
+ return " __et_training_gradients_index_" + method_name;
23
+ }
24
+
25
+ std::string make_fqn_method_name (const std::string& method_name) {
26
+ return " __et_training_fqn_" + method_name;
27
+ }
28
+
19
29
} // namespace
20
30
21
31
runtime::Result<std::vector<runtime::EValue>>
22
32
TrainingModule::execute_forward_backward (
23
33
const std::string& method_name,
24
34
const std::vector<runtime::EValue>& input) {
25
35
// Find where the user outputs end.
26
- const std::string gradients_method_name =
27
- gradients_method_prefix + method_name;
36
+ const std::string gradients_method_name = make_gradients_method_name (method_name);
28
37
auto res = executorch::extension::Module::execute (gradients_method_name);
29
38
if (!res.ok ()) {
30
39
return res.error ();
31
40
}
32
41
uint64_t grad_start = res.get ()[0 ].toInt ();
33
42
34
- const std::string parameters_method_name =
35
- parameters_method_prefix + method_name;
43
+ const std::string parameters_method_name = make_parameters_method_name (method_name);
36
44
// get params start.
37
45
auto param_res =
38
46
executorch::extension::Module::execute (parameters_method_name);
@@ -66,7 +74,7 @@ TrainingModule::execute_forward_backward(
66
74
auto & gradients_map = method_named_gradients_.at (method_name);
67
75
68
76
// Get names if we havent seen this method before.
69
- const std::string fqn_method_name = fqn_method_prefix + method_name;
77
+ const std::string fqn_method_name = make_fqn_method_name ( method_name) ;
70
78
auto fqn_res = executorch::extension::Module::execute (fqn_method_name);
71
79
if (!fqn_res.ok ()) {
72
80
return fqn_res.error ();
@@ -92,9 +100,9 @@ TrainingModule::named_parameters(const std::string& method_name) {
92
100
// If we haven't seen this method before, populate the dict.
93
101
if (method_named_parameters_.find (method_name) ==
94
102
method_named_parameters_.end ()) {
95
- const std::string fqn_method_name = fqn_method_prefix + method_name;
103
+ const std::string fqn_method_name = make_fqn_method_name ( method_name) ;
96
104
const std::string parameters_method_name =
97
- parameters_method_prefix + method_name;
105
+ make_parameters_method_name ( method_name) ;
98
106
99
107
method_named_parameters_.insert ({method_name, {}});
100
108
0 commit comments