Skip to content

Commit 0f2e190

Browse files
committed
Pass particles as arg to compute stress
1 parent 0dbf19f commit 0f2e190

File tree

5 files changed

+46
-14
lines changed

5 files changed

+46
-14
lines changed

diffmpm/materials/_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ class _Material(abc.ABC):
66
"""Base material class."""
77

88
_props: Tuple[str, ...]
9+
properties: dict
910

1011
def __init__(self, material_properties):
1112
"""Initialize material properties.
@@ -35,7 +36,7 @@ def __repr__(self):
3536
...
3637

3738
@abc.abstractmethod
38-
def compute_stress(self):
39+
def compute_stress(self, particles):
3940
"""Compute stress for the material."""
4041
...
4142

diffmpm/materials/linear_elastic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ class LinearElastic(_Material):
99
"""Linear Elastic Material."""
1010

1111
_props = ("density", "youngs_modulus", "poisson_ratio")
12+
state_vars = ()
1213

1314
def __init__(self, material_properties):
1415
"""Create a Linear Elastic material.
@@ -63,7 +64,7 @@ def _compute_elastic_tensor(self):
6364
]
6465
)
6566

66-
def compute_stress(self, dstrain):
67+
def compute_stress(self, particles):
6768
"""Compute material stress."""
68-
dstress = self.de @ dstrain
69+
dstress = self.de @ particles.dstrain
6970
return dstress

diffmpm/materials/simple.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
@register_pytree_node_class
77
class SimpleMaterial(_Material):
88
_props = ("E", "density")
9+
state_vars = ()
910

1011
def __init__(self, material_properties):
1112
self.validate_props(material_properties)
@@ -14,5 +15,5 @@ def __init__(self, material_properties):
1415
def __repr__(self):
1516
return f"SimpleMaterial(props={self.properties})"
1617

17-
def compute_stress(self, dstrain):
18-
return dstrain * self.properties["E"]
18+
def compute_stress(self, particles):
19+
return particles.dstrain * self.properties["E"]

diffmpm/particle.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ def __init__(
6969
self.reference_loc = jnp.zeros_like(self.loc)
7070
self.dvolumetric_strain = jnp.zeros((self.loc.shape[0], 1))
7171
self.volumetric_strain_centroid = jnp.zeros((self.loc.shape[0], 1))
72+
self.state_vars = {}
73+
if self.material.state_vars:
74+
self.state_vars = self.material.initialize_state_variables(
75+
self.loc.shape[0]
76+
)
7277
else:
7378
(
7479
self.mass,
@@ -87,6 +92,7 @@ def __init__(
8792
self.reference_loc,
8893
self.dvolumetric_strain,
8994
self.volumetric_strain_centroid,
95+
self.state_vars,
9096
) = data # type: ignore
9197
self.initialized = True
9298

@@ -112,6 +118,7 @@ def tree_flatten(self):
112118
self.reference_loc,
113119
self.dvolumetric_strain,
114120
self.volumetric_strain_centroid,
121+
self.state_vars,
115122
)
116123
aux_data = (self.material,)
117124
return (children, aux_data)
@@ -319,7 +326,7 @@ def compute_stress(self, *args):
319326
particles. The stress calculated by the material is then
320327
added to the particles current stress values.
321328
"""
322-
self.stress = self.stress.at[:].add(self.material.compute_stress(self.dstrain))
329+
self.stress = self.stress.at[:].add(self.material.compute_stress(self))
323330

324331
def update_volume(self, *args):
325332
"""Update volume based on central strain rate."""

tests/test_material.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,48 @@
11
import jax.numpy as jnp
22
import pytest
33
from diffmpm.materials import LinearElastic, SimpleMaterial
4+
from diffmpm.particle import Particles
45

5-
material_dstrain_stress_targets = [
6+
particles_dstrain_stress_targets = [
67
(
7-
SimpleMaterial({"E": 10, "density": 1}),
8+
Particles(
9+
jnp.array([[0.5, 0.5]]).reshape(1, 1, 2),
10+
SimpleMaterial({"E": 10, "density": 1}),
11+
jnp.array([0]),
12+
),
813
jnp.ones((1, 6, 1)),
914
jnp.ones((1, 6, 1)) * 10,
1015
),
1116
(
12-
LinearElastic({"density": 1, "youngs_modulus": 10, "poisson_ratio": 1}),
17+
Particles(
18+
jnp.array([[0.5, 0.5]]).reshape(1, 1, 2),
19+
LinearElastic({"density": 1, "youngs_modulus": 10, "poisson_ratio": 1}),
20+
jnp.array([0]),
21+
),
1322
jnp.ones((1, 6, 1)),
1423
jnp.array([-10, -10, -10, 2.5, 2.5, 2.5]).reshape(1, 6, 1),
1524
),
1625
(
17-
LinearElastic({"density": 1000, "youngs_modulus": 1e7, "poisson_ratio": 0.3}),
26+
Particles(
27+
jnp.array([[0.5, 0.5]]).reshape(1, 1, 2),
28+
LinearElastic(
29+
{"density": 1000, "youngs_modulus": 1e7, "poisson_ratio": 0.3}
30+
),
31+
jnp.array([0]),
32+
),
1833
jnp.array([0.001, 0.0005, 0, 0, 0, 0]).reshape(1, 6, 1),
1934
jnp.array([1.63461538461538e4, 12500, 0.86538461538462e4, 0, 0, 0]).reshape(
2035
1, 6, 1
2136
),
2237
),
2338
(
24-
LinearElastic({"density": 1000, "youngs_modulus": 1e7, "poisson_ratio": 0.3}),
39+
Particles(
40+
jnp.array([[0.5, 0.5]]).reshape(1, 1, 2),
41+
LinearElastic(
42+
{"density": 1000, "youngs_modulus": 1e7, "poisson_ratio": 0.3}
43+
),
44+
jnp.array([0]),
45+
),
2546
jnp.array([0.001, 0.0005, 0, 0.00001, 0, 0]).reshape(1, 6, 1),
2647
jnp.array(
2748
[1.63461538461538e4, 12500, 0.86538461538462e4, 3.84615384615385e01, 0, 0]
@@ -30,7 +51,8 @@
3051
]
3152

3253

33-
@pytest.mark.parametrize("material, dstrain, target", material_dstrain_stress_targets)
34-
def test_compute_stress(material, dstrain, target):
35-
stress = material.compute_stress(dstrain)
54+
@pytest.mark.parametrize("particles, dstrain, target", particles_dstrain_stress_targets)
55+
def test_compute_stress(particles, dstrain, target):
56+
particles.dstrain = dstrain
57+
stress = particles.material.compute_stress(particles)
3658
assert jnp.allclose(stress, target)

0 commit comments

Comments
 (0)