@@ -166,13 +166,13 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
166166        else :
167167            pybamm .logger .info ("Start solver set-up" )
168168
169-         self ._check_and_prepare_model_inplace (model ,  inputs ,  ics_only )
169+         self ._check_and_prepare_model_inplace (model )
170170
171171        # set default calculate sensitivities on model 
172172        if  not  hasattr (model , "calculate_sensitivities" ):
173173            model .calculate_sensitivities  =  []
174174
175-         self ._set_up_model_sensitivities_inplace (model ,  inputs )
175+         self ._set_up_model_sensitivities_inplace (model )
176176
177177        vars_for_processing  =  self ._get_vars_for_processing (model , inputs )
178178
@@ -369,7 +369,7 @@ def _wrangle_name(cls, name: str) -> str:
369369            name  =  name .replace (string , replacement )
370370        return  name 
371371
372-     def  _check_and_prepare_model_inplace (self , model ,  inputs ,  ics_only ):
372+     def  _check_and_prepare_model_inplace (self , model ):
373373        """ 
374374        Performs checks on the model and prepares it for solving. 
375375        """ 
@@ -461,7 +461,7 @@ def _get_vars_for_processing(model, inputs):
461461            return  vars_for_processing 
462462
463463    @staticmethod  
464-     def  _set_up_model_sensitivities_inplace (model ,  inputs ):
464+     def  _set_up_model_sensitivities_inplace (model ):
465465        """ 
466466        Set up model attributes related to sensitivities. 
467467        """ 
@@ -826,14 +826,9 @@ def solve(
826826        t_interp  =  self .process_t_interp (t_interp )
827827
828828        # Set up inputs 
829-         # 
830-         # Argument "inputs" can be either a list of input dicts or 
831-         # a single dict. The remaining of this function is only working 
832-         # with variable "input_list", which is a list of dictionaries. 
833-         # If "inputs" is a single dict, "inputs_list" is a list of only one dict. 
834-         inputs_list  =  inputs  if  isinstance (inputs , list ) else  [inputs ]
835-         model_inputs_list  =  [
836-             self ._set_up_model_inputs (model , inputs ) for  inputs  in  inputs_list 
829+         model_inputs_list : list [dict ] =  [
830+             self ._set_up_model_inputs (model , inputs )
831+             for  inputs  in  (inputs  if  isinstance (inputs , list ) else  [inputs ])
837832        ]
838833
839834        calculate_sensitivities_list , sensitivities_have_changed  =  (
@@ -848,18 +843,13 @@ def solve(
848843        # is passed to `_set_consistent_initialization`. 
849844        # See https://github.com/pybamm-team/PyBaMM/pull/1261 
850845        if  len (model_inputs_list ) >  1 :
851-             all_inputs_names  =  set (
852-                 itertools .chain .from_iterable (
853-                     [model_inputs .keys () for  model_inputs  in  model_inputs_list ]
854-                 )
855-             )
846+             all_inputs_names  =  {
847+                 key  for  model_inputs  in  model_inputs_list  for  key  in  model_inputs 
848+             }
856849            if  all_inputs_names :
857-                 initial_conditions_node_names  =  set (
858-                     [
859-                         it .name 
860-                         for  it  in  model .concatenated_initial_conditions .pre_order ()
861-                     ]
862-                 )
850+                 initial_conditions_node_names  =  {
851+                     it .name  for  it  in  model .concatenated_initial_conditions .pre_order ()
852+                 }
863853                if  not  initial_conditions_node_names .isdisjoint (all_inputs_names ):
864854                    raise  pybamm .SolverError (
865855                        "Input parameters cannot appear in expression " 
@@ -910,9 +900,9 @@ def solve(
910900                # If the new initial conditions are different 
911901                # and cannot be evaluated directly, set up again 
912902                self .set_up (model , model_inputs_list [0 ], t_eval , ics_only = True )
913-             self ._model_set_up [model ]["initial conditions" ]  =  ( 
914-                 model . concatenated_initial_conditions 
915-             ) 
903+             self ._model_set_up [model ][
904+                 "initial conditions" 
905+             ]  =   model . concatenated_initial_conditions 
916906        else :
917907            # Set the standard initial conditions 
918908            self ._set_initial_conditions (model , t_eval [0 ], model_inputs_list [0 ])
0 commit comments