diff --git a/diffmpm/material.py b/diffmpm/material.py index 2dd8487..ca43252 100644 --- a/diffmpm/material.py +++ b/diffmpm/material.py @@ -1,6 +1,7 @@ from jax.tree_util import register_pytree_node_class import abc import jax.numpy as jnp +from jax import vmap, lax, jit class Material(abc.ABC): @@ -133,6 +134,186 @@ def compute_stress(self, dstrain): return dstrain * self.properties["E"] +@register_pytree_node_class +class Bingham(Material): + _props = ( + "density", + "youngs_modulus", + "poisson_ratio", + "tau0", + "mu", + "critical_shear_rate", + "ndim", + ) + + def __init__(self, material_properties): + """ + Create a Bingham material model. + + Arguments + --------- + material_properties: dict + Dictionary with material properties. For Bingham + material, 'density','youngs_modulus','poisson_ratio','tau0','mu', + 'ndim and 'critical_shear_rate' are required keys. + + Methods + ------- + initialise_state_variables + Initialises the state variables for the Bingham material + compute_stress + computes the stress for the Bingham material particles + + """ + self.validate_props(material_properties) + self.ndim = material_properties["ndim"] + youngs_modulus = material_properties["youngs_modulus"] + poisson_ratio = material_properties["poisson_ratio"] + self.state_variables=["pressure"] + # Calculate the bulk modulus + bulk_modulus = youngs_modulus / (3.0 * (1.0 - 2.0 * poisson_ratio)) + compressibility_multiplier_ = 1.0 + # Special Material Properties + if material_properties.get("incompressible", False): + compressibility_multiplier_ = 0.0 + self.properties = { + **material_properties, + "bulk_modulus": bulk_modulus, + "compressibility_multiplier": compressibility_multiplier_, + } + + def __repr__(self): + return f"Bingham(props={self.properties})" + + # Initialise history variables + def initialise_state_variables(particles): + state_vars = {} + state_vars["pressure"] = jnp.zeros((particles.loc.shape[0])) + return state_vars + + # Compute the pressure + def __thermodynamic_pressure(self, volumetric_strain): + return -self.properties["bulk_modulus"] * volumetric_strain + + # Compute the stress + def compute_stress(self, dstrain, particles, state_vars:dict): + """ + Computes the stress for the Bingham material. + + Parameters + ---------- + dstrain: array_like + The strain rate tensor for the particles + particles: diffmpm.particles.Particles + state_vars: dict {str: jnp.ndarray} + dictionary containig the string as the name of the + property and the jnp.ndarray shape (nparticles, 1) as the + values of the property at each particle. + + Returns + ------- + updated_stress: jnp.ndarray + The updated stress for the particles expected shape (nparticles, 6,1) + """ + shear_rate_threshold = 1e-15 + # dirac delta in Voigt notation + dirac_delta = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).reshape((6, 1)) + dirac_delta = lax.cond( + self.ndim == 1, + lambda x: x.at[0, 0].set(1.0), + lambda x: x.at[0:2, 0].set(1.0), + dirac_delta, + ) + # Set threshold for minimum critical shear rate + self.properties["critical_shear_rate"] = lax.select( + self.properties["critical_shear_rate"] < shear_rate_threshold, + shear_rate_threshold, + self.properties["critical_shear_rate"], + ) + + @jit + def __compute_stress_per_particle( + particle_strain_rate, + self, + state_vars_pressure, + dvolumetric_strain_per_particle, + dirac_delta, + ): + strain_r = particle_strain_rate + # Convert strain rate to rate of deformation tensor + strain_r = strain_r.at[-3:].multiply(0.5) + + # Rate of shear = sqrt(2 * Strain_r_ij * Strain_r_ij) + # Strain_r is in Voigt notation so the above formula reduces in the voigt notation to + # sqrt(Strain_r_0^2 + Strain_r_1^2 + Strain_r_2^2 + 2*Strain_r_3^2 + # + 2*Strain_r_4^2 + 2*Strain_r_5^2) + # When shear rate> critical_shear_rate^2 then the material is yielding + + shear_rate = jnp.sqrt( + 2.0 * (strain_r.T @ (strain_r) + strain_r[-3:].T @ strain_r[-3:]) + ).squeeze() + + # Apparent_viscosity maps shear rate to shear stress + # Check if shear rate is 0 + + apparent_viscosity_true = 2.0 * ( + (self.properties["tau0"] / shear_rate) + self.properties["mu"] + ) + condition = (shear_rate * shear_rate) > ( + self.properties["critical_shear_rate"] + * self.properties["critical_shear_rate"] + ) + apparent_viscosity = lax.select(condition, apparent_viscosity_true, 0.0) + + # Compute volumetric tau + + tau = apparent_viscosity * strain_r + # von Mises criterion + # yield condition trace of the invariant > tau0^2 + # and trace can be found using the first 3 numbers of tau + # as tau is in voigt notation + + trace_invariant = 0.5 * jnp.dot(tau[:3, 0], tau[:3, 0]) + tau = lax.cond( + trace_invariant < (self.properties["tau0"] * self.properties["tau0"]), + lambda x: x.at[:].set(0), + lambda x: x, + tau, + ) + # update pressure + state_vars_pressure += self.properties[ + "compressibility_multiplier" + ] * self.__thermodynamic_pressure(dvolumetric_strain_per_particle) + + # Update volumetric and deviatoric stress + # thermodynamic pressure is from material point + # stress = -thermodynamic_pressure I + tau, where I is identity matrix or + # direc_delta in Voigt notation + + updated_stress_per_particle = ( + -(state_vars_pressure) + * dirac_delta + * self.properties["compressibility_multiplier"] + + tau + ) + return updated_stress_per_particle, state_vars_pressure + + # using vmap to vectorise the function compute stress per particle + # for all the particles using the first dimension of the strain rate + # and the dvolumetric_strain and state_vars pressure + updated_stress, state_vars["pressure"] = vmap( + __compute_stress_per_particle, in_axes=(0, None, 0, 0, None) + )( + particles.strain_rate, + self, + state_vars["pressure"], + particles.dvolumetric_strain, + dirac_delta, + ) + + return updated_stress + + if __name__ == "__main__": from diffmpm.utils import _show_example diff --git a/tests/test_bingham.py b/tests/test_bingham.py new file mode 100644 index 0000000..0fb9e42 --- /dev/null +++ b/tests/test_bingham.py @@ -0,0 +1,206 @@ +import pytest +import jax.numpy as jnp +from diffmpm.material import Bingham +from diffmpm.particle import Particles +from diffmpm.element import Quadrilateral4Node +from diffmpm.constraint import Constraint +from diffmpm.node import Nodes + +particles_element_targets = [ + ( + Particles( + jnp.array([[0.5, 0.5]]).reshape(1, 1, 2), + ( + Bingham( + { + "density": 1000, + "youngs_modulus": 1.0e7, + "poisson_ratio": 0.3, + "tau0": 771.8, + "mu": 0.0451, + "critical_shear_rate": 0.2, + "ndim": 2, + } + ) + ), + jnp.array([0]), + ), + Quadrilateral4Node( + (1, 1), + 1, + (4.0, 4.0), + [], + Nodes(4, jnp.array([-2, -2, 2, -2, -2, 2, 2, 2]).reshape((4, 1, 2))), + ), + jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).reshape((6, 1)), + ), + ( + Particles( + jnp.array([[0.5, 0.5]]).reshape(1, 1, 2), + ( + Bingham( + { + "density": 1000, + "youngs_modulus": 1.0e7, + "poisson_ratio": 0.3, + "tau0": 771.8, + "mu": 0.0451, + "critical_shear_rate": 0.2, + "ndim": 2, + } + ) + ), + jnp.array([0]), + ), + Quadrilateral4Node( + (1, 1), + 1, + (4.0, 4.0), + [(0, Constraint(0, 0.02)), (0, Constraint(1, 0.03))], + Nodes(4, jnp.array([-2, -2, 2, -2, -2, 2, 2, 2]).reshape((4, 1, 2))), + ), + jnp.array([-52083.3333333333, -52083.3333333333, 0.0, 0.0, 0.0, 0.0]).reshape( + (6, 1) + ), + ), + ( + Particles( + jnp.array([[0.5, 0.5]]).reshape(1, 1, 2), + ( + Bingham( + { + "density": 1000, + "youngs_modulus": 1.0e7, + "poisson_ratio": 0.3, + "tau0": 200.0, + "mu": 200.0, + "critical_shear_rate": 0.2, + "ndim": 2, + } + ) + ), + jnp.array([0]), + ), + Quadrilateral4Node( + (1, 1), + 1, + (4.0, 4.0), + [(0, Constraint(0, 2.0)), (0, Constraint(1, 3.0))], + Nodes(4, jnp.array([-2, -2, 2, -2, -2, 2, 2, 2]).reshape((4, 1, 2))), + ), + jnp.array( + [-5208520.35574006, -5208613.86694342, 0.0, -233.778008402801, 0.0, 0.0] + ).reshape((6, 1)), + ), + ( + Particles( + jnp.array([[0.5, 0.5]]).reshape(1, 1, 2), + ( + Bingham( + { + "density": 1000, + "youngs_modulus": 1.0e7, + "poisson_ratio": 0.3, + "tau0": 200.0, + "mu": 200.0, + "critical_shear_rate": 0.2, + "ndim": 2, + "incompressible": True, + } + ) + ), + jnp.array([0]), + ), + Quadrilateral4Node( + (1, 1), + 1, + (4.0, 4.0), + [(0, Constraint(0, 2.0)), (0, Constraint(1, 3.0))], + Nodes(4, jnp.array([-2, -2, 2, -2, -2, 2, 2, 2]).reshape((4, 1, 2))), + ), + jnp.array( + [-187.0224067222, -280.5336100834, 0.0, -233.778008402801, 0.0, 0.0] + ).reshape((6, 1)), + ), +] + +@pytest.mark.parametrize( + "particles, element, target", + particles_element_targets, +) +def test_compute_stress(particles, element, target): + particles.update_natural_coords(element) + if element.constraints: + element.apply_boundary_constraints() + particles.compute_strain(element, 1.0) + stress = particles.material.compute_stress(None, particles, {"pressure": jnp.zeros(1)}) + assert jnp.allclose(stress, target) + + +def test_key_not_present_in_material_properties(): + with pytest.raises(KeyError): + material = Bingham( + { + "density": 1000, + "youngs_modulus": 1.0e7, + "poisson_ratio": 0.3, + "tau0": 771.8, + "critical_shear_rate": 0.2, + "ndim": 2, + } + ) + + +@pytest.mark.parametrize( + "particles, element, target, state_vars", + [ + (Particles( + jnp.array([[0.5, 0.5, 0.5, 0.5]]).reshape(2, 1, 2), + ( + Bingham( + { + "density": 1000, + "youngs_modulus": 1.0e7, + "poisson_ratio": 0.3, + "tau0": 200.0, + "mu": 200.0, + "critical_shear_rate": 0.2, + "ndim": 2, + } + ) + ), + jnp.array([0,0]), + ), + Quadrilateral4Node( + (1, 1), + 1, + (4.0, 4.0), + [(0, Constraint(0, 2.0)), (0, Constraint(1, 3.0))], + Nodes(4, jnp.array([-2, -2, 2, -2, -2, 2, 2, 2]).reshape((4, 1, 2))), + ), + jnp.array( + [ + -5208520.35574006, + -5208613.86694342, + 0.0, + -233.778008402801, + 0.0, + 0.0, + -5208520.35574006, + -5208613.86694342, + 0.0, + -233.778008402801, + 0.0, + 0.0, + ] + ).reshape((2, 6, 1)), + {"pressure": jnp.zeros((2, 1))}), + ], +) +def test_compute_stress_two_particles(particles, state_vars, element, target): + particles.update_natural_coords(element) + if element.constraints: + element.apply_boundary_constraints() + particles.compute_strain(element, 1.0) + stress = particles.material.compute_stress(None, particles, state_vars) + assert jnp.allclose(stress, target)