Skip to content

Commit 03f112c

Browse files
committed
Update optim file
1 parent 1034d4f commit 03f112c

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

optim_benchmark.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,13 @@
7272
def compute_loss(params, *, solver, target_vel, config):
7373
# material = init_simple({"E": params, "density": 1, "id": -1})
7474
material = init_linear_elastic(
75-
{"youngs_modulus": params, "density": 1, "poisson_ratio": 0, "id": -1}
75+
{
76+
"youngs_modulus": params["ym"],
77+
"density": 1,
78+
"poisson_ratio": 0,
79+
"id": -1,
80+
}
7681
)
77-
# breakpoint()
7882
particles_ = [
7983
init_particle_state(
8084
config.parsed_config["particles"][0].loc,
@@ -102,11 +106,11 @@ def compute_loss(params, *, solver, target_vel, config):
102106

103107
def optax_adam(params, niter, mpm, target_vel, config):
104108
# Initialize parameters of the model + optimizer.
105-
start_learning_rate = 1
109+
start_learning_rate = 4
106110
optimizer = optax.adam(start_learning_rate)
107111
opt_state = optimizer.init(params)
108112

109-
param_list = []
113+
param_list = {"ym": [], "pr": []}
110114
loss_list = []
111115
# A simple update loop.
112116
t = tqdm(range(niter), desc=f"E: {params}")
@@ -115,16 +119,23 @@ def optax_adam(params, niter, mpm, target_vel, config):
115119
lo, grads = jax.value_and_grad(partial_f, argnums=0)(params)
116120
updates, opt_state = optimizer.update(grads, opt_state)
117121
params = optax.apply_updates(params, updates)
118-
t.set_description(f"YM: {params}")
119-
param_list.append(params)
122+
t.set_description(f"YM: {params['ym']:.2f}")
123+
param_list["ym"].append(params["ym"])
124+
# param_list["pr"].append(params["pr"])
120125
loss_list.append(lo)
121126
return param_list, loss_list
122127

123128

124-
params = 900.5
129+
# params = {"pr": 0.4}
130+
params = {"ym": 1101.0}
125131
# material = init_simple({"E": params, "density": 1, "id": -1})
126132
material = init_linear_elastic(
127-
{"youngs_modulus": params, "density": 1, "poisson_ratio": 0, "id": -1}
133+
{
134+
"youngs_modulus": params["ym"],
135+
"density": 1,
136+
"poisson_ratio": 0,
137+
"id": -1,
138+
}
128139
)
129140
particles = [
130141
init_particle_state(
@@ -142,11 +153,11 @@ def optax_adam(params, niter, mpm, target_vel, config):
142153
}
143154
)
144155
param_list, loss_list = optax_adam(
145-
params, 100, solver, true_vel, config
156+
params, 200, solver, true_vel, config
146157
) # ADAM optimizer
147158

148159
fig, ax = plt.subplots(1, 2, figsize=(16, 6))
149-
ax[0].plot(param_list, "ko", markersize=2, label="E")
160+
ax[0].plot(param_list["ym"], "ko", markersize=2, label="E")
150161
ax[0].grid()
151162
ax[0].legend()
152163
ax[1].plot(loss_list, "ko", markersize=2, label="Loss")

0 commit comments

Comments
 (0)