@@ -166,13 +166,13 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
166166 else :
167167 pybamm .logger .info ("Start solver set-up" )
168168
169- self ._check_and_prepare_model_inplace (model , inputs , ics_only )
169+ self ._check_and_prepare_model_inplace (model )
170170
171171 # set default calculate sensitivities on model
172172 if not hasattr (model , "calculate_sensitivities" ):
173173 model .calculate_sensitivities = []
174174
175- self ._set_up_model_sensitivities_inplace (model , inputs )
175+ self ._set_up_model_sensitivities_inplace (model )
176176
177177 vars_for_processing = self ._get_vars_for_processing (model , inputs )
178178
@@ -369,7 +369,7 @@ def _wrangle_name(cls, name: str) -> str:
369369 name = name .replace (string , replacement )
370370 return name
371371
372- def _check_and_prepare_model_inplace (self , model , inputs , ics_only ):
372+ def _check_and_prepare_model_inplace (self , model ):
373373 """
374374 Performs checks on the model and prepares it for solving.
375375 """
@@ -461,7 +461,7 @@ def _get_vars_for_processing(model, inputs):
461461 return vars_for_processing
462462
463463 @staticmethod
464- def _set_up_model_sensitivities_inplace (model , inputs ):
464+ def _set_up_model_sensitivities_inplace (model ):
465465 """
466466 Set up model attributes related to sensitivities.
467467 """
@@ -826,14 +826,9 @@ def solve(
826826 t_interp = self .process_t_interp (t_interp )
827827
828828 # Set up inputs
829- #
830- # Argument "inputs" can be either a list of input dicts or
831- # a single dict. The remaining of this function is only working
832- # with variable "input_list", which is a list of dictionaries.
833- # If "inputs" is a single dict, "inputs_list" is a list of only one dict.
834- inputs_list = inputs if isinstance (inputs , list ) else [inputs ]
835- model_inputs_list = [
836- self ._set_up_model_inputs (model , inputs ) for inputs in inputs_list
829+ model_inputs_list : list [dict ] = [
830+ self ._set_up_model_inputs (model , inputs )
831+ for inputs in (inputs if isinstance (inputs , list ) else [inputs ])
837832 ]
838833
839834 calculate_sensitivities_list , sensitivities_have_changed = (
@@ -848,18 +843,13 @@ def solve(
848843 # is passed to `_set_consistent_initialization`.
849844 # See https://github.com/pybamm-team/PyBaMM/pull/1261
850845 if len (model_inputs_list ) > 1 :
851- all_inputs_names = set (
852- itertools .chain .from_iterable (
853- [model_inputs .keys () for model_inputs in model_inputs_list ]
854- )
855- )
846+ all_inputs_names = {
847+ key for model_inputs in model_inputs_list for key in model_inputs
848+ }
856849 if all_inputs_names :
857- initial_conditions_node_names = set (
858- [
859- it .name
860- for it in model .concatenated_initial_conditions .pre_order ()
861- ]
862- )
850+ initial_conditions_node_names = {
851+ it .name for it in model .concatenated_initial_conditions .pre_order ()
852+ }
863853 if not initial_conditions_node_names .isdisjoint (all_inputs_names ):
864854 raise pybamm .SolverError (
865855 "Input parameters cannot appear in expression "
@@ -910,9 +900,9 @@ def solve(
910900 # If the new initial conditions are different
911901 # and cannot be evaluated directly, set up again
912902 self .set_up (model , model_inputs_list [0 ], t_eval , ics_only = True )
913- self ._model_set_up [model ]["initial conditions" ] = (
914- model . concatenated_initial_conditions
915- )
903+ self ._model_set_up [model ][
904+ "initial conditions"
905+ ] = model . concatenated_initial_conditions
916906 else :
917907 # Set the standard initial conditions
918908 self ._set_initial_conditions (model , t_eval [0 ], model_inputs_list [0 ])
0 commit comments