Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 23 additions & 83 deletions src/pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1512,8 +1512,6 @@ def process(
function evaluators created will have this base name
use_jacobian: bool, optional
whether to return Jacobian functions
return_jacp_stacked: bool, optional
returns Jacobian function wrt stacked parameters instead of jacp

Returns
-------
Expand All @@ -1522,28 +1520,18 @@ def process(
:class:`casadi.Function`
evaluator for the function $f(y, t, p)$ given by `symbol`

jac: :class:`pybamm.EvaluatorPython` or
:class:`pybamm.EvaluatorJaxJacobian` or
:class:`casadi.Function`
jac: same type as func
evaluator for the Jacobian $\\frac{\\partial f}{\\partial y}$
of the function given by `symbol`

jacp: :class:`pybamm.EvaluatorPython` or
:class:`pybamm.EvaluatorJaxSensitivities` or
:class:`casadi.Function`
evaluator for the parameter sensitivities
$\frac{\\partial f}{\\partial p}$
of the function given by `symbol`
jacp: same type as func
evaluator for parameter sensitivities
$\\frac{\\partial f}{\\partial p}$ (always stacked)

jac_action: :class:`pybamm.EvaluatorPython` or
:class:`pybamm.EvaluatorJax` or
:class:`casadi.Function`
evaluator for product of the Jacobian with a vector $v$,
i.e. $\\frac{\\partial f}{\\partial y} * v$
jac_action: same type as func
evaluator for Jacobian-vector product $\\frac{\\partial f}{\\partial y} * v$
"""

def report(string):
# don't log event conversion
if "event" not in string:
pybamm.logger.verbose(string)

Expand All @@ -1562,103 +1550,55 @@ def report(string):
f"to parameters {model.calculate_sensitivities} using jax"
)
jacp = func.get_sensitivities()
if use_jacobian:
report(f"Calculating jacobian for {name} using jax")
jac = func.get_jacobian()
jac_action = func.get_jacobian_action()
else:
jac = None
jac_action = None
jac = func.get_jacobian() if use_jacobian else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the single if statement is less complex then multiple, plus you are removing the report, so I'd disagree with this change

jac_action = func.get_jacobian_action() if use_jacobian else None

elif model.convert_to_format != "casadi":
y = vars_for_processing["y"]
jacobian = vars_for_processing["jacobian"]

if model.calculate_sensitivities:
raise pybamm.SolverError( # pragma: no cover
raise pybamm.SolverError(
"Sensitivies are no longer supported for the python "
"evaluator. Please use `convert_to_format = 'casadi'`, or `jax` "
"to calculate sensitivities."
)

else:
jacp = None

if use_jacobian:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same with this one

report(f"Calculating jacobian for {name}")
jac = jacobian.jac(symbol, y)
report(f"Converting jacobian for {name} to python")
jac = pybamm.EvaluatorPython(jac)
# cannot do jacobian action efficiently for now
jac_action = None
else:
jac = None
jac_action = None

report(f"Converting {name} to python")
jacp = None
jac = jacobian.jac(symbol, y) if use_jacobian else None
jac = pybamm.EvaluatorPython(jac) if jac else None
jac_action = None
func = pybamm.EvaluatorPython(symbol)
report(f"Converting {name} to python")

else:
t_casadi = vars_for_processing["t_casadi"]
y_casadi = vars_for_processing["y_casadi"]
p_casadi = vars_for_processing["p_casadi"]
p_casadi_stacked = vars_for_processing["p_casadi_stacked"]

# Process with CasADi
report(f"Converting {name} to CasADi")
casadi_expression = symbol.to_casadi(t_casadi, y_casadi, inputs=p_casadi)
# Add sensitivity vectors to the rhs and algebraic equations
jacp = None

if model.calculate_sensitivities:
report(
f"Calculating sensitivities for {name} with respect "
f"to parameters {model.calculate_sensitivities} using "
"CasADi"
report(f"Calculating sensitivities for {name} using CasADi")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its useful to have the printout of model.calculate_sensitivities here, so keep this

jacp = casadi.Function(
f"d{name}_dp",
[t_casadi, y_casadi, p_casadi_stacked],
[casadi.jacobian(casadi_expression, p_casadi_stacked)],
)
# Compute derivate wrt p-stacked (can be passed to solver to
# compute sensitivities online)
if return_jacp_stacked:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removing this if statement is only a small part of this change, you will also need to change the behaviour of the solvers that depend on this functionality (I believe, but I'll run the tests so we can check)

jacp = casadi.Function(
f"d{name}_dp",
[t_casadi, y_casadi, p_casadi_stacked],
[casadi.jacobian(casadi_expression, p_casadi_stacked)],
)
else:
# WARNING, jacp for convert_to_format=casadi does not return a dict
# instead it returns multiple return values, one for each param
# TODO: would it be faster to do the jacobian wrt pS_casadi_stacked?
jacp = casadi.Function(
name + "_jacp",
[t_casadi, y_casadi, p_casadi_stacked],
[
casadi.densify(
casadi.jacobian(casadi_expression, p_casadi[pname])
)
for pname in model.calculate_sensitivities
],
)

if use_jacobian:
report(f"Calculating jacobian for {name} using CasADi")
jac_casadi = casadi.jacobian(casadi_expression, y_casadi)
jac = casadi.Function(
name + "_jac",
[t_casadi, y_casadi, p_casadi_stacked],
[jac_casadi],
)

v = casadi.MX.sym(
"v",
model.len_rhs_and_alg,
)
jac_action_casadi = casadi.densify(
casadi.jtimes(casadi_expression, y_casadi, v)
f"{name}_jac", [t_casadi, y_casadi, p_casadi_stacked], [jac_casadi]
)
v = casadi.MX.sym("v", model.len_rhs_and_alg)
jac_action = casadi.Function(
name + "_jac_action",
f"{name}_jac_action",
[t_casadi, y_casadi, p_casadi_stacked, v],
[jac_action_casadi],
[casadi.densify(casadi.jtimes(casadi_expression, y_casadi, v))],
)
else:
jac = None
Expand Down
Loading