Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions pypesto/petab/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,21 @@ def create_objective(
edatas = self.create_edatas(
model=model, simulation_conditions=simulation_conditions
)
else:
simulation_conditions = pd.DataFrame(
[
{
PREEQUILIBRATION_CONDITION_ID: edata.id.split("+")[0]
if "+" in edata.id
else "",
# why is this not CONDITION_SEP ¯\_(ツ)_/¯?
SIMULATION_CONDITION_ID: edata.id.split("+")[1]
if "+" in edata.id
else edata.id,
}
for edata in edatas
]
)

parameter_mapping = (
amici.petab.parameter_mapping.create_parameter_mapping(
Expand Down
35 changes: 35 additions & 0 deletions test/petab/test_amici_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,38 @@ def test_preeq_guesses():
# assert that resetting works
problem.objective.initialize()
assert obj.steadystate_guesses["fval"] == np.inf


def test_edatas():
"""
Test whether optimization with preequilibration guesses works, asserts
that steadystate guesses are written and checks that gradient is still
correct with guesses set.
"""
model_name = "Brannmark_JBC2010"
importer = pypesto.petab.PetabImporter.from_yaml(
os.path.join(models.MODELS_DIR, model_name, model_name + ".yaml")
)
pars = np.asarray(importer.petab_problem.x_nominal_free_scaled)

full_objective = importer.create_objective()

edatas = importer.create_edatas()

full_result = full_objective(pars, return_dict=True)
for pm_full, edata in zip(
full_objective.parameter_mapping.parameter_mappings,
edatas,
strict=True,
):
sub_objective = importer.create_objective(edatas=[edata])
pm_sub = sub_objective.parameter_mapping.parameter_mappings[0]
for var, val in vars(pm_full).items():
assert getattr(pm_sub, var) == val, var

sub_result = sub_objective(pars, return_dict=True)
assert sub_result[C.FVAL] == sum(
-rdata["llh"]
for rdata in full_result[C.RDATAS]
if rdata.id == sub_result[C.RDATAS][0].id
)