diff --git a/pypesto/petab/importer.py b/pypesto/petab/importer.py index 6160247dd..b36e5b488 100644 --- a/pypesto/petab/importer.py +++ b/pypesto/petab/importer.py @@ -466,6 +466,21 @@ def create_objective( edatas = self.create_edatas( model=model, simulation_conditions=simulation_conditions ) + else: + simulation_conditions = pd.DataFrame( + [ + { + PREEQUILIBRATION_CONDITION_ID: edata.id.split("+")[0] + if "+" in edata.id + else "", + # why is this not CONDITION_SEP ¯\_(ツ)_/¯? + SIMULATION_CONDITION_ID: edata.id.split("+")[1] + if "+" in edata.id + else edata.id, + } + for edata in edatas + ] + ) parameter_mapping = ( amici.petab.parameter_mapping.create_parameter_mapping( diff --git a/test/petab/test_amici_objective.py b/test/petab/test_amici_objective.py index 00c399f6e..4b59de78e 100644 --- a/test/petab/test_amici_objective.py +++ b/test/petab/test_amici_objective.py @@ -130,3 +130,38 @@ def test_preeq_guesses(): # assert that resetting works problem.objective.initialize() assert obj.steadystate_guesses["fval"] == np.inf + + +def test_edatas(): + """ + Test whether optimization with preequilibration guesses works, asserts + that steadystate guesses are written and checks that gradient is still + correct with guesses set. + """ + model_name = "Brannmark_JBC2010" + importer = pypesto.petab.PetabImporter.from_yaml( + os.path.join(models.MODELS_DIR, model_name, model_name + ".yaml") + ) + pars = np.asarray(importer.petab_problem.x_nominal_free_scaled) + + full_objective = importer.create_objective() + + edatas = importer.create_edatas() + + full_result = full_objective(pars, return_dict=True) + for pm_full, edata in zip( + full_objective.parameter_mapping.parameter_mappings, + edatas, + strict=True, + ): + sub_objective = importer.create_objective(edatas=[edata]) + pm_sub = sub_objective.parameter_mapping.parameter_mappings[0] + for var, val in vars(pm_full).items(): + assert getattr(pm_sub, var) == val, var + + sub_result = sub_objective(pars, return_dict=True) + assert sub_result[C.FVAL] == sum( + -rdata["llh"] + for rdata in full_result[C.RDATAS] + if rdata.id == sub_result[C.RDATAS][0].id + )