Skip to content

Commit 44291fa

Browse files
committed
Adds NLP atoms sin cos
1 parent bae2d64 commit 44291fa

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import torch
2+
from cvxpy.expressions.expression import Expression
3+
4+
5+
class sin:
6+
@staticmethod
7+
def torch_numeric(expr: Expression, values: list[torch.Tensor]) -> torch.Tensor:
8+
return torch.sin(values[0])
9+
10+
class cos:
11+
@staticmethod
12+
def torch_numeric(expr: Expression, values: list[torch.Tensor]) -> torch.Tensor:
13+
return torch.cos(values[0])

cvxtorch/utils/exp2tch.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,18 @@
160160
minimum: minimum_tch.torch_numeric,
161161
power: power_tch.torch_numeric,
162162
xexp: xexp_tch.torch_numeric,
163-
}
163+
}
164+
165+
try:
166+
# Atoms from cvxpy-ipopt
167+
from cvxpy.atoms.elementwise.trig import cos, sin
168+
169+
from cvxtorch.torch_numerics.elementwise.trig import cos as cos_tch, sin as sin_tch
170+
EXPR2TORCH.extend({
171+
sin: sin_tch.torch_numeric,
172+
cos: cos_tch.torch_numeric,
173+
})
174+
except ImportError:
175+
pass
176+
177+

0 commit comments

Comments
 (0)