Skip to content

Commit f28c027

Browse files
Merge pull request #761 from JuliaSymbolics/as/const-variant
[WIP] feat: add `Const` variant
2 parents 76fc4e2 + 0476657 commit f28c027

21 files changed

+564
-360
lines changed

src/SymbolicUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ Base.@deprecate istree iscall
106106

107107
include("small_array.jl")
108108

109-
export istree, operation, arguments, sorted_arguments, iscall
109+
export istree, operation, arguments, sorted_arguments, iscall, unwrap_const
110110
# Sym, Term,
111111
# Add, Mul and Pow
112112
include("types.jl")

src/code.jl

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr,
1010
import ..SymbolicUtils
1111
import ..SymbolicUtils.Rewriters
1212
import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym,
13-
symtype, sorted_arguments, metadata, isterm, term, maketerm, Symbolic
13+
symtype, sorted_arguments, metadata, isterm, term, maketerm, unwrap_const,
14+
ArgsT, maybe_const
1415
import SymbolicIndexingInterface: symbolic_type, NotSymbolic
1516

1617
##== state management ==##
@@ -142,17 +143,20 @@ function function_to_expr(op::Union{typeof(*),typeof(+)}, O, st)
142143
end
143144

144145
function function_to_expr(op::typeof(^), O, st)
145-
args = arguments(O)
146-
if args[2] isa Real && args[2] < 0
147-
args[1] = Term(inv, Any[args[1]])
148-
args[2] = -args[2]
149-
end
150-
if isequal(args[2], 1)
151-
return toexpr(args[1], st)
146+
base, exp = arguments(O)
147+
base = unwrap_const(base)
148+
exp = unwrap_const(exp)
149+
if exp isa Real && exp < 0
150+
base = Term(inv, ArgsT((base,)))
151+
if isone(-exp)
152+
return toexpr(base, st)
153+
else
154+
exp = -exp
155+
end
152156
end
153-
if get(st.rewrites, :nanmath, false) === true && !(args[2] isa Integer)
157+
if get(st.rewrites, :nanmath, false) === true && !(exp isa Integer)
154158
op = NaNMath.pow
155-
return toexpr(Term(op, args), st)
159+
return toexpr(Term(op, ArgsT((maybe_const(base), maybe_const(exp)))), st)
156160
end
157161
return nothing
158162
end
@@ -203,11 +207,11 @@ end
203207
_is_tuple_of_symbolics(O) = false
204208

205209
function toexpr(O, st)
210+
O = unwrap_const(O)
211+
O = substitute_name(O, st)
206212
if issym(O)
207-
O = substitute_name(O, st)
208-
return issym(O) ? nameof(O) : toexpr(O, st)
213+
return nameof(O)
209214
end
210-
O = substitute_name(O, st)
211215

212216
if _is_array_of_symbolics(O)
213217
return issparse(O) ? toexpr(MakeSparseArray(O)) : toexpr(MakeArray(O, typeof(O)), st)
@@ -908,6 +912,7 @@ function cse!(expr::BasicSymbolic, state::CSEState)
908912
args = arguments(expr)
909913
cse_inside_expr(expr, op, args...) || return expr
910914
args = map(args) do arg
915+
arg = unwrap_const(arg)
911916
if arg isa Union{Tuple, AbstractArray} &&
912917
(_is_array_of_symbolics(arg) || _is_tuple_of_symbolics(arg))
913918
if arg isa Tuple

src/inspect.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import AbstractTrees
22

33
const inspect_metadata = Ref{Bool}(false)
4-
function AbstractTrees.nodevalue(x::Symbolic)
5-
iscall(x) ? operation(x) : isexpr(x) ? head(x) : x
6-
end
74

85
function AbstractTrees.nodevalue(x::BSImpl.Type)
96
T = nameof(MData.variant_type(x))
@@ -35,7 +32,7 @@ the expression.
3532
3633
This function is used internally for printing via AbstractTrees.
3734
"""
38-
function AbstractTrees.children(x::Symbolic)
35+
function AbstractTrees.children(x::BasicSymbolic)
3936
iscall(x) ? sorted_arguments(x) : isexpr(x) ? sorted_children(x) : ()
4037
end
4138

@@ -50,7 +47,7 @@ Line numbers will be shown, use `pluck(expr, line_number)` to get the sub expres
5047
"""
5148
function inspect end
5249

53-
function inspect(io::IO, x::Symbolic;
50+
function inspect(io::IO, x::BasicSymbolic;
5451
hint=true,
5552
metadata=inspect_metadata[])
5653

src/matchers.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
#
88

99
function matcher(val::Any, acSets)
10+
val = unwrap_const(val)
1011
# if val is a call (like an operation) creates a term matcher or term matcher with defslot
1112
if iscall(val)
1213
# if has two arguments and one of them is a DefSlot, create a term matcher with defslot
1314
# just two arguments bc defslot is only supported with operations with two args: *, ^, +
14-
if any(x -> isa(x, DefSlot), parent(arguments(val)))
15+
if any(x -> isa(unwrap_const(x), DefSlot), parent(arguments(val)))
1516
return defslot_term_matcher_constructor(val, acSets)
1617
end
1718
# else return a normal term matcher
@@ -20,7 +21,7 @@ function matcher(val::Any, acSets)
2021

2122
function literal_matcher(next, data, bindings)
2223
# car data is the first element of data
23-
islist(data) && isequal(car(data), val) ? next(bindings, 1) : nothing
24+
islist(data) && isequal(unwrap_const(car(data)), val) ? next(bindings, 1) : nothing
2425
end
2526
end
2627

@@ -35,7 +36,7 @@ function matcher(slot::Slot, acSets)
3536
return next(bindings, 1)
3637
end
3738
# elseif the first element of data matches the slot predicate, add it to bindings and call next
38-
elseif slot.predicate(car(data))
39+
elseif slot.predicate(unwrap_const(car(data)))
3940
rest = car(data)
4041
binds = assoc(bindings, slot.name, rest)
4142
next(binds, 1)
@@ -93,7 +94,7 @@ function matcher(segment::Segment, acSets)
9394
for i=length(data):-1:0
9495
subexpr = take_n(data, i)
9596

96-
!segment.predicate(subexpr) && continue
97+
!segment.predicate(unwrap_const(subexpr)) && continue
9798
res = success(assoc(bindings, segment.name, subexpr), i)
9899
res !== nothing && break
99100
end
@@ -104,7 +105,7 @@ function matcher(segment::Segment, acSets)
104105
end
105106

106107
function term_matcher_constructor(term, acSets)
107-
matchers = vcat([matcher(operation(term), acSets)], map(x -> matcher(x, acSets), parent(arguments(term))))
108+
matchers = vcat([matcher(operation(term), acSets)], map(x -> matcher(unwrap_const(x), acSets), parent(arguments(term))))
108109

109110
function loop(term, bindings′, matchers′) # Get it to compile faster
110111
if !islist(matchers′)
@@ -262,11 +263,11 @@ end
262263
# in the normal_matcher and in defslot_matcher and other_part_matcher
263264
function defslot_term_matcher_constructor(term, acSets)
264265
a = parent(arguments(term))
265-
defslot_index = findfirst(x -> isa(x, DefSlot), a) # find the defslot in the term
266-
defslot = a[defslot_index]
266+
defslot_index = findfirst(x -> isa(unwrap_const(x), DefSlot), a) # find the defslot in the term
267+
defslot = unwrap_const(a[defslot_index])
267268
defslot_matcher = matcher(defslot, acSets)
268269
if length(a) == 2
269-
other_part_matcher = matcher(a[defslot_index == 1 ? 2 : 1], acSets)
270+
other_part_matcher = matcher(unwrap_const(a[defslot_index == 1 ? 2 : 1]), acSets)
270271
else
271272
# if we hare here the operation is a multiplication or sum of n>2 terms
272273
# (because ^ cannot have more than 2 terms).

src/methods.jl

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -134,17 +134,17 @@ promote_symtype(::typeof(rem2pi), T::Type{<:Number}, mode) = T
134134

135135
error_f_symbolic(f, T) = error("$f is not defined for $T.")
136136

137-
function Base.rem2pi(x::Symbolic, mode::Base.RoundingMode)
137+
function Base.rem2pi(x::BasicSymbolic, mode::Base.RoundingMode)
138138
T = symtype(x)
139139
T <: Number ? term(rem2pi, x, mode) : error_f_symbolic(rem2pi, T)
140140
end
141141

142142
# Specially handle inv and literal pow
143-
function Base.inv(x::Symbolic)
143+
function Base.inv(x::BasicSymbolic)
144144
T = symtype(x)
145145
T <: Number ? Base.:^(x, -1) : error_f_symbolic(rem2pi, T)
146146
end
147-
function Base.literal_pow(::typeof(^), x::Symbolic, ::Val{p}) where {p}
147+
function Base.literal_pow(::typeof(^), x::BasicSymbolic, ::Val{p}) where {p}
148148
T = symtype(x)
149149
T <: Number ? Base.:^(x, p) : error_f_symbolic(^, T)
150150
end
@@ -159,20 +159,20 @@ for f in monadic
159159
@eval promote_symtype(::$(typeof(f)), T::Type{<:LiteralRealImpl}) = LiteralReal
160160
end
161161

162-
Base.:*(a::AbstractArray, b::Symbolic{<:Number}) = map(x->x*b, a)
163-
Base.:*(a::Symbolic{<:Number}, b::AbstractArray) = map(x->a*x, b)
162+
Base.:*(a::AbstractArray, b::BasicSymbolic{<:Number}) = map(x->x*b, a)
163+
Base.:*(a::BasicSymbolic{<:Number}, b::AbstractArray) = map(x->a*x, b)
164164

165165
for f in [identity, one, zero, *, +, -]
166166
@eval promote_symtype(::$(typeof(f)), T::Type{<:Number}) = T
167167
end
168168

169169
promote_symtype(::typeof(Base.real), T::Type{<:Number}) = Real
170-
Base.real(s::Symbolic{<:Number}) = islike(s, Real) ? s : term(real, s)
170+
Base.real(s::BasicSymbolic{<:Number}) = islike(s, Real) ? s : term(real, s)
171171
promote_symtype(::typeof(Base.conj), T::Type{<:Number}) = T
172-
Base.conj(s::Symbolic{<:Number}) = islike(s, Real) ? s : term(conj, s)
172+
Base.conj(s::BasicSymbolic{<:Number}) = islike(s, Real) ? s : term(conj, s)
173173
promote_symtype(::typeof(Base.imag), T::Type{<:Number}) = Real
174-
Base.imag(s::Symbolic{<:Number}) = islike(s, Real) ? zero(symtype(s)) : term(imag, s)
175-
Base.adjoint(s::Symbolic{<:Number}) = conj(s)
174+
Base.imag(s::BasicSymbolic{<:Number}) = islike(s, Real) ? zero(symtype(s)) : term(imag, s)
175+
Base.adjoint(s::BasicSymbolic{<:Number}) = conj(s)
176176

177177

178178
## Booleans
@@ -186,29 +186,29 @@ for (f, Domain) in [(==) => Number, (!=) => Number,
186186
xor => Bool]
187187
@eval begin
188188
promote_symtype(::$(typeof(f)), ::Type{<:$Domain}, ::Type{<:$Domain}) = Bool
189-
(::$(typeof(f)))(a::Symbolic{<:$Domain}, b::$Domain) = term($f, a, b, type=Bool)
190-
(::$(typeof(f)))(a::Symbolic{<:$Domain}, b::Symbolic{<:$Domain}) = term($f, a, b, type=Bool)
191-
(::$(typeof(f)))(a::$Domain, b::Symbolic{<:$Domain}) = term($f, a, b, type=Bool)
189+
(::$(typeof(f)))(a::BasicSymbolic{<:$Domain}, b::$Domain) = term($f, a, b, type=Bool)
190+
(::$(typeof(f)))(a::BasicSymbolic{<:$Domain}, b::BasicSymbolic{<:$Domain}) = term($f, a, b, type=Bool)
191+
(::$(typeof(f)))(a::$Domain, b::BasicSymbolic{<:$Domain}) = term($f, a, b, type=Bool)
192192
end
193193
end
194194

195195
for f in [!, ~]
196196
@eval begin
197197
promote_symtype(::$(typeof(f)), ::Type{<:Bool}) = Bool
198-
(::$(typeof(f)))(s::Symbolic{Bool}) = Term{Bool}(!, [s])
198+
(::$(typeof(f)))(s::BasicSymbolic{Bool}) = Term{Bool}(!, [s])
199199
end
200200
end
201201

202202

203203
# An ifelse node
204-
function Base.ifelse(_if::Symbolic{Bool}, _then, _else)
204+
function Base.ifelse(_if::BasicSymbolic{Bool}, _then, _else)
205205
Term{Union{symtype(_then), symtype(_else)}}(ifelse, Any[_if, _then, _else])
206206
end
207207
promote_symtype(::typeof(ifelse), _, ::Type{T}, ::Type{S}) where {T,S} = Union{T, S}
208208

209209
# Array-like operations
210-
Base.size(x::Symbolic{<:Number}) = ()
211-
Base.length(x::Symbolic{<:Number}) = 1
212-
Base.ndims(x::Symbolic{T}) where {T} = Base.ndims(T)
213-
Base.ndims(::Type{<:Symbolic{T}}) where {T} = Base.ndims(T)
214-
Base.broadcastable(x::Symbolic{T}) where {T<:Number} = Ref(x)
210+
Base.size(x::BasicSymbolic{<:Number}) = ()
211+
Base.length(x::BasicSymbolic{<:Number}) = 1
212+
Base.ndims(x::BasicSymbolic{T}) where {T} = Base.ndims(T)
213+
Base.ndims(::Type{<:BasicSymbolic{T}}) where {T} = Base.ndims(T)
214+
Base.broadcastable(x::BasicSymbolic{T}) where {T<:Number} = Ref(x)

src/ordering.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
<(a::Real, b::Complex) = true
66
<(a::Complex, b::Real) = false
77

8-
<(a::Symbolic, b::Number) = false
9-
<(a::Number, b::Symbolic) = true
8+
<(a::BasicSymbolic, b::Number) = false
9+
<(a::Number, b::BasicSymbolic) = true
1010

1111
<(a::Function, b::Function) = nameof(a) <nameof(b)
1212

@@ -89,6 +89,8 @@ end
8989

9090
function _get_degrees(::typeof(^), expr, degs_cache)
9191
base_expr, pow_expr = arguments(expr)
92+
base_expr = unwrap_const(base_expr)
93+
pow_expr = unwrap_const(pow_expr)
9294
if pow_expr isa Real
9395
@inbounds degs = map(_get_degrees(base_expr, degs_cache)) do (base, pow)
9496
(base => pow * pow_expr)
@@ -107,6 +109,8 @@ end
107109

108110
function _get_degrees(::typeof(/), expr, degs_cache)
109111
nom_expr, denom_expr = arguments(expr)
112+
nom_expr = unwrap_const(nom_expr)
113+
denom_expr = unwrap_const(denom_expr)
110114
if denom_expr isa Number # constant denominator
111115
return _get_degrees(nom_expr, degs_cache)
112116
elseif nom_expr isa Number # constant nominator
@@ -157,11 +161,14 @@ function <ₑ(a::Tuple, b::Tuple)
157161
end
158162

159163
function <(a::BasicSymbolic, b::BasicSymbolic)
164+
if isconst(a) || isconst(b)
165+
return <(unwrap_const(a), unwrap_const(b))
166+
end
160167
da, db = get_degrees(a), get_degrees(b)
161168
fw = monomial_lt(da, db)
162169
bw = monomial_lt(db, da)
163170
if fw === bw && !isequal(a, b)
164-
if _arglen(a) == _arglen(b)
171+
if _arglen(a) == _arglen(b) != 0
165172
return (operation(a), arguments(a)...,) <ₑ (operation(b), arguments(b)...,)
166173
else
167174
return _arglen(a) < _arglen(b)

0 commit comments

Comments
 (0)