Currently, the StatefulTrainer cannot use [GradientTransformationExtraArgs](https://optax.readthedocs.io/en/latest/api/transformations.html#optax.GradientTransformationExtraArgs). It would be fairly easy to add support with an extra keyword argument.