Skip to content

Commit 3bc17c9

Browse files
fix: handle Const in new matchers
1 parent 500ad84 commit 3bc17c9

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

src/matchers.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ function matcher(slot::Slot, acSets)
3232
val = get(bindings, slot.name, nothing)
3333
# if slot name already is in bindings, check if it matches
3434
if val !== nothing
35-
if isequal(val, car(data))::Bool
35+
if isequal(val, unwrap_const(car(data)))::Bool
3636
return next(bindings, 1)
3737
end
3838
# elseif the first element of data matches the slot predicate, add it to bindings and call next
@@ -74,7 +74,7 @@ function trymatchexpr(data, value, n)
7474
end
7575

7676
return !islist(value) ? n : nothing
77-
elseif isequal(value, data)
77+
elseif isequal(unwrap_const(value), unwrap_const(data))
7878
return n + 1
7979
end
8080
end
@@ -137,17 +137,17 @@ function term_matcher_constructor(term, acSets)
137137
result !== nothing && return success(result, 1)
138138

139139
frankestein = nothing
140-
if (operation(data) === ^) && iscall(arguments(data)[1]) && (operation(arguments(data)[1]) === /) && isequal(arguments(arguments(data)[1])[1], 1)
140+
if (operation(data) === ^) && iscall(arguments(data)[1]) && (operation(arguments(data)[1]) === /) && _isone(arguments(arguments(data)[1])[1])
141141
# if data is of the alternative form (1/...)^(...)
142142
one_over_smth = arguments(data)[1]
143143
T = vartype(one_over_smth)
144144
frankestein = Term{T}(^, [arguments(one_over_smth)[2], -arguments(data)[2]])
145-
elseif (operation(data) === /) && isequal(arguments(data)[1], 1) && iscall(arguments(data)[2]) && (operation(arguments(data)[2]) === ^)
145+
elseif (operation(data) === /) && _isone(arguments(data)[1]) && iscall(arguments(data)[2]) && (operation(arguments(data)[2]) === ^)
146146
# if data is of the alternative form 1/(...)^(...)
147147
denominator = arguments(data)[2]
148148
T = vartype(denominator)
149149
frankestein = Term{T}(^, [arguments(denominator)[1], -arguments(denominator)[2]])
150-
elseif (operation(data) === /) && isequal(arguments(data)[1], 1)
150+
elseif (operation(data) === /) && _isone(arguments(data)[1])
151151
# if data is of the alternative form 1/(...), it might match with exponent = -1
152152
denominator = arguments(data)[2]
153153
T = vartype(denominator)
@@ -172,7 +172,7 @@ function term_matcher_constructor(term, acSets)
172172
return pow_term_matcher
173173
# if we want to do commutative checks, i.e. call matcher with different order of the arguments
174174
elseif acSets!==nothing && operation(term) in [+, *]
175-
has_segment = any([isa(a,Segment) for a in arguments(term)])
175+
has_segment = any([isa(unwrap_const(a),Segment) for a in arguments(term)])
176176
function commutative_term_matcher(success, data, bindings)
177177
!islist(data) && return nothing # if data is not a list, return nothing
178178
data = car(data)
@@ -213,7 +213,7 @@ function term_matcher_constructor(term, acSets)
213213
result = loop(data, bindings, matchers)
214214
result !== nothing && return success(result, 1)
215215

216-
if (operation(data) === ^) && (arguments(data)[2] === 1//2)
216+
if (operation(data) === ^) && (unwrap_const(arguments(data)[2]) === 1//2)
217217
T = vartype(arguments(data)[1])
218218
frankestein = Term{T}(sqrt,[arguments(data)[1]])
219219
result = loop(frankestein, bindings, matchers)
@@ -233,7 +233,7 @@ function term_matcher_constructor(term, acSets)
233233
result = loop(data, bindings, matchers)
234234
result !== nothing && return success(result, 1)
235235

236-
if (operation(data) === ^) && (arguments(data)[1] === ℯ)
236+
if (operation(data) === ^) && (unwrap_const(arguments(data)[1]) === ℯ)
237237
T = vartype(arguments(data)[2])
238238
frankestein = Term{T}(exp,[arguments(data)[2]])
239239
result = loop(frankestein, bindings, matchers)

0 commit comments

Comments
 (0)