diff --git a/cvxtorch/torch_numerics/elementwise/trig.py b/cvxtorch/torch_numerics/elementwise/trig.py new file mode 100644 index 0000000..b5b87c3 --- /dev/null +++ b/cvxtorch/torch_numerics/elementwise/trig.py @@ -0,0 +1,13 @@ +import torch +from cvxpy.expressions.expression import Expression + + +class sin: + @staticmethod + def torch_numeric(expr: Expression, values: list[torch.Tensor]) -> torch.Tensor: + return torch.sin(values[0]) + +class cos: + @staticmethod + def torch_numeric(expr: Expression, values: list[torch.Tensor]) -> torch.Tensor: + return torch.cos(values[0]) diff --git a/cvxtorch/utils/exp2tch.py b/cvxtorch/utils/exp2tch.py index 0d72045..85e4acd 100644 --- a/cvxtorch/utils/exp2tch.py +++ b/cvxtorch/utils/exp2tch.py @@ -160,4 +160,18 @@ minimum: minimum_tch.torch_numeric, power: power_tch.torch_numeric, xexp: xexp_tch.torch_numeric, -} \ No newline at end of file +} + +try: + # Atoms from cvxpy-ipopt + from cvxpy.atoms.elementwise.trig import cos, sin + + from cvxtorch.torch_numerics.elementwise.trig import cos as cos_tch, sin as sin_tch + EXPR2TORCH.extend({ + sin: sin_tch.torch_numeric, + cos: cos_tch.torch_numeric, + }) +except ImportError: + pass + +