Skip to content

Commit 48e9c56

Browse files
refactor: remove abstract type Symbolic
1 parent 25e5cbf commit 48e9c56

File tree

9 files changed

+56
-75
lines changed

9 files changed

+56
-75
lines changed

src/code.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr,
1010
import ..SymbolicUtils
1111
import ..SymbolicUtils.Rewriters
1212
import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym,
13-
symtype, sorted_arguments, metadata, isterm, term, maketerm, Symbolic, unwrap_const,
13+
symtype, sorted_arguments, metadata, isterm, term, maketerm, unwrap_const,
1414
ArgsT, maybe_const
1515
import SymbolicIndexingInterface: symbolic_type, NotSymbolic
1616

src/inspect.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import AbstractTrees
22

33
const inspect_metadata = Ref{Bool}(false)
4-
function AbstractTrees.nodevalue(x::Symbolic)
5-
iscall(x) ? operation(x) : isexpr(x) ? head(x) : x
6-
end
74

85
function AbstractTrees.nodevalue(x::BSImpl.Type)
96
T = nameof(MData.variant_type(x))
@@ -35,7 +32,7 @@ the expression.
3532
3633
This function is used internally for printing via AbstractTrees.
3734
"""
38-
function AbstractTrees.children(x::Symbolic)
35+
function AbstractTrees.children(x::BasicSymbolic)
3936
iscall(x) ? sorted_arguments(x) : isexpr(x) ? sorted_children(x) : ()
4037
end
4138

@@ -50,7 +47,7 @@ Line numbers will be shown, use `pluck(expr, line_number)` to get the sub expres
5047
"""
5148
function inspect end
5249

53-
function inspect(io::IO, x::Symbolic;
50+
function inspect(io::IO, x::BasicSymbolic;
5451
hint=true,
5552
metadata=inspect_metadata[])
5653

src/methods.jl

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -134,17 +134,17 @@ promote_symtype(::typeof(rem2pi), T::Type{<:Number}, mode) = T
134134

135135
error_f_symbolic(f, T) = error("$f is not defined for T.")
136136

137-
function Base.rem2pi(x::Symbolic, mode::Base.RoundingMode)
137+
function Base.rem2pi(x::BasicSymbolic, mode::Base.RoundingMode)
138138
T = symtype(x)
139139
T <: Number ? term(rem2pi, x, mode) : error_f_symbolic(rem2pi, T)
140140
end
141141

142142
# Specially handle inv and literal pow
143-
function Base.inv(x::Symbolic)
143+
function Base.inv(x::BasicSymbolic)
144144
T = symtype(x)
145145
T <: Number ? Base.:^(x, -1) : error_f_symbolic(rem2pi, T)
146146
end
147-
function Base.literal_pow(::typeof(^), x::Symbolic, ::Val{p}) where {p}
147+
function Base.literal_pow(::typeof(^), x::BasicSymbolic, ::Val{p}) where {p}
148148
T = symtype(x)
149149
T <: Number ? Base.:^(x, p) : error_f_symbolic(^, T)
150150
end
@@ -159,20 +159,20 @@ for f in monadic
159159
@eval promote_symtype(::$(typeof(f)), T::Type{<:LiteralRealImpl}) = LiteralReal
160160
end
161161

162-
Base.:*(a::AbstractArray, b::Symbolic{<:Number}) = map(x->x*b, a)
163-
Base.:*(a::Symbolic{<:Number}, b::AbstractArray) = map(x->a*x, b)
162+
Base.:*(a::AbstractArray, b::BasicSymbolic{<:Number}) = map(x->x*b, a)
163+
Base.:*(a::BasicSymbolic{<:Number}, b::AbstractArray) = map(x->a*x, b)
164164

165165
for f in [identity, one, zero, *, +, -]
166166
@eval promote_symtype(::$(typeof(f)), T::Type{<:Number}) = T
167167
end
168168

169169
promote_symtype(::typeof(Base.real), T::Type{<:Number}) = Real
170-
Base.real(s::Symbolic{<:Number}) = islike(s, Real) ? s : term(real, s)
170+
Base.real(s::BasicSymbolic{<:Number}) = islike(s, Real) ? s : term(real, s)
171171
promote_symtype(::typeof(Base.conj), T::Type{<:Number}) = T
172-
Base.conj(s::Symbolic{<:Number}) = islike(s, Real) ? s : term(conj, s)
172+
Base.conj(s::BasicSymbolic{<:Number}) = islike(s, Real) ? s : term(conj, s)
173173
promote_symtype(::typeof(Base.imag), T::Type{<:Number}) = Real
174-
Base.imag(s::Symbolic{<:Number}) = islike(s, Real) ? zero(symtype(s)) : term(imag, s)
175-
Base.adjoint(s::Symbolic{<:Number}) = conj(s)
174+
Base.imag(s::BasicSymbolic{<:Number}) = islike(s, Real) ? zero(symtype(s)) : term(imag, s)
175+
Base.adjoint(s::BasicSymbolic{<:Number}) = conj(s)
176176

177177

178178
## Booleans
@@ -186,29 +186,29 @@ for (f, Domain) in [(==) => Number, (!=) => Number,
186186
xor => Bool]
187187
@eval begin
188188
promote_symtype(::$(typeof(f)), ::Type{<:$Domain}, ::Type{<:$Domain}) = Bool
189-
(::$(typeof(f)))(a::Symbolic{<:$Domain}, b::$Domain) = term($f, a, b, type=Bool)
190-
(::$(typeof(f)))(a::Symbolic{<:$Domain}, b::Symbolic{<:$Domain}) = term($f, a, b, type=Bool)
191-
(::$(typeof(f)))(a::$Domain, b::Symbolic{<:$Domain}) = term($f, a, b, type=Bool)
189+
(::$(typeof(f)))(a::BasicSymbolic{<:$Domain}, b::$Domain) = term($f, a, b, type=Bool)
190+
(::$(typeof(f)))(a::BasicSymbolic{<:$Domain}, b::BasicSymbolic{<:$Domain}) = term($f, a, b, type=Bool)
191+
(::$(typeof(f)))(a::$Domain, b::BasicSymbolic{<:$Domain}) = term($f, a, b, type=Bool)
192192
end
193193
end
194194

195195
for f in [!, ~]
196196
@eval begin
197197
promote_symtype(::$(typeof(f)), ::Type{<:Bool}) = Bool
198-
(::$(typeof(f)))(s::Symbolic{Bool}) = Term{Bool}(!, [s])
198+
(::$(typeof(f)))(s::BasicSymbolic{Bool}) = Term{Bool}(!, [s])
199199
end
200200
end
201201

202202

203203
# An ifelse node
204-
function Base.ifelse(_if::Symbolic{Bool}, _then, _else)
204+
function Base.ifelse(_if::BasicSymbolic{Bool}, _then, _else)
205205
Term{Union{symtype(_then), symtype(_else)}}(ifelse, Any[_if, _then, _else])
206206
end
207207
promote_symtype(::typeof(ifelse), _, ::Type{T}, ::Type{S}) where {T,S} = Union{T, S}
208208

209209
# Array-like operations
210-
Base.size(x::Symbolic{<:Number}) = ()
211-
Base.length(x::Symbolic{<:Number}) = 1
212-
Base.ndims(x::Symbolic{T}) where {T} = Base.ndims(T)
213-
Base.ndims(::Type{<:Symbolic{T}}) where {T} = Base.ndims(T)
214-
Base.broadcastable(x::Symbolic{T}) where {T<:Number} = Ref(x)
210+
Base.size(x::BasicSymbolic{<:Number}) = ()
211+
Base.length(x::BasicSymbolic{<:Number}) = 1
212+
Base.ndims(x::BasicSymbolic{T}) where {T} = Base.ndims(T)
213+
Base.ndims(::Type{<:BasicSymbolic{T}}) where {T} = Base.ndims(T)
214+
Base.broadcastable(x::BasicSymbolic{T}) where {T<:Number} = Ref(x)

src/ordering.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
<(a::Real, b::Complex) = true
66
<(a::Complex, b::Real) = false
77

8-
<(a::Symbolic, b::Number) = false
9-
<(a::Number, b::Symbolic) = true
8+
<(a::BasicSymbolic, b::Number) = false
9+
<(a::Number, b::BasicSymbolic) = true
1010

1111
<(a::Function, b::Function) = nameof(a) <nameof(b)
1212

src/substitute.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ function (s::Substituter)(expr)
77
end
88

99
function _const_or_not_symbolic(x)
10-
isconst(x) || !(x isa Symbolic)
10+
isconst(x) || !(x isa BasicSymbolic)
1111
end
1212

1313
function combine_fold(::Type{T}, op, args::ArgsT, meta) where {T}
1414
@nospecialize op args meta
15-
can_fold = !(op isa Symbolic) && all(_const_or_not_symbolic, args)
15+
can_fold = !(op isa BasicSymbolic) && all(_const_or_not_symbolic, args)
1616
if can_fold
1717
if op === (+)
1818
add_worker(args)
@@ -60,14 +60,14 @@ julia> substitute(1+sqrt(y), Dict(y => 2), fold=false)
6060
end
6161

6262
"""
63-
occursin(needle::Symbolic, haystack::Symbolic)
63+
occursin(needle::BasicSymbolic, haystack::BasicSymbolic)
6464
6565
Determine whether the second argument contains the first argument. Note that
6666
this function doesn't handle associativity, commutativity, or distributivity.
6767
"""
68-
Base.occursin(needle::Symbolic, haystack::Symbolic) = _occursin(needle, haystack)
69-
Base.occursin(needle, haystack::Symbolic) = _occursin(needle, haystack)
70-
Base.occursin(needle::Symbolic, haystack) = _occursin(needle, haystack)
68+
Base.occursin(needle::BasicSymbolic, haystack::BasicSymbolic) = _occursin(needle, haystack)
69+
Base.occursin(needle, haystack::BasicSymbolic) = _occursin(needle, haystack)
70+
Base.occursin(needle::BasicSymbolic, haystack) = _occursin(needle, haystack)
7171
function _occursin(needle, haystack)
7272
isequal(unwrap_const(needle), unwrap_const(haystack)) && return true
7373
if iscall(haystack)

src/types.jl

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#--------------------
33
#### Symbolic
44
#--------------------
5-
abstract type Symbolic{T} end
65

76
#################### SafeReal #########################
87
export SafeReal, LiteralReal
@@ -55,7 +54,7 @@ end
5554
5655
Core ADT for `BasicSymbolic`. `hash` and `isequal` compare metadata.
5756
"""
58-
@data mutable BasicSymbolicImpl{T} <: Symbolic{T} begin
57+
@data mutable BasicSymbolicImpl{T} begin
5958
struct Const
6059
const val::T
6160
id::IdentT
@@ -491,11 +490,10 @@ end
491490
### Base interface
492491
###
493492

494-
Base.isequal(::Symbolic, x) = false
495-
Base.isequal(x, ::Symbolic) = false
496-
Base.isequal(::Symbolic, ::Missing) = false
497-
Base.isequal(::Missing, ::Symbolic) = false
498-
Base.isequal(::Symbolic, ::Symbolic) = false
493+
Base.isequal(::BasicSymbolic, x) = false
494+
Base.isequal(x, ::BasicSymbolic) = false
495+
Base.isequal(::BasicSymbolic, ::Missing) = false
496+
Base.isequal(::Missing, ::BasicSymbolic) = false
499497

500498
const SCALAR_SYMTYPE_VARIANTS = [Number, Real, SafeReal, LiteralReal, Int, Float64, Bool]
501499
const ARR_VARIANTS = [Vector, Matrix]
@@ -718,8 +716,8 @@ function Base.hash(s::BSImpl.Type, h::UInt)
718716
hash_bsimpl(s, h, COMPARE_FULL[])
719717
end
720718

721-
Base.one( s::Union{Symbolic, BSImpl.Type}) = one( symtype(s))
722-
Base.zero(s::Union{Symbolic, BSImpl.Type}) = zero(symtype(s))
719+
Base.one( s::BSImpl.Type) = one( symtype(s))
720+
Base.zero(s::BSImpl.Type) = zero(symtype(s))
723721

724722

725723
Base.nameof(s::Union{BasicSymbolic, BSImpl.Type}) = issym(s) ? s.name : error("Non-Sym BasicSymbolic doesn't have a name")
@@ -975,7 +973,7 @@ struct Div{T} end
975973

976974
function Const{T}(val) where {T}
977975
val = unwrap(val)
978-
val isa Symbolic && return val
976+
val isa BasicSymbolic && return val
979977
BSImpl.Const{T}(convert(T, val))
980978
end
981979

@@ -1227,12 +1225,12 @@ end
12271225
metadata(s::BSImpl.Type) = isconst(s) ? nothing : s.metadata
12281226
metadata(s::BasicSymbolic, meta) = Setfield.@set! s.metadata = meta
12291227

1230-
function hasmetadata(s::Symbolic, ctx)
1228+
function hasmetadata(s::BasicSymbolic, ctx)
12311229
metadata(s) isa AbstractDict && haskey(metadata(s), ctx)
12321230
end
12331231

12341232
issafecanon(f, s) = true
1235-
function issafecanon(f, s::Symbolic)
1233+
function issafecanon(f, s::BasicSymbolic)
12361234
if metadata(s) === nothing || isempty(metadata(s)) || issym(s)
12371235
return true
12381236
else
@@ -1245,7 +1243,7 @@ _issafecanon(::typeof(^), s) = !iscall(s) || !(operation(s) in (*, ^))
12451243

12461244
issafecanon(f, ss...) = all(x->issafecanon(f, x), ss)
12471245

1248-
function getmetadata(s::Symbolic, ctx)
1246+
function getmetadata(s::BasicSymbolic, ctx)
12491247
md = metadata(s)
12501248
if md isa AbstractDict
12511249
md[ctx]
@@ -1254,7 +1252,7 @@ function getmetadata(s::Symbolic, ctx)
12541252
end
12551253
end
12561254

1257-
function getmetadata(s::Symbolic, ctx, default)
1255+
function getmetadata(s::BasicSymbolic, ctx, default)
12581256
md = metadata(s)
12591257
md isa AbstractDict ? get(md, ctx, default) : default
12601258
end
@@ -1282,7 +1280,7 @@ function assocmeta(d::Base.ImmutableDict, ctx, val)::ImmutableDict{DataType,Any}
12821280
Base.ImmutableDict{DataType, Any}(d, ctx, val)
12831281
end
12841282

1285-
function setmetadata(s::Symbolic, ctx::DataType, val)
1283+
function setmetadata(s::BasicSymbolic, ctx::DataType, val)
12861284
if s.metadata isa AbstractDict
12871285
@set s.metadata = assocmeta(s.metadata, ctx, val)
12881286
else
@@ -1521,9 +1519,9 @@ promote_symtype(f, Ts...) = Any
15211519

15221520
struct FnType{X<:Tuple,Y,Z} end
15231521

1524-
(f::Symbolic{<:FnType})(args...) = Term{promote_symtype(f, symtype.(args)...)}(f, SmallV{Any}(args))
1522+
(f::BasicSymbolic{<:FnType})(args...) = Term{promote_symtype(f, symtype.(args)...)}(f, SmallV{Any}(args))
15251523

1526-
function (f::Symbolic)(args...)
1524+
function (f::BasicSymbolic)(args...)
15271525
error("Sym $f is not callable. " *
15281526
"Use @syms $f(var1, var2,...) to create it as a callable.")
15291527
end
@@ -1548,7 +1546,7 @@ function promote_symtype(f::BasicSymbolic{<:FnType{X,Y}}, args...) where {X, Y}
15481546
return Y
15491547
end
15501548

1551-
function Base.show(io::IO, f::Symbolic{<:FnType{X,Y}}) where {X,Y}
1549+
function Base.show(io::IO, f::BasicSymbolic{<:FnType{X,Y}}) where {X,Y}
15521550
print(io, nameof(f))
15531551
# Use `Base.unwrap_unionall` to handle `Tuple{T} where T`. This is not the
15541552
# best printing, but it's better than erroring.
@@ -1657,7 +1655,7 @@ end
16571655
###
16581656
### Arithmetic
16591657
###
1660-
const SN = Symbolic{<:Number}
1658+
const SN = BasicSymbolic{<:Number}
16611659
# integration. Constructors of `Add, Mul, Pow...` from Base (+, *, ^, ...)
16621660

16631661
add_t(a::Number,b::Number) = promote_symtype(+, symtype(a), symtype(b))

src/utils.jl

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ using Base: ImmutableDict
22

33

44
pow(x,y) = y==0 ? 1 : y<0 ? inv(x)^(-y) : x^y
5-
pow(x::Symbolic,y) = y==0 ? 1 : Base.:^(x,y)
6-
pow(x, y::Symbolic) = Base.:^(x,y)
7-
pow(x::Symbolic,y::Symbolic) = Base.:^(x,y)
5+
pow(x::BasicSymbolic,y) = y==0 ? 1 : Base.:^(x,y)
6+
pow(x, y::BasicSymbolic) = Base.:^(x,y)
7+
pow(x::BasicSymbolic,y::BasicSymbolic) = Base.:^(x,y)
88

99
# Simplification utilities
1010
function has_trig_exp(term)
@@ -19,20 +19,6 @@ function has_trig_exp(term)
1919
end
2020
end
2121

22-
function fold(t)
23-
if iscall(t)
24-
tt = map(fold, parent(arguments(t)))
25-
if !any(x->x isa Symbolic, tt)
26-
# evaluate it
27-
return operation(t)(tt...)
28-
else
29-
return maketerm(typeof(t), operation(t), tt, metadata(t))
30-
end
31-
else
32-
return t
33-
end
34-
end
35-
3622
### Predicates
3723

3824
sym_isa(::Type{T}) where {T} = @nospecialize(x) -> x isa T || symtype(x) <: T
@@ -53,8 +39,8 @@ function _isone(x)
5339
x isa Array && return isone(x)
5440
return false
5541
end
56-
_isinteger(x) = (x isa Number && isinteger(x)) || (x isa Symbolic && symtype(x) <: Integer)
57-
_isreal(x) = (x isa Number && isreal(x)) || (x isa Symbolic && symtype(x) <: Real)
42+
_isinteger(x) = (x isa Number && isinteger(x)) || (x isa BasicSymbolic && symtype(x) <: Integer)
43+
_isreal(x) = (x isa Number && isreal(x)) || (x isa BasicSymbolic && symtype(x) <: Real)
5844

5945
issortedₑ(args) = issorted(args, lt=<ₑ)
6046
needs_sorting(f) = x -> is_operation(f)(x) && !issortedₑ(arguments(x))

test/basics.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using SymbolicUtils: Symbolic, Sym, FnType, Term, Polyform, symtype, operation, arguments, issym, isterm, BasicSymbolic, term, basicsymbolic_to_polyvar, get_mul_coefficient, PolynomialT, Const
1+
using SymbolicUtils: Sym, FnType, Term, Polyform, symtype, operation, arguments, issym, isterm, BasicSymbolic, term, basicsymbolic_to_polyvar, get_mul_coefficient, PolynomialT, Const
22
using SymbolicUtils
33
using ConstructionBase: setproperties
44
import MultivariatePolynomials as MP
@@ -165,7 +165,7 @@ end
165165

166166
@testset "array-like operations" begin
167167
abstract type SquareDummy end
168-
Base.:*(a::Symbolic{SquareDummy}, b) = b^2
168+
Base.:*(a::BasicSymbolic{SquareDummy}, b) = b^2
169169
@syms s t a::SquareDummy A[1:2, 1:2]
170170

171171
@test isequal(ndims(A), 2)

test/fuzzlib.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using SymbolicUtils
2-
using SymbolicUtils: Term, showraw, Symbolic, issym
2+
using SymbolicUtils: BasicSymbolic, Term, showraw, issym
33
using SpecialFunctions
44
using Test
55
using NaNMath
@@ -30,7 +30,7 @@ function rand_input(T)
3030
end
3131
end
3232

33-
rand_input(i::Symbolic{T}) where {T} = rand_input(T)
33+
rand_input(i::BasicSymbolic{T}) where {T} = rand_input(T)
3434

3535
const num_spec = let
3636
@syms a b::Real c::Integer d::Float64 e::Rational f

0 commit comments

Comments
 (0)