7272def compute_loss (params , * , solver , target_vel , config ):
7373 # material = init_simple({"E": params, "density": 1, "id": -1})
7474 material = init_linear_elastic (
75- {"youngs_modulus" : params , "density" : 1 , "poisson_ratio" : 0 , "id" : - 1 }
75+ {
76+ "youngs_modulus" : params ["ym" ],
77+ "density" : 1 ,
78+ "poisson_ratio" : 0 ,
79+ "id" : - 1 ,
80+ }
7681 )
77- # breakpoint()
7882 particles_ = [
7983 init_particle_state (
8084 config .parsed_config ["particles" ][0 ].loc ,
@@ -102,11 +106,11 @@ def compute_loss(params, *, solver, target_vel, config):
102106
103107def optax_adam (params , niter , mpm , target_vel , config ):
104108 # Initialize parameters of the model + optimizer.
105- start_learning_rate = 1
109+ start_learning_rate = 4
106110 optimizer = optax .adam (start_learning_rate )
107111 opt_state = optimizer .init (params )
108112
109- param_list = []
113+ param_list = { "ym" : [], "pr" : []}
110114 loss_list = []
111115 # A simple update loop.
112116 t = tqdm (range (niter ), desc = f"E: { params } " )
@@ -115,16 +119,23 @@ def optax_adam(params, niter, mpm, target_vel, config):
115119 lo , grads = jax .value_and_grad (partial_f , argnums = 0 )(params )
116120 updates , opt_state = optimizer .update (grads , opt_state )
117121 params = optax .apply_updates (params , updates )
118- t .set_description (f"YM: { params } " )
119- param_list .append (params )
122+ t .set_description (f"YM: { params ['ym' ]:.2f} " )
123+ param_list ["ym" ].append (params ["ym" ])
124+ # param_list["pr"].append(params["pr"])
120125 loss_list .append (lo )
121126 return param_list , loss_list
122127
123128
124- params = 900.5
129+ # params = {"pr": 0.4}
130+ params = {"ym" : 1101.0 }
125131# material = init_simple({"E": params, "density": 1, "id": -1})
126132material = init_linear_elastic (
127- {"youngs_modulus" : params , "density" : 1 , "poisson_ratio" : 0 , "id" : - 1 }
133+ {
134+ "youngs_modulus" : params ["ym" ],
135+ "density" : 1 ,
136+ "poisson_ratio" : 0 ,
137+ "id" : - 1 ,
138+ }
128139)
129140particles = [
130141 init_particle_state (
@@ -142,11 +153,11 @@ def optax_adam(params, niter, mpm, target_vel, config):
142153 }
143154)
144155param_list , loss_list = optax_adam (
145- params , 100 , solver , true_vel , config
156+ params , 200 , solver , true_vel , config
146157) # ADAM optimizer
147158
148159fig , ax = plt .subplots (1 , 2 , figsize = (16 , 6 ))
149- ax [0 ].plot (param_list , "ko" , markersize = 2 , label = "E" )
160+ ax [0 ].plot (param_list [ "ym" ] , "ko" , markersize = 2 , label = "E" )
150161ax [0 ].grid ()
151162ax [0 ].legend ()
152163ax [1 ].plot (loss_list , "ko" , markersize = 2 , label = "Loss" )
0 commit comments