diff --git a/extension/training/module/training_module.cpp b/extension/training/module/training_module.cpp index 51140c14e32..4dbaaf3fcfb 100644 --- a/extension/training/module/training_module.cpp +++ b/extension/training/module/training_module.cpp @@ -13,9 +13,19 @@ namespace extension { namespace training { namespace { -std::string gradients_method_prefix = "__et_training_gradients_index_"; -std::string parameters_method_prefix = "__et_training_parameters_index_"; -std::string fqn_method_prefix = "__et_training_fqn_"; + +std::string make_parameters_method_name(const std::string& method_name) { + return "__et_training_parameters_index_" + method_name; +} + +std::string make_gradients_method_name(const std::string& method_name) { + return "__et_training_gradients_index_" + method_name; +} + +std::string make_fqn_method_name(const std::string& method_name) { + return "__et_training_fqn_" + method_name; +} + } // namespace runtime::Result> @@ -24,7 +34,7 @@ TrainingModule::execute_forward_backward( const std::vector& input) { // Find where the user outputs end. const std::string gradients_method_name = - gradients_method_prefix + method_name; + make_gradients_method_name(method_name); auto res = executorch::extension::Module::execute(gradients_method_name); if (!res.ok()) { return res.error(); @@ -32,7 +42,7 @@ TrainingModule::execute_forward_backward( uint64_t grad_start = res.get()[0].toInt(); const std::string parameters_method_name = - parameters_method_prefix + method_name; + make_parameters_method_name(method_name); // get params start. auto param_res = executorch::extension::Module::execute(parameters_method_name); @@ -66,7 +76,7 @@ TrainingModule::execute_forward_backward( auto& gradients_map = method_named_gradients_.at(method_name); // Get names if we havent seen this method before. - const std::string fqn_method_name = fqn_method_prefix + method_name; + const std::string fqn_method_name = make_fqn_method_name(method_name); auto fqn_res = executorch::extension::Module::execute(fqn_method_name); if (!fqn_res.ok()) { return fqn_res.error(); @@ -92,9 +102,9 @@ TrainingModule::named_parameters(const std::string& method_name) { // If we haven't seen this method before, populate the dict. if (method_named_parameters_.find(method_name) == method_named_parameters_.end()) { - const std::string fqn_method_name = fqn_method_prefix + method_name; + const std::string fqn_method_name = make_fqn_method_name(method_name); const std::string parameters_method_name = - parameters_method_prefix + method_name; + make_parameters_method_name(method_name); method_named_parameters_.insert({method_name, {}});