From 44deb6e76a76ee8f4ad32b5a306b20a7c00ed15e Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 3 Jan 2022 14:49:08 +0100 Subject: [PATCH 01/24] implement Lambert W function This copies the implementation of the Lambert W function and related functions and data from jlapeyre/LambertW.jl to SpecialFunctions.jl --- README.md | 2 +- docs/src/functions_list.md | 3 + docs/src/functions_overview.md | 6 + src/SpecialFunctions.jl | 5 +- src/lambertw.jl | 386 +++++++++++++++++++++++++++++++++ test/lambertw.jl | 177 +++++++++++++++ test/runtests.jl | 1 + 7 files changed, 578 insertions(+), 2 deletions(-) create mode 100644 src/lambertw.jl create mode 100644 test/lambertw.jl diff --git a/README.md b/README.md index 206679b8..d4e5ef42 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # SpecialFunctions.jl Special mathematical functions in Julia, including Bessel, Hankel, Airy, error, Dawson, exponential (or sine and cosine) integrals, -eta, zeta, digamma, inverse digamma, trigamma, and polygamma functions. +eta, zeta, Lambert's W, digamma, inverse digamma, trigamma, and polygamma functions. Most of these functions were formerly part of Base in early versions of Julia. CI (Linux, macOS, FreeBSD, Windows): diff --git a/docs/src/functions_list.md b/docs/src/functions_list.md index 56d36c45..d8583f16 100644 --- a/docs/src/functions_list.md +++ b/docs/src/functions_list.md @@ -69,4 +69,7 @@ SpecialFunctions.beta SpecialFunctions.logbeta SpecialFunctions.logabsbeta SpecialFunctions.logabsbinomial +SpecialFunctions.lambertw +SpecialFunctions.lambertwbp +SpecialFunctions.omega ``` diff --git a/docs/src/functions_overview.md b/docs/src/functions_overview.md index 01c6089b..350fb671 100644 --- a/docs/src/functions_overview.md +++ b/docs/src/functions_overview.md @@ -95,3 +95,9 @@ Here the *Special Functions* are listed according to the structure of [NIST Digi |:-------- |:----------- | | [`eta(x)`](@ref SpecialFunctions.eta) | [Dirichlet eta function](https://en.wikipedia.org/wiki/Dirichlet_eta_function) at `x` | | [`zeta(x)`](@ref SpecialFunctions.zeta) | [Riemann zeta function](https://en.wikipedia.org/wiki/Riemann_zeta_function) at `x` | + +## [Lambert's W Function](https://dlmf.nist.gov/4.13) +| Function | Description | +|:-------- |:----------- | +| [`lambertw(x, [k=0])`](@ref SpecialFunctions.lambertw) | [Lambert's W function](https://en.wikipedia.org/wiki/Lambert_W_function) at `x` for `k`-th branch | +| [`lambertwbp(x, [k=0])`](@ref SpecialFunctions.lambertwbp) | Accurate value of ``1 + W_k(-\frac{1}{\mathrm{e}} + x)`` for small `x` | \ No newline at end of file diff --git a/src/SpecialFunctions.jl b/src/SpecialFunctions.jl index 9c482350..9c31524b 100644 --- a/src/SpecialFunctions.jl +++ b/src/SpecialFunctions.jl @@ -77,7 +77,9 @@ export expintx, sinint, cosint, - lbinomial + lbinomial, + lambertw, + lambertwbp include("bessel.jl") include("erf.jl") @@ -90,6 +92,7 @@ include("gamma.jl") include("gamma_inc.jl") include("betanc.jl") include("beta_inc.jl") +include("lambertw.jl") if !isdefined(Base, :get_extension) include("../ext/SpecialFunctionsChainRulesCoreExt.jl") end diff --git a/src/lambertw.jl b/src/lambertw.jl new file mode 100644 index 00000000..b98f7aee --- /dev/null +++ b/src/lambertw.jl @@ -0,0 +1,386 @@ +import Base: convert +#export lambertw, lambertwbp +using Compat + +const euler = +if isdefined(Base, :MathConstants) + Base.MathConstants.e +else + e +end + +const omega_const_bf_ = Ref{BigFloat}() + +function __init__() + omega_const_bf_[] = + parse(BigFloat,"0.5671432904097838729999686622103555497538157871865125081351310792230457930866845666932194") +end + +#### Lambert W function #### + +const LAMBERTW_USE_NAN = false + +macro baddomain(v) + if LAMBERTW_USE_NAN + return :(return(NaN)) + else + return esc(:(throw(DomainError($v)))) + end +end + +# Use Halley's root-finding method to find x = lambertw(z) with +# initial point x. +function _lambertw(z::T, x::T) where T <: Number + two_t = convert(T,2) + lastx = x + lastdiff = zero(T) + for i in 1:100 + ex = exp(x) + xexz = x * ex - z + x1 = x + 1 + x = x - xexz / (ex * x1 - (x + two_t) * xexz / (two_t * x1 ) ) + xdiff = abs(lastx - x) + xdiff <= 2*eps(abs(lastx)) && break + lastdiff == diff && break + lastx = x + lastdiff = xdiff + end + x +end + +### Real z ### + +# Real x, k = 0 + +# The fancy initial condition selection does not seem to help speed, but we leave it for now. +function lambertwk0(x::T)::T where T<:AbstractFloat + x == Inf && return Inf + one_t = one(T) + oneoe = -one_t/convert(T,euler) + x == oneoe && return -one_t + itwo_t = 1/convert(T,2) + oneoe <= x || @baddomain(x) + if x > one_t + lx = log(x) + llx = log(lx) + x1 = lx - llx - log(one_t - llx/lx) * itwo_t + else + x1 = (567//1000) * x + end + _lambertw(x,x1) +end + +# Real x, k = -1 +function _lambertwkm1(x::T) where T<:Real + oneoe = -one(T)/convert(T,euler) + x == oneoe && return -one(T) + oneoe <= x || @baddomain(x) + x == zero(T) && return -convert(T,Inf) + x < zero(T) || @baddomain(x) + _lambertw(x,log(-x)) +end + + +""" + lambertw(z::Complex{T}, k::V=0) where {T<:Real, V<:Integer} + lambertw(z::T, k::V=0) where {T<:Real, V<:Integer} + +Compute the `k`th branch of the Lambert W function of `z`. If `z` is real, `k` must be +either `0` or `-1`. For `Real` `z`, the domain of the branch `k = -1` is `[-1/e,0]` and the +domain of the branch `k = 0` is `[-1/e,Inf]`. For `Complex` `z`, and all `k`, the domain is +the complex plane. + +```jldoctest +julia> lambertw(-1/e,-1) +-1.0 + +julia> lambertw(-1/e,0) +-1.0 + +julia> lambertw(0,0) +0.0 + +julia> lambertw(0,-1) +-Inf + +julia> lambertw(Complex(-10.0,3.0), 4) +-0.9274337508660128 + 26.37693445371142im +``` + +!!! note + The constant `LAMBERTW_USE_NAN` at the top of the source file controls whether arguments + outside the domain throw `DomainError` or return `NaN`. The default is `DomainError`. +""" +function lambertw(x::Real, k::Integer) + k == 0 && return lambertwk0(x) + k == -1 && return _lambertwkm1(x) + @baddomain(k) # more informative message like below ? +# error("lambertw: real x must have k == 0 or k == -1") +end + +function lambertw(x::Union{Integer,Rational}, k::Integer) + if k == 0 + x == 0 && return float(zero(x)) + x == 1 && return convert(typeof(float(x)),omega) # must be more efficient way + end + lambertw(float(x),k) +end + +### Complex z ### + +# choose initial value inside correct branch for root finding +function lambertw(z::Complex{T}, k::Integer) where T<:Real + one_t = one(T) + local w::Complex{T} + pointseven = 7//10 + if abs(z) <= one_t/convert(T,euler) + if z == 0 + k == 0 && return z + return complex(-convert(T,Inf),zero(T)) + end + if k == 0 + w = z + elseif k == -1 && imag(z) == 0 && real(z) < 0 + w = complex(log(-real(z)),1//10^7) # need offset for z ≈ -1/e. + else + w = log(z) + k != 0 ? w += complex(0,k * 2 * pi) : nothing + end + elseif k == 0 && imag(z) <= pointseven && abs(z) <= pointseven + w = abs(z+ 1//2) < 1//10 ? imag(z) > 0 ? complex(pointseven,pointseven) : complex(pointseven,-pointseven) : z + else + if real(z) == convert(T,Inf) + k == 0 && return z + return z + complex(0,2*k*pi) + end + real(z) == -convert(T,Inf) && return -z + complex(0,(2*k+1)*pi) + w = log(z) + k != 0 ? w += complex(0, 2*k*pi) : nothing + end + _lambertw(z,w) +end + +lambertw(z::Complex{T}, k::Integer) where T<:Integer = lambertw(float(z),k) + +# lambertw(e + 0im,k) is ok for all k +#function lambertw(::Irrational{:e}, k::T) where T<:Integer +function lambertw(::typeof(euler), k::T) where T<:Integer + k == 0 && return 1 + @baddomain(k) +end + +# Maybe this should return a float +lambertw(::typeof(euler)) = 1 +#lambertw(::Irrational{:e}) = 1 + +#lambertw{T<:Number}(x::T) = lambertw(x,0) +lambertw(x::Number) = lambertw(x,0) + +lambertw(n::Irrational, args::Integer...) = lambertw(float(n),args...) + +### omega constant ### + +const omega_const_ = 0.567143290409783872999968662210355 +# The BigFloat `omega_const_bf_` is set via a literal in the function __init__ to prevent a segfault + +# maybe compute higher precision. converges very quickly +function omega_const(::Type{BigFloat}) + @compat precision(BigFloat) <= 256 && return omega_const_bf_[] + myeps = eps(BigFloat) + oc = omega_const_bf_[] + for i in 1:100 + nextoc = (1 + oc) / (1 + exp(oc)) + abs(oc - nextoc) <= myeps && break + oc = nextoc + end + return oc +end + +""" + omega + ω + +A constant defined by `ω exp(ω) = 1`. + +```jldoctest +julia> ω +ω = 0.5671432904097... + +julia> omega +ω = 0.5671432904097... + +julia> ω * exp(ω) +1.0 + +julia> big(omega) +5.67143290409783872999968662210355549753815787186512508135131079223045793086683e-01 +``` +""" +const ω = Irrational{:ω}() +@doc (@doc ω) omega = ω + +# The following three lines may be removed when support for v0.6 is dropped +Base.convert(::Type{AbstractFloat}, o::Irrational{:ω}) = Float64(o) +Base.convert(::Type{Float16}, o::Irrational{:ω}) = Float16(o) +Base.convert(::Type{T}, o::Irrational{:ω}) where T <:Number = T(o) + +Base.Float64(::Irrational{:ω}) = omega_const_ # FIXME: This is very slow. Why ? +Base.Float32(::Irrational{:ω}) = Float32(omega_const_) +Base.Float16(::Irrational{:ω}) = Float16(omega_const_) +Base.BigFloat(o::Irrational{:ω}) = omega_const(BigFloat) + +### Expansion about branch point x = -1/e ### + +# Refer to the paper "On the Lambert W function". In (4.22) +# coefficients μ₀ through μ₃ are given explicitly. Recursion relations +# (4.23) and (4.24) for all μ are also given. This code implements the +# recursion relations. + +# (4.23) and (4.24) give zero based coefficients +cset(a,i,v) = a[i+1] = v +cget(a,i) = a[i+1] + +# (4.24) +function compa(k,m,a) + sum0 = zero(eltype(m)) + for j in 2:k-1 + sum0 += cget(m,j) * cget(m,k+1-j) + end + cset(a,k,sum0) + sum0 +end + +# (4.23) +function compm(k,m,a) + kt = convert(eltype(m),k) + mk = (kt-1)/(kt+1) *(cget(m,k-2)/2 + cget(a,k-2)/4) - + cget(a,k)/2 - cget(m,k-1)/(kt+1) + cset(m,k,mk) + mk +end + +# We plug the known value μ₂ == -1//3 for (4.22) into (4.23) and +# solve for α₂. We get α₂ = 0. +# compute array of coefficients μ in (4.22). +# m[1] is μ₀ +function lamwcoeff(T::DataType, n::Int) + # a = @compat Array{T}(undef,n) + # m = @compat Array{T}(undef,n) + a = zeros(T,n) # We don't need initialization, but Compat is a huge PITA. + m = zeros(T,n) + cset(a,0,2) # α₀ literal in paper + cset(a,1,-1) # α₁ literal in paper + cset(a,2,0) # α₂ get this by solving (4.23) for alpha_2 with values printed in paper + cset(m,0,-1) # μ₀ literal in paper + cset(m,1,1) # μ₁ literal in paper + cset(m,2,-1//3) # μ₂ literal in paper, but only in (4.22) + for i in 3:n-1 # coeffs are zero indexed + compa(i,m,a) + compm(i,m,a) + end + return m +end + +const LAMWMU_FLOAT64 = lamwcoeff(Float64,500) + +function horner(x, p::AbstractArray,n) + n += 1 + ex = p[n] + for i = n-1:-1:2 + ex = :($(p[i]) + t * $ex) + end + ex = :( t * $ex) + Expr(:block, :(t = $x), ex) +end + +function mkwser(name, n) + iex = horner(:x,LAMWMU_FLOAT64,n) + :(function ($name)(x) $iex end) +end + +eval(mkwser(:wser3, 3)) +eval(mkwser(:wser5, 5)) +eval(mkwser(:wser7, 7)) +eval(mkwser(:wser12, 12)) +eval(mkwser(:wser19, 19)) +eval(mkwser(:wser26, 26)) +eval(mkwser(:wser32, 32)) +eval(mkwser(:wser50, 50)) +eval(mkwser(:wser100, 100)) +eval(mkwser(:wser290, 290)) + +# Converges to Float64 precision +# We could get finer tuning by separating k=0,-1 branches. +function wser(p,x) + x < 4e-11 && return wser3(p) + x < 1e-5 && return wser7(p) + x < 1e-3 && return wser12(p) + x < 1e-2 && return wser19(p) + x < 3e-2 && return wser26(p) + x < 5e-2 && return wser32(p) + x < 1e-1 && return wser50(p) + x < 1.9e-1 && return wser100(p) + x > 1/euler && @baddomain(x) # radius of convergence + return wser290(p) # good for x approx .32 +end + +# These may need tuning. +function wser(p::Complex{T},z) where T<:Real + x = abs(z) + x < 4e-11 && return wser3(p) + x < 1e-5 && return wser7(p) + x < 1e-3 && return wser12(p) + x < 1e-2 && return wser19(p) + x < 3e-2 && return wser26(p) + x < 5e-2 && return wser32(p) + x < 1e-1 && return wser50(p) + x < 1.9e-1 && return wser100(p) + x > 1/euler && @baddomain(x) # radius of convergence + return wser290(p) +end + +@inline function _lambertw0(x) # 1 + W(-1/e + x) , k = 0 + ps = 2*euler*x; + p = sqrt(ps) + wser(p,x) +end + +@inline function _lambertwm1(x) # 1 + W(-1/e + x) , k = -1 + ps = 2*euler*x; + p = -sqrt(ps) + wser(p,x) +end + +""" + lambertwbp(z,k=0) + +Accurate value of `1 + W(-1/e + z)`, for `abs(z)` in `[0,1/e]` for `k` either `0` or `-1`. +Accurate to Float64 precision for abs(z) < 0.32. +If `k=-1` and `imag(z) < 0`, the value on the branch `k=1` is returned. `lambertwbp` is vectorized. + +```jldoctest +julia> lambertw(-1/e + 1e-18, -1) +-1.0 + +julia> lambertwbp(1e-18, -1) +-2.331643983409312e-9 + +# Same result, but 1000 times slower +julia> convert(Float64,(lambertw(-BigFloat(1)/e + BigFloat(10)^(-18),-1) + 1)) +-2.331643983409312e-9 +``` + +!!! note + `lambertwbp` uses a series expansion about the branch point `z=-1/e` to avoid loss of precision. + The loss of precision in `lambertw` is analogous to the loss of precision + in computing the `sqrt(1-x)` for `x` close to `1`. +""" +function lambertwbp(x::Number,k::Integer) + k == 0 && return _lambertw0(x) + k == -1 && return _lambertwm1(x) + error("expansion about branch point only implemented for k = 0 and -1") +end + +lambertwbp(x::Number) = _lambertw0(x) + +nothing diff --git a/test/lambertw.jl b/test/lambertw.jl new file mode 100644 index 00000000..7614d9d2 --- /dev/null +++ b/test/lambertw.jl @@ -0,0 +1,177 @@ +using Compat + +macro test_baddomain(expr) + if SpecialFunctions.LAMBERTW_USE_NAN + :(@test $(esc(expr)) === NaN) + else + :(@test_throws DomainError $(esc(expr))) + end +end + +const euler = +if isdefined(Base, :MathConstants) + Base.MathConstants.e +else + e +end + +### domain errors + +@test_baddomain lambertw(-2.0,0) +@test_baddomain lambertw(-2.0,-1) +@test_baddomain lambertw(-2.0,1) +@test_baddomain lambertw(NaN) + +## math constant e +@test_baddomain lambertw(euler,1) +@test_baddomain lambertw(euler,-1) + +## integer arguments return floating point types +@test typeof(lambertw(0)) <: AbstractFloat +@test lambertw(0) == 0 + +### math constant, euler e + +# could return math const e, but this would break type stability +@test typeof(lambertw(1)) <: AbstractFloat +@test lambertw(euler,0) == 1 + +## value at branch point where real branches meet +@test lambertw(-1/euler,0) == lambertw(-1/euler,-1) == -1 +@test typeof(lambertw(-1/euler,0)) == typeof(lambertw(-1/euler,-1)) <: AbstractFloat + +## convert irrationals to float + +@test isapprox(lambertw(pi), 1.0736581947961492) +@test isapprox(lambertw(pi,0), 1.0736581947961492) + +### infinite args or return values + +@test lambertw(0,-1) == lambertw(0.0,-1) == -Inf +@test lambertw(Inf,0) == Inf +@test lambertw(complex(Inf,1),0) == complex(Inf,1) +@test lambertw(complex(Inf,0),1) == complex(Inf,2pi) +@test lambertw(complex(-Inf,0),1) == complex(Inf,3pi) +@test lambertw(complex(0.0,0.0),-1) == complex(-Inf,0.0) + +## default branch is k = 0 +@test lambertw(1.0) == lambertw(1.0,0) + +## BigInt args return BigFloats +@test typeof(lambertw(BigInt(0))) == BigFloat +@test typeof(lambertw(BigInt(3))) == BigFloat + +## Any Integer type allowed for second argument +@test lambertw(-0.2,-1) == lambertw(-0.2,BigInt(-1)) + +## BigInt for second arg does not promote the type +@test typeof(lambertw(-0.2,-1)) == typeof(lambertw(-0.2,BigInt(-1))) + +for (z,k,res) in [ (0,0 ,0), (complex(0,0),0 ,0), + (complex(0.0,0),0 ,0), (complex(1.0,0),0, 0.567143290409783873) ] + if Int != Int32 + @test isapprox(lambertw(z,k), res) + @test isapprox(lambertw(z), res) + else + @test isapprox(lambertw(z,k), res; rtol = 1e-14) + @test isapprox(lambertw(z), res; rtol = 1e-14) + end +end + +for (z,k) in ((complex(1,1),2), (complex(1,1),0),(complex(.6,.6),0), + (complex(.6,-.6),0)) + let w + @test (w = lambertw(z,k) ; true) + @test abs(w*exp(w) - z) < 1e-15 + end +end + +@test abs(lambertw(complex(-3.0,-4.0),0) - Complex(1.075073066569255, -1.3251023817343588)) < 1e-14 +@test abs(lambertw(complex(-3.0,-4.0),1) - Complex(0.5887666813694675, 2.7118802109452247)) < 1e-14 +@test (lambertw(complex(.3,.3),0); true) + +# bug fix +# The routine will start at -1/e + eps * im, rather than -1/e + 0im, +# otherwise root finding will fail +if Int != Int32 + @test abs(lambertw(-1.0/euler + 0im,-1)) == 1 +else + @test abs(lambertw(-1.0/euler + 0im,-1) + 1) < 1e-7 +end +# lambertw for BigFloat is more precise than Float64. Note +# that 70 digits in test is about 35 digits in W +let W + for z in [ BigFloat(1), BigFloat(2), complex(BigFloat(1), BigFloat(1))] + @test (W = lambertw(z); true) + @test abs(z - W * exp(W)) < BigFloat(1)^(-70) + end +end + +### ω constant + +## get ω from recursion and compare to value from lambertw +let sp = precision(BigFloat) + @compat setprecision(512) + @test lambertw(big(1)) == big(SpecialFunctions.omega) + @compat setprecision(sp) +end + +@test lambertw(1) == float(SpecialFunctions.omega) +@test convert(Float16,SpecialFunctions.omega) == convert(Float16,0.5674) +@test convert(Float32,SpecialFunctions.omega) == 0.56714326f0 +@test lambertw(BigInt(1)) == big(SpecialFunctions.omega) + +### expansion about branch point + +# not a domain error, but not implemented +@test_throws ErrorException lambertwbp(1,1) + +@test_throws DomainError lambertw(.3,2) + +# Expansions about branch point converges almost to machine precision +# except near the radius of convergence. +# Complex args are not tested here. + +if Int != Int32 + +let sp = precision(BigFloat), z = BigFloat(1)/10^12, wo, diff + @compat setprecision(2048) + for i in 1:300 + # k = 0 + @test (wo = lambertwbp(Float64(z)); diff = abs(-1 + wo - lambertw(z-1/big(euler))); true) + if diff > 5e-16 + println(Float64(z), " ", Float64(diff)) + end + @test diff < 5e-16 + # k = -1 + @test (wo = lambertwbp(Float64(z),-1); diff = abs(-1 + wo - lambertw(z-1/big(euler),-1)); true) + if diff > 5e-16 + println(Float64(z), " ", Float64(diff)) + end + @test diff < 5e-16 + z *= 1.1 + if z > 0.23 break end + end + @compat setprecision(sp) +end + +# test the expansion about branch point for k=-1, +# by comparing to exact BigFloat calculation. +@test lambertwbp(1e-20,-1) - 1 - lambertw(-BigFloat(1)/big(euler)+ BigFloat(1)/BigFloat(10)^BigFloat(20),-1) < 1e-16 + +@test abs(lambertwbp(Complex(.01,.01),-1) - Complex(-0.2755038208041206, -0.1277888928494641)) < 1e-14 + +end + +## vectorization + +if VERSION >= v"0.5" + @test lambertw.([0.1,0.2]) == [lambertw(0.1),lambertw(0.2)] + @test lambertw.([0.1+im ,0.2-im]) == [lambertw(0.1+im),lambertw(0.2-im)] + @test lambertw.([0.1,-0.2],[0,-1]) == [lambertw(0.1,0),lambertw(-0.2,-1)] + @test lambertwbp.([.1,.2,.3],-1) == map(x -> lambertwbp(x,-1), [.1,.2,.3]) +else + @test lambertw([0.1,0.2]) == [lambertw(0.1),lambertw(0.2)] + @test lambertw([0.1+im ,0.2-im]) == [lambertw(0.1+im),lambertw(0.2-im)] + @test lambertw([0.1,-0.2],[0,-1]) == [lambertw(0.1,0),lambertw(-0.2,-1)] +end diff --git a/test/runtests.jl b/test/runtests.jl index 69dcafe7..f9d27d31 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -32,6 +32,7 @@ tests = [ "gamma", "logabsgamma", "sincosint", + "lambertw", "other_tests", "chainrules" ] From 87b5509d3eed71ace5d2a3909fcef3044bd01bf2 Mon Sep 17 00:00:00 2001 From: John Lapeyre Date: Thu, 19 Apr 2018 01:37:39 +0200 Subject: [PATCH 02/24] made changes request in PR review --- LICENSE | 32 ++++++++ src/SpecialFunctions.jl | 7 ++ src/lambertw.jl | 170 +++++++++++++++++----------------------- test/lambertw.jl | 93 +++++++++------------- 4 files changed, 146 insertions(+), 156 deletions(-) diff --git a/LICENSE b/LICENSE index 28e07032..b9a9755c 100644 --- a/LICENSE +++ b/LICENSE @@ -22,3 +22,35 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +Portions of this code are derived from SciPy and are licensed under +the Scipy License: + +> Copyright (c) 2001, 2002 Enthought, Inc. +> All rights reserved. + +> Copyright (c) 2003-2012 SciPy Developers. +> All rights reserved. + +> Redistribution and use in source and binary forms, with or without +> modification, are permitted provided that the following conditions are met: + +> a. Redistributions of source code must retain the above copyright notice, +> this list of conditions and the following disclaimer. +> b. Redistributions in binary form must reproduce the above copyright +> notice, this list of conditions and the following disclaimer in the +> documentation and/or other materials provided with the distribution. +> c. Neither the name of Enthought nor the names of the SciPy Developers +> may be used to endorse or promote products derived from this software +> without specific prior written permission. +> +> THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +> AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +> IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +> ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS +> BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +> OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +> SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +> INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +> CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +> ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +> THE POSSIBILITY OF SUCH DAMAGE. diff --git a/src/SpecialFunctions.jl b/src/SpecialFunctions.jl index 9c31524b..1090abb2 100644 --- a/src/SpecialFunctions.jl +++ b/src/SpecialFunctions.jl @@ -81,6 +81,13 @@ export lambertw, lambertwbp +const omega_const_bf_ = Ref{BigFloat}() +function __init__() + # allocate storage for this BigFloat constant each time this module is loaded + omega_const_bf_[] = + parse(BigFloat,"0.5671432904097838729999686622103555497538157871865125081351310792230457930866845666932194") +end + include("bessel.jl") include("erf.jl") include("ellip.jl") diff --git a/src/lambertw.jl b/src/lambertw.jl index b98f7aee..5fdd9e74 100644 --- a/src/lambertw.jl +++ b/src/lambertw.jl @@ -1,65 +1,47 @@ import Base: convert -#export lambertw, lambertwbp -using Compat - -const euler = -if isdefined(Base, :MathConstants) - Base.MathConstants.e -else - e -end -const omega_const_bf_ = Ref{BigFloat}() - -function __init__() - omega_const_bf_[] = - parse(BigFloat,"0.5671432904097838729999686622103555497538157871865125081351310792230457930866845666932194") -end +using Compat +import Compat.MathConstants # For clarity, we use MathConstants.e for Euler's number #### Lambert W function #### -const LAMBERTW_USE_NAN = false - -macro baddomain(v) - if LAMBERTW_USE_NAN - return :(return(NaN)) - else - return esc(:(throw(DomainError($v)))) - end -end - -# Use Halley's root-finding method to find x = lambertw(z) with -# initial point x. -function _lambertw(z::T, x::T) where T <: Number +# Use Halley's root-finding method to find +# x = lambertw(z) with initial point x. +function _lambertw(z::T, x::T, maxits) where T <: Number two_t = convert(T,2) lastx = x lastdiff = zero(T) - for i in 1:100 + converged::Bool = false + for i in 1:maxits ex = exp(x) xexz = x * ex - z x1 = x + 1 - x = x - xexz / (ex * x1 - (x + two_t) * xexz / (two_t * x1 ) ) + x -= xexz / (ex * x1 - (x + two_t) * xexz / (two_t * x1 ) ) xdiff = abs(lastx - x) - xdiff <= 2*eps(abs(lastx)) && break - lastdiff == diff && break + if xdiff <= 3*eps(abs(lastx)) || lastdiff == xdiff # second condition catches two-value cycle + converged = true + break + end lastx = x lastdiff = xdiff end - x + converged || warn("lambertw with z=", z, " did not converge in ", maxits, " iterations.") + return x end ### Real z ### # Real x, k = 0 - +# This appears to be inferrable with T=Float64 and T=BigFloat, including if x=Inf. # The fancy initial condition selection does not seem to help speed, but we leave it for now. -function lambertwk0(x::T)::T where T<:AbstractFloat - x == Inf && return Inf +function lambertwk0(x::T, maxits)::T where T<:AbstractFloat + isnan(x) && return(NaN) + x == Inf && return Inf # appears to return convert(BigFloat,Inf) for x == BigFloat(Inf) one_t = one(T) - oneoe = -one_t/convert(T,euler) + oneoe = -one_t/convert(T,MathConstants.e) # The branch point x == oneoe && return -one_t + oneoe <= x || throw(DomainError(x)) itwo_t = 1/convert(T,2) - oneoe <= x || @baddomain(x) if x > one_t lx = log(x) llx = log(lx) @@ -67,28 +49,28 @@ function lambertwk0(x::T)::T where T<:AbstractFloat else x1 = (567//1000) * x end - _lambertw(x,x1) + return _lambertw(x, x1, maxits) end # Real x, k = -1 -function _lambertwkm1(x::T) where T<:Real - oneoe = -one(T)/convert(T,euler) - x == oneoe && return -one(T) - oneoe <= x || @baddomain(x) - x == zero(T) && return -convert(T,Inf) - x < zero(T) || @baddomain(x) - _lambertw(x,log(-x)) +function lambertwkm1(x::T, maxits) where T<:Real + oneoe = -one(T)/convert(T,MathConstants.e) + x == oneoe && return -one(T) # W approaches -1 as x -> -1/e from above + oneoe <= x || throw(DomainError(x)) # branch domain exludes x < -1/e + x == zero(T) && return -convert(T,Inf) # W decreases w/o bound as x -> 0 from below + x < zero(T) || throw(DomainError(x)) + return _lambertw(x, log(-x), maxits) end - """ - lambertw(z::Complex{T}, k::V=0) where {T<:Real, V<:Integer} - lambertw(z::T, k::V=0) where {T<:Real, V<:Integer} + lambertw(z::Complex{T}, k::V=0, maxits=1000) where {T<:Real, V<:Integer} + lambertw(z::T, k::V=0, maxits=1000) where {T<:Real, V<:Integer} Compute the `k`th branch of the Lambert W function of `z`. If `z` is real, `k` must be either `0` or `-1`. For `Real` `z`, the domain of the branch `k = -1` is `[-1/e,0]` and the domain of the branch `k = 0` is `[-1/e,Inf]`. For `Complex` `z`, and all `k`, the domain is -the complex plane. +the complex plane. When using root finding to compute `W`, a value for `W` is returned +with a warning if it has not converged after `maxits` iterations. ```jldoctest julia> lambertw(-1/e,-1) @@ -107,33 +89,31 @@ julia> lambertw(Complex(-10.0,3.0), 4) -0.9274337508660128 + 26.37693445371142im ``` -!!! note - The constant `LAMBERTW_USE_NAN` at the top of the source file controls whether arguments - outside the domain throw `DomainError` or return `NaN`. The default is `DomainError`. """ -function lambertw(x::Real, k::Integer) - k == 0 && return lambertwk0(x) - k == -1 && return _lambertwkm1(x) - @baddomain(k) # more informative message like below ? -# error("lambertw: real x must have k == 0 or k == -1") +lambertw(z, k::Integer=0, maxits::Integer=1000) = lambertw_(z, k, maxits) + +function lambertw_(x::Real, k, maxits) + k == 0 && return lambertwk0(x, maxits) + k == -1 && return lambertwkm1(x, maxits) + throw(DomainError(k, "lambertw: real x must have branch k == 0 or k == -1")) end -function lambertw(x::Union{Integer,Rational}, k::Integer) +function lambertw_(x::Union{Integer,Rational}, k, maxits) if k == 0 x == 0 && return float(zero(x)) - x == 1 && return convert(typeof(float(x)),omega) # must be more efficient way + x == 1 && return convert(typeof(float(x)), omega) # must be a more efficient way end - lambertw(float(x),k) + return lambertw_(float(x), k, maxits) end ### Complex z ### # choose initial value inside correct branch for root finding -function lambertw(z::Complex{T}, k::Integer) where T<:Real +function lambertw_(z::Complex{T}, k, maxits) where T<:Real one_t = one(T) local w::Complex{T} pointseven = 7//10 - if abs(z) <= one_t/convert(T,euler) + if abs(z) <= one_t/convert(T,MathConstants.e) if z == 0 k == 0 && return z return complex(-convert(T,Inf),zero(T)) @@ -157,27 +137,19 @@ function lambertw(z::Complex{T}, k::Integer) where T<:Real w = log(z) k != 0 ? w += complex(0, 2*k*pi) : nothing end - _lambertw(z,w) + return _lambertw(z, w, maxits) end -lambertw(z::Complex{T}, k::Integer) where T<:Integer = lambertw(float(z),k) +lambertw_(z::Complex{T}, k, maxits) where T<:Integer = lambertw_(float(z), k, maxits) +lambertw_(n::Irrational, k, maxits) = lambertw_(float(n), k, maxits) # lambertw(e + 0im,k) is ok for all k -#function lambertw(::Irrational{:e}, k::T) where T<:Integer -function lambertw(::typeof(euler), k::T) where T<:Integer +# Maybe this should return a float. But, this should cause no type instability in any case +function lambertw_(::typeof(MathConstants.e), k, maxits) k == 0 && return 1 - @baddomain(k) + throw(DomainError(k)) end -# Maybe this should return a float -lambertw(::typeof(euler)) = 1 -#lambertw(::Irrational{:e}) = 1 - -#lambertw{T<:Number}(x::T) = lambertw(x,0) -lambertw(x::Number) = lambertw(x,0) - -lambertw(n::Irrational, args::Integer...) = lambertw(float(n),args...) - ### omega constant ### const omega_const_ = 0.567143290409783872999968662210355 @@ -185,7 +157,7 @@ const omega_const_ = 0.567143290409783872999968662210355 # maybe compute higher precision. converges very quickly function omega_const(::Type{BigFloat}) - @compat precision(BigFloat) <= 256 && return omega_const_bf_[] + precision(BigFloat) <= 256 && return omega_const_bf_[] myeps = eps(BigFloat) oc = omega_const_bf_[] for i in 1:100 @@ -200,7 +172,7 @@ end omega ω -A constant defined by `ω exp(ω) = 1`. +The constant defined by `ω exp(ω) = 1`. ```jldoctest julia> ω @@ -219,7 +191,7 @@ julia> big(omega) const ω = Irrational{:ω}() @doc (@doc ω) omega = ω -# The following three lines may be removed when support for v0.6 is dropped +# The following two lines may be removed when support for v0.6 is dropped Base.convert(::Type{AbstractFloat}, o::Irrational{:ω}) = Float64(o) Base.convert(::Type{Float16}, o::Irrational{:ω}) = Float16(o) Base.convert(::Type{T}, o::Irrational{:ω}) where T <:Number = T(o) @@ -236,7 +208,7 @@ Base.BigFloat(o::Irrational{:ω}) = omega_const(BigFloat) # (4.23) and (4.24) for all μ are also given. This code implements the # recursion relations. -# (4.23) and (4.24) give zero based coefficients +# (4.23) and (4.24) give zero based coefficients. cset(a,i,v) = a[i+1] = v cget(a,i) = a[i+1] @@ -247,7 +219,7 @@ function compa(k,m,a) sum0 += cget(m,j) * cget(m,k+1-j) end cset(a,k,sum0) - sum0 + return sum0 end # (4.23) @@ -256,7 +228,7 @@ function compm(k,m,a) mk = (kt-1)/(kt+1) *(cget(m,k-2)/2 + cget(a,k-2)/4) - cget(a,k)/2 - cget(m,k-1)/(kt+1) cset(m,k,mk) - mk + return mk end # We plug the known value μ₂ == -1//3 for (4.22) into (4.23) and @@ -283,19 +255,21 @@ end const LAMWMU_FLOAT64 = lamwcoeff(Float64,500) -function horner(x, p::AbstractArray,n) +# Base.Math.@horner requires literal coefficients +# But, we have an array `p` of computed coefficients +function horner(x, p::AbstractArray, n) n += 1 ex = p[n] for i = n-1:-1:2 - ex = :($(p[i]) + t * $ex) + ex = :(muladd(t, $ex, $(p[i]))) end ex = :( t * $ex) - Expr(:block, :(t = $x), ex) + return Expr(:block, :(t = $x), ex) end function mkwser(name, n) iex = horner(:x,LAMWMU_FLOAT64,n) - :(function ($name)(x) $iex end) + return :(function ($name)(x) $iex end) end eval(mkwser(:wser3, 3)) @@ -320,7 +294,7 @@ function wser(p,x) x < 5e-2 && return wser32(p) x < 1e-1 && return wser50(p) x < 1.9e-1 && return wser100(p) - x > 1/euler && @baddomain(x) # radius of convergence + x > 1/MathConstants.e && throw(DomainError(x)) # radius of convergence return wser290(p) # good for x approx .32 end @@ -335,28 +309,28 @@ function wser(p::Complex{T},z) where T<:Real x < 5e-2 && return wser32(p) x < 1e-1 && return wser50(p) x < 1.9e-1 && return wser100(p) - x > 1/euler && @baddomain(x) # radius of convergence + x > 1/MathConstants.e && throw(DomainError(x)) # radius of convergence return wser290(p) end @inline function _lambertw0(x) # 1 + W(-1/e + x) , k = 0 - ps = 2*euler*x; + ps = 2*MathConstants.e*x; p = sqrt(ps) - wser(p,x) + return wser(p,x) end @inline function _lambertwm1(x) # 1 + W(-1/e + x) , k = -1 - ps = 2*euler*x; + ps = 2*MathConstants.e*x; p = -sqrt(ps) - wser(p,x) + return wser(p,x) end """ lambertwbp(z,k=0) -Accurate value of `1 + W(-1/e + z)`, for `abs(z)` in `[0,1/e]` for `k` either `0` or `-1`. -Accurate to Float64 precision for abs(z) < 0.32. -If `k=-1` and `imag(z) < 0`, the value on the branch `k=1` is returned. `lambertwbp` is vectorized. +Compute accurate value of `1 + W(-1/e + z)`, for `abs(z)` in `[0,1/e]` for `k` either `0` or `-1`. +The result is accurate to Float64 precision for abs(z) < 0.32. +If `k=-1` and `imag(z) < 0`, the value on the branch `k=1` is returned. ```jldoctest julia> lambertw(-1/e + 1e-18, -1) @@ -378,9 +352,7 @@ julia> convert(Float64,(lambertw(-BigFloat(1)/e + BigFloat(10)^(-18),-1) + 1)) function lambertwbp(x::Number,k::Integer) k == 0 && return _lambertw0(x) k == -1 && return _lambertwm1(x) - error("expansion about branch point only implemented for k = 0 and -1") + throw(ArgumentError("expansion about branch point only implemented for k = 0 and -1.")) end lambertwbp(x::Number) = _lambertw0(x) - -nothing diff --git a/test/lambertw.jl b/test/lambertw.jl index 7614d9d2..46d5dd57 100644 --- a/test/lambertw.jl +++ b/test/lambertw.jl @@ -1,44 +1,31 @@ using Compat -macro test_baddomain(expr) - if SpecialFunctions.LAMBERTW_USE_NAN - :(@test $(esc(expr)) === NaN) - else - :(@test_throws DomainError $(esc(expr))) - end -end - -const euler = -if isdefined(Base, :MathConstants) - Base.MathConstants.e -else - e -end +import Compat.MathConstants ### domain errors -@test_baddomain lambertw(-2.0,0) -@test_baddomain lambertw(-2.0,-1) -@test_baddomain lambertw(-2.0,1) -@test_baddomain lambertw(NaN) +@test_throws DomainError lambertw(-2.0,0) +@test_throws DomainError lambertw(-2.0,-1) +@test_throws DomainError lambertw(-2.0,1) +@test isnan(lambertw(NaN)) ## math constant e -@test_baddomain lambertw(euler,1) -@test_baddomain lambertw(euler,-1) +@test_throws DomainError lambertw(MathConstants.e,1) +@test_throws DomainError lambertw(MathConstants.e,-1) ## integer arguments return floating point types @test typeof(lambertw(0)) <: AbstractFloat @test lambertw(0) == 0 -### math constant, euler e +### math constant, MathConstants.e e # could return math const e, but this would break type stability @test typeof(lambertw(1)) <: AbstractFloat -@test lambertw(euler,0) == 1 +@test lambertw(MathConstants.e,0) == 1 ## value at branch point where real branches meet -@test lambertw(-1/euler,0) == lambertw(-1/euler,-1) == -1 -@test typeof(lambertw(-1/euler,0)) == typeof(lambertw(-1/euler,-1)) <: AbstractFloat +@test lambertw(-1/MathConstants.e,0) == lambertw(-1/MathConstants.e,-1) == -1 +@test typeof(lambertw(-1/MathConstants.e,0)) == typeof(lambertw(-1/MathConstants.e,-1)) <: AbstractFloat ## convert irrationals to float @@ -94,9 +81,9 @@ end # The routine will start at -1/e + eps * im, rather than -1/e + 0im, # otherwise root finding will fail if Int != Int32 - @test abs(lambertw(-1.0/euler + 0im,-1)) == 1 + @test abs(lambertw(-1.0/MathConstants.e + 0im,-1)) == 1 else - @test abs(lambertw(-1.0/euler + 0im,-1) + 1) < 1e-7 + @test abs(lambertw(-1.0/MathConstants.e + 0im,-1) + 1) < 1e-7 end # lambertw for BigFloat is more precise than Float64. Note # that 70 digits in test is about 35 digits in W @@ -111,9 +98,9 @@ end ## get ω from recursion and compare to value from lambertw let sp = precision(BigFloat) - @compat setprecision(512) + setprecision(512) @test lambertw(big(1)) == big(SpecialFunctions.omega) - @compat setprecision(sp) + setprecision(sp) end @test lambertw(1) == float(SpecialFunctions.omega) @@ -124,7 +111,7 @@ end ### expansion about branch point # not a domain error, but not implemented -@test_throws ErrorException lambertwbp(1,1) +@test_throws ArgumentError lambertwbp(1,1) @test_throws DomainError lambertw(.3,2) @@ -134,44 +121,36 @@ end if Int != Int32 -let sp = precision(BigFloat), z = BigFloat(1)/10^12, wo, diff - @compat setprecision(2048) +# Test double-precision expansion near branch point using BigFloats +let sp = precision(BigFloat), z = BigFloat(1)/10^12, wo, xdiff + setprecision(2048) for i in 1:300 - # k = 0 - @test (wo = lambertwbp(Float64(z)); diff = abs(-1 + wo - lambertw(z-1/big(euler))); true) - if diff > 5e-16 - println(Float64(z), " ", Float64(diff)) + innerarg = z-1/big(MathConstants.e) + + # branch k = 0 + @test (wo = lambertwbp(Float64(z)); xdiff = abs(-1 + wo - lambertw(innerarg)); true) + if xdiff > 5e-16 + println(Float64(z), " ", Float64(xdiff)) end - @test diff < 5e-16 - # k = -1 - @test (wo = lambertwbp(Float64(z),-1); diff = abs(-1 + wo - lambertw(z-1/big(euler),-1)); true) - if diff > 5e-16 - println(Float64(z), " ", Float64(diff)) + @test xdiff < 5e-16 + + # branch k = -1 + @test (wo = lambertwbp(Float64(z),-1); xdiff = abs(-1 + wo - lambertw(innerarg,-1)); true) + if xdiff > 5e-16 + println(Float64(z), " ", Float64(xdiff)) end - @test diff < 5e-16 + @test xdiff < 5e-16 z *= 1.1 if z > 0.23 break end + end - @compat setprecision(sp) + setprecision(sp) end # test the expansion about branch point for k=-1, # by comparing to exact BigFloat calculation. -@test lambertwbp(1e-20,-1) - 1 - lambertw(-BigFloat(1)/big(euler)+ BigFloat(1)/BigFloat(10)^BigFloat(20),-1) < 1e-16 +@test lambertwbp(1e-20,-1) - 1 - lambertw(-BigFloat(1)/big(MathConstants.e)+ BigFloat(1)/BigFloat(10)^BigFloat(20),-1) < 1e-16 @test abs(lambertwbp(Complex(.01,.01),-1) - Complex(-0.2755038208041206, -0.1277888928494641)) < 1e-14 -end - -## vectorization - -if VERSION >= v"0.5" - @test lambertw.([0.1,0.2]) == [lambertw(0.1),lambertw(0.2)] - @test lambertw.([0.1+im ,0.2-im]) == [lambertw(0.1+im),lambertw(0.2-im)] - @test lambertw.([0.1,-0.2],[0,-1]) == [lambertw(0.1,0),lambertw(-0.2,-1)] - @test lambertwbp.([.1,.2,.3],-1) == map(x -> lambertwbp(x,-1), [.1,.2,.3]) -else - @test lambertw([0.1,0.2]) == [lambertw(0.1),lambertw(0.2)] - @test lambertw([0.1+im ,0.2-im]) == [lambertw(0.1+im),lambertw(0.2-im)] - @test lambertw([0.1,-0.2],[0,-1]) == [lambertw(0.1,0),lambertw(-0.2,-1)] -end +end # if Int != Int32 From 8bc181496d2a9362d129e838c799579591db8101 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Fri, 31 Dec 2021 00:20:34 +0100 Subject: [PATCH 03/24] sync with latest LambertW.jl master --- src/lambertw.jl | 338 +++++++++++++++++++++++------------------------ test/lambertw.jl | 86 ++++++------ 2 files changed, 205 insertions(+), 219 deletions(-) diff --git a/src/lambertw.jl b/src/lambertw.jl index 5fdd9e74..0f76c668 100644 --- a/src/lambertw.jl +++ b/src/lambertw.jl @@ -1,161 +1,158 @@ -import Base: convert - -using Compat -import Compat.MathConstants # For clarity, we use MathConstants.e for Euler's number - #### Lambert W function #### -# Use Halley's root-finding method to find -# x = lambertw(z) with initial point x. -function _lambertw(z::T, x::T, maxits) where T <: Number - two_t = convert(T,2) - lastx = x - lastdiff = zero(T) - converged::Bool = false - for i in 1:maxits - ex = exp(x) - xexz = x * ex - z - x1 = x + 1 - x -= xexz / (ex * x1 - (x + two_t) * xexz / (two_t * x1 ) ) - xdiff = abs(lastx - x) - if xdiff <= 3*eps(abs(lastx)) || lastdiff == xdiff # second condition catches two-value cycle - converged = true - break - end - lastx = x - lastdiff = xdiff - end - converged || warn("lambertw with z=", z, " did not converge in ", maxits, " iterations.") - return x -end - -### Real z ### - -# Real x, k = 0 -# This appears to be inferrable with T=Float64 and T=BigFloat, including if x=Inf. -# The fancy initial condition selection does not seem to help speed, but we leave it for now. -function lambertwk0(x::T, maxits)::T where T<:AbstractFloat - isnan(x) && return(NaN) - x == Inf && return Inf # appears to return convert(BigFloat,Inf) for x == BigFloat(Inf) - one_t = one(T) - oneoe = -one_t/convert(T,MathConstants.e) # The branch point - x == oneoe && return -one_t - oneoe <= x || throw(DomainError(x)) - itwo_t = 1/convert(T,2) - if x > one_t - lx = log(x) - llx = log(lx) - x1 = lx - llx - log(one_t - llx/lx) * itwo_t - else - x1 = (567//1000) * x - end - return _lambertw(x, x1, maxits) -end - -# Real x, k = -1 -function lambertwkm1(x::T, maxits) where T<:Real - oneoe = -one(T)/convert(T,MathConstants.e) - x == oneoe && return -one(T) # W approaches -1 as x -> -1/e from above - oneoe <= x || throw(DomainError(x)) # branch domain exludes x < -1/e - x == zero(T) && return -convert(T,Inf) # W decreases w/o bound as x -> 0 from below - x < zero(T) || throw(DomainError(x)) - return _lambertw(x, log(-x), maxits) -end - """ lambertw(z::Complex{T}, k::V=0, maxits=1000) where {T<:Real, V<:Integer} lambertw(z::T, k::V=0, maxits=1000) where {T<:Real, V<:Integer} Compute the `k`th branch of the Lambert W function of `z`. If `z` is real, `k` must be -either `0` or `-1`. For `Real` `z`, the domain of the branch `k = -1` is `[-1/e,0]` and the -domain of the branch `k = 0` is `[-1/e,Inf]`. For `Complex` `z`, and all `k`, the domain is -the complex plane. When using root finding to compute `W`, a value for `W` is returned -with a warning if it has not converged after `maxits` iterations. +either `0` or `-1`. For `Real` `z`, the domain of the branch `k = -1` is `[-1/e, 0]` and the +domain of the branch `k = 0` is `[-1/e, Inf]`. For `Complex` `z`, and all `k`, the domain is +the complex plane. ```jldoctest -julia> lambertw(-1/e,-1) +julia> lambertw(-1/e, -1) -1.0 -julia> lambertw(-1/e,0) +julia> lambertw(-1/e, 0) -1.0 -julia> lambertw(0,0) +julia> lambertw(0, 0) 0.0 -julia> lambertw(0,-1) +julia> lambertw(0, -1) -Inf -julia> lambertw(Complex(-10.0,3.0), 4) +julia> lambertw(Complex(-10.0, 3.0), 4) -0.9274337508660128 + 26.37693445371142im ``` - """ -lambertw(z, k::Integer=0, maxits::Integer=1000) = lambertw_(z, k, maxits) +lambertw(z, k::Integer=0, maxits::Integer=1000) = _lambertw(float(z), k, maxits) -function lambertw_(x::Real, k, maxits) - k == 0 && return lambertwk0(x, maxits) - k == -1 && return lambertwkm1(x, maxits) - throw(DomainError(k, "lambertw: real x must have branch k == 0 or k == -1")) +# lambertw(e + 0im, k) is ok for all k +# Maybe this should return a float. But, this should cause no type instability in any case +function _lambertw(::typeof(MathConstants.e), k, maxits) + k == 0 && return 1 + throw(DomainError(k)) end - -function lambertw_(x::Union{Integer,Rational}, k, maxits) +_lambertw(x::Irrational, k, maxits) = _lambertw(float(x), k, maxits) +function _lambertw(x::Union{Integer, Rational}, k, maxits) if k == 0 x == 0 && return float(zero(x)) x == 1 && return convert(typeof(float(x)), omega) # must be a more efficient way end - return lambertw_(float(x), k, maxits) + return _lambertw(float(x), k, maxits) end -### Complex z ### +### Real x +function _lambertw(x::Real, k, maxits) + k == 0 && return lambertw_branch_zero(x, maxits) + k == -1 && return lambertw_branch_one(x, maxits) + throw(DomainError(k, "lambertw: real x must have branch k == 0 or k == -1")) +end + +# Real x, k = 0 +# This appears to be inferrable with T=Float64 and T=BigFloat, including if x=Inf. +# There is a magic number here. It could be noted, or possibly removed. +# In particular, the fancy initial condition selection does not seem to help speed. +function lambertw_branch_zero(x::T, maxits)::T where T<:Real + isnan(x) && return(NaN) + x == Inf && return Inf # appears to return convert(BigFloat, Inf) for x == BigFloat(Inf) + one_t = one(T) + oneoe = -one_t / convert(T, MathConstants.e) # The branch point + x == oneoe && return -one_t + oneoe <= x || throw(DomainError(x)) + itwo_t = 1 / convert(T, 2) + if x > one_t + lx = log(x) + llx = log(lx) + x0 = lx - llx - log(one_t - llx / lx) * itwo_t + else + x0 = (567//1000) * x + end + return lambertw_root_finding(x, x0, maxits) +end + +# Real x, k = -1 +function lambertw_branch_one(x::T, maxits) where T<:Real + oneoe = -one(T) / convert(T, MathConstants.e) + x == oneoe && return -one(T) # W approaches -1 as x -> -1/e from above + oneoe <= x || throw(DomainError(x)) # branch domain exludes x < -1/e + x == zero(T) && return -convert(T, Inf) # W decreases w/o bound as x -> 0 from below + x < zero(T) || throw(DomainError(x)) + return lambertw_root_finding(x, log(-x), maxits) +end + +### Complex z + +_lambertw(z::Complex{<:Integer}, k, maxits) = _lambertw(float(z), k, maxits) # choose initial value inside correct branch for root finding -function lambertw_(z::Complex{T}, k, maxits) where T<:Real +function _lambertw(z::Complex{T}, k, maxits) where T<:Real one_t = one(T) local w::Complex{T} pointseven = 7//10 - if abs(z) <= one_t/convert(T,MathConstants.e) + if abs(z) <= one_t/convert(T, MathConstants.e) if z == 0 k == 0 && return z - return complex(-convert(T,Inf),zero(T)) + return complex(-convert(T, Inf), zero(T)) end if k == 0 w = z elseif k == -1 && imag(z) == 0 && real(z) < 0 - w = complex(log(-real(z)),1//10^7) # need offset for z ≈ -1/e. + w = complex(log(-real(z)), 1//10^7) # need offset for z ≈ -1/e. else w = log(z) - k != 0 ? w += complex(0,k * 2 * pi) : nothing + k != 0 ? w += complex(0, k * 2 * pi) : nothing end elseif k == 0 && imag(z) <= pointseven && abs(z) <= pointseven - w = abs(z+ 1//2) < 1//10 ? imag(z) > 0 ? complex(pointseven,pointseven) : complex(pointseven,-pointseven) : z + w = abs(z+ 1//2) < 1//10 ? imag(z) > 0 ? complex(pointseven, pointseven) : complex(pointseven, -pointseven) : z else - if real(z) == convert(T,Inf) + if real(z) == convert(T, Inf) k == 0 && return z - return z + complex(0,2*k*pi) + return z + complex(0, 2*k*pi) end - real(z) == -convert(T,Inf) && return -z + complex(0,(2*k+1)*pi) + real(z) == -convert(T, Inf) && return -z + complex(0, (2*k+1)*pi) w = log(z) k != 0 ? w += complex(0, 2*k*pi) : nothing end - return _lambertw(z, w, maxits) + return lambertw_root_finding(z, w, maxits) end -lambertw_(z::Complex{T}, k, maxits) where T<:Integer = lambertw_(float(z), k, maxits) -lambertw_(n::Irrational, k, maxits) = lambertw_(float(n), k, maxits) +### root finding, iterative solution -# lambertw(e + 0im,k) is ok for all k -# Maybe this should return a float. But, this should cause no type instability in any case -function lambertw_(::typeof(MathConstants.e), k, maxits) - k == 0 && return 1 - throw(DomainError(k)) +# Use Halley's root-finding method to find +# x = lambertw(z) with initial point x0. +function lambertw_root_finding(z::T, x0::T, maxits) where T <: Number + two_t = convert(T, 2) + x = x0 + lastx = x + lastdiff = zero(T) + converged::Bool = false + for i in 1:maxits + ex = exp(x) + xexz = x * ex - z + x1 = x + 1 + x -= xexz / (ex * x1 - (x + two_t) * xexz / (two_t * x1 )) + xdiff = abs(lastx - x) + if xdiff <= 3 * eps(abs(lastx)) || lastdiff == xdiff # second condition catches two-value cycle + converged = true + break + end + lastx = x + lastdiff = xdiff + end + converged || @warn("lambertw with z=", z, " did not converge in ", maxits, " iterations.") + return x end -### omega constant ### +### omega constant + +const _omega_const = 0.567143290409783872999968662210355 -const omega_const_ = 0.567143290409783872999968662210355 # The BigFloat `omega_const_bf_` is set via a literal in the function __init__ to prevent a segfault -# maybe compute higher precision. converges very quickly +# compute omega constant via root finding +# We could compute higher precision. This converges very quickly. function omega_const(::Type{BigFloat}) precision(BigFloat) <= 256 && return omega_const_bf_[] myeps = eps(BigFloat) @@ -174,6 +171,7 @@ end The constant defined by `ω exp(ω) = 1`. +# Example ```jldoctest julia> ω ω = 0.5671432904097... @@ -191,17 +189,12 @@ julia> big(omega) const ω = Irrational{:ω}() @doc (@doc ω) omega = ω -# The following two lines may be removed when support for v0.6 is dropped -Base.convert(::Type{AbstractFloat}, o::Irrational{:ω}) = Float64(o) -Base.convert(::Type{Float16}, o::Irrational{:ω}) = Float16(o) -Base.convert(::Type{T}, o::Irrational{:ω}) where T <:Number = T(o) +Base.Float64(::Irrational{:ω}) = _omega_const +Base.Float32(::Irrational{:ω}) = Float32(_omega_const) +Base.Float16(::Irrational{:ω}) = Float16(_omega_const) +Base.BigFloat(::Irrational{:ω}) = omega_const(BigFloat) -Base.Float64(::Irrational{:ω}) = omega_const_ # FIXME: This is very slow. Why ? -Base.Float32(::Irrational{:ω}) = Float32(omega_const_) -Base.Float16(::Irrational{:ω}) = Float16(omega_const_) -Base.BigFloat(o::Irrational{:ω}) = omega_const(BigFloat) - -### Expansion about branch point x = -1/e ### +### Expansion about branch point x = -1/e # Refer to the paper "On the Lambert W function". In (4.22) # coefficients μ₀ through μ₃ are given explicitly. Recursion relations @@ -209,83 +202,78 @@ Base.BigFloat(o::Irrational{:ω}) = omega_const(BigFloat) # recursion relations. # (4.23) and (4.24) give zero based coefficients. -cset(a,i,v) = a[i+1] = v -cget(a,i) = a[i+1] +cset(a, i, v) = a[i+1] = v +cget(a, i) = a[i+1] # (4.24) -function compa(k,m,a) +function compute_a_coeffs(k, m, a) sum0 = zero(eltype(m)) - for j in 2:k-1 - sum0 += cget(m,j) * cget(m,k+1-j) + for j in 2:(k - 1) + sum0 += cget(m, j) * cget(m, k + 1 - j) end - cset(a,k,sum0) + cset(a, k, sum0) return sum0 end # (4.23) -function compm(k,m,a) - kt = convert(eltype(m),k) - mk = (kt-1)/(kt+1) *(cget(m,k-2)/2 + cget(a,k-2)/4) - - cget(a,k)/2 - cget(m,k-1)/(kt+1) - cset(m,k,mk) +function compute_m_coefficients(k, m, a) + kt = convert(eltype(m), k) + mk = (kt - 1) / (kt + 1) *(cget(m, k - 2) / 2 + cget(a, k - 2) / 4) - + cget(a, k) / 2 - cget(m, k - 1) / (kt + 1) + cset(m, k, mk) return mk end # We plug the known value μ₂ == -1//3 for (4.22) into (4.23) and # solve for α₂. We get α₂ = 0. -# compute array of coefficients μ in (4.22). +# Compute array of coefficients μ in (4.22). # m[1] is μ₀ -function lamwcoeff(T::DataType, n::Int) - # a = @compat Array{T}(undef,n) - # m = @compat Array{T}(undef,n) - a = zeros(T,n) # We don't need initialization, but Compat is a huge PITA. - m = zeros(T,n) - cset(a,0,2) # α₀ literal in paper - cset(a,1,-1) # α₁ literal in paper - cset(a,2,0) # α₂ get this by solving (4.23) for alpha_2 with values printed in paper - cset(m,0,-1) # μ₀ literal in paper - cset(m,1,1) # μ₁ literal in paper - cset(m,2,-1//3) # μ₂ literal in paper, but only in (4.22) - for i in 3:n-1 # coeffs are zero indexed - compa(i,m,a) - compm(i,m,a) +function compute_branch_point_coeffs(T::DataType, n::Int) + a = Array{T}(undef, n) + m = Array{T}(undef, n) + cset(a, 0, 2) # α₀ literal in paper + cset(a, 1, -1) # α₁ literal in paper + cset(a, 2, 0) # α₂ get this by solving (4.23) for alpha_2 with values printed in paper + cset(m, 0, -1) # μ₀ literal in paper + cset(m, 1, 1) # μ₁ literal in paper + cset(m, 2, -1//3) # μ₂ literal in paper, but only in (4.22) + for i in 3:(n - 1) # coeffs are zero indexed + compute_a_coeffs(i, m, a) + compute_m_coefficients(i, m, a) end return m end -const LAMWMU_FLOAT64 = lamwcoeff(Float64,500) +const BRANCH_POINT_COEFFS_FLOAT64 = compute_branch_point_coeffs(Float64, 500) # Base.Math.@horner requires literal coefficients -# But, we have an array `p` of computed coefficients -function horner(x, p::AbstractArray, n) +# It cannot be used here because we have an array of computed coefficients +function horner(x, coeffs::AbstractArray, n) n += 1 - ex = p[n] - for i = n-1:-1:2 - ex = :(muladd(t, $ex, $(p[i]))) + ex = coeffs[n] + for i = (n - 1):-1:2 + ex = :(muladd(t, $ex, $(coeffs[i]))) end ex = :( t * $ex) return Expr(:block, :(t = $x), ex) end -function mkwser(name, n) - iex = horner(:x,LAMWMU_FLOAT64,n) - return :(function ($name)(x) $iex end) +# write functions that evaluate the branch point series +# with `num_terms` number of terms. +for (func_name, num_terms) in ( + (:wser3, 3), (:wser5, 5), (:wser7, 7), (:wser12, 12), + (:wser19, 19), (:wser26, 26), (:wser32, 32), + (:wser50, 50), (:wser100, 100), (:wser290, 290)) + iex = horner(:x, BRANCH_POINT_COEFFS_FLOAT64, num_terms) + @eval function ($func_name)(x) $iex end end -eval(mkwser(:wser3, 3)) -eval(mkwser(:wser5, 5)) -eval(mkwser(:wser7, 7)) -eval(mkwser(:wser12, 12)) -eval(mkwser(:wser19, 19)) -eval(mkwser(:wser26, 26)) -eval(mkwser(:wser32, 32)) -eval(mkwser(:wser50, 50)) -eval(mkwser(:wser100, 100)) -eval(mkwser(:wser290, 290)) - # Converges to Float64 precision -# We could get finer tuning by separating k=0,-1 branches. -function wser(p,x) +# We could get finer tuning by separating k=0, -1 branches. +# Why is wser5 omitted ? +# p is the argument to the series which is computed +# from x before calling `branch_point_series`. +function branch_point_series(p, x) x < 4e-11 && return wser3(p) x < 1e-5 && return wser7(p) x < 1e-3 && return wser12(p) @@ -294,12 +282,12 @@ function wser(p,x) x < 5e-2 && return wser32(p) x < 1e-1 && return wser50(p) x < 1.9e-1 && return wser100(p) - x > 1/MathConstants.e && throw(DomainError(x)) # radius of convergence + x > 1 / MathConstants.e && throw(DomainError(x)) # radius of convergence return wser290(p) # good for x approx .32 end # These may need tuning. -function wser(p::Complex{T},z) where T<:Real +function branch_point_series(p::Complex{T}, z) where T<:Real x = abs(z) x < 4e-11 && return wser3(p) x < 1e-5 && return wser7(p) @@ -309,29 +297,31 @@ function wser(p::Complex{T},z) where T<:Real x < 5e-2 && return wser32(p) x < 1e-1 && return wser50(p) x < 1.9e-1 && return wser100(p) - x > 1/MathConstants.e && throw(DomainError(x)) # radius of convergence + x > 1 / MathConstants.e && throw(DomainError(x)) # radius of convergence return wser290(p) end -@inline function _lambertw0(x) # 1 + W(-1/e + x) , k = 0 - ps = 2*MathConstants.e*x; - p = sqrt(ps) - return wser(p,x) +function _lambertw0(x) # 1 + W(-1/e + x) , k = 0 + ps = 2 * MathConstants.e * x + series_arg = sqrt(ps) + branch_point_series(series_arg, x) end -@inline function _lambertwm1(x) # 1 + W(-1/e + x) , k = -1 - ps = 2*MathConstants.e*x; - p = -sqrt(ps) - return wser(p,x) +function _lambertwm1(x) # 1 + W(-1/e + x) , k = -1 + ps = 2 * MathConstants.e * x + series_arg = -sqrt(ps) + branch_point_series(series_arg, x) end """ - lambertwbp(z,k=0) + lambertwbp(z, k=0) -Compute accurate value of `1 + W(-1/e + z)`, for `abs(z)` in `[0,1/e]` for `k` either `0` or `-1`. +Compute accurate value of `1 + W(-1/e + z)`, for `abs(z)` in `[0, 1/e]` for `k` either `0` or `-1`. +This function is faster and more accurate near the branch point `-1/e` between `k=0` and `k=1`. The result is accurate to Float64 precision for abs(z) < 0.32. If `k=-1` and `imag(z) < 0`, the value on the branch `k=1` is returned. +# Example ```jldoctest julia> lambertw(-1/e + 1e-18, -1) -1.0 @@ -340,7 +330,7 @@ julia> lambertwbp(1e-18, -1) -2.331643983409312e-9 # Same result, but 1000 times slower -julia> convert(Float64,(lambertw(-BigFloat(1)/e + BigFloat(10)^(-18),-1) + 1)) +julia> convert(Float64, (lambertw(-BigFloat(1)/e + BigFloat(10)^(-18), -1) + 1)) -2.331643983409312e-9 ``` @@ -349,7 +339,7 @@ julia> convert(Float64,(lambertw(-BigFloat(1)/e + BigFloat(10)^(-18),-1) + 1)) The loss of precision in `lambertw` is analogous to the loss of precision in computing the `sqrt(1-x)` for `x` close to `1`. """ -function lambertwbp(x::Number,k::Integer) +function lambertwbp(x::Number, k::Integer) k == 0 && return _lambertw0(x) k == -1 && return _lambertwm1(x) throw(ArgumentError("expansion about branch point only implemented for k = 0 and -1.")) diff --git a/test/lambertw.jl b/test/lambertw.jl index 46d5dd57..8a761645 100644 --- a/test/lambertw.jl +++ b/test/lambertw.jl @@ -1,94 +1,90 @@ -using Compat - -import Compat.MathConstants - ### domain errors -@test_throws DomainError lambertw(-2.0,0) -@test_throws DomainError lambertw(-2.0,-1) -@test_throws DomainError lambertw(-2.0,1) +@test_throws DomainError lambertw(-2.0, 0) +@test_throws DomainError lambertw(-2.0, -1) +@test_throws DomainError lambertw(-2.0, 1) @test isnan(lambertw(NaN)) ## math constant e -@test_throws DomainError lambertw(MathConstants.e,1) -@test_throws DomainError lambertw(MathConstants.e,-1) +@test_throws DomainError lambertw(MathConstants.e, 1) +@test_throws DomainError lambertw(MathConstants.e, -1) ## integer arguments return floating point types -@test typeof(lambertw(0)) <: AbstractFloat +@test lambertw(0) isa AbstractFloat @test lambertw(0) == 0 ### math constant, MathConstants.e e # could return math const e, but this would break type stability -@test typeof(lambertw(1)) <: AbstractFloat -@test lambertw(MathConstants.e,0) == 1 +@test lambertw(1) isa AbstractFloat +@test lambertw(MathConstants.e, 0) == 1 ## value at branch point where real branches meet -@test lambertw(-1/MathConstants.e,0) == lambertw(-1/MathConstants.e,-1) == -1 -@test typeof(lambertw(-1/MathConstants.e,0)) == typeof(lambertw(-1/MathConstants.e,-1)) <: AbstractFloat +@test lambertw(-inv(MathConstants.e), 0) == lambertw(-inv(MathConstants.e), -1) == -1 +@test typeof(lambertw(-inv(MathConstants.e), 0)) == typeof(lambertw(-inv(MathConstants.e), -1)) <: AbstractFloat ## convert irrationals to float @test isapprox(lambertw(pi), 1.0736581947961492) -@test isapprox(lambertw(pi,0), 1.0736581947961492) +@test isapprox(lambertw(pi, 0), 1.0736581947961492) ### infinite args or return values -@test lambertw(0,-1) == lambertw(0.0,-1) == -Inf -@test lambertw(Inf,0) == Inf -@test lambertw(complex(Inf,1),0) == complex(Inf,1) -@test lambertw(complex(Inf,0),1) == complex(Inf,2pi) -@test lambertw(complex(-Inf,0),1) == complex(Inf,3pi) -@test lambertw(complex(0.0,0.0),-1) == complex(-Inf,0.0) +@test lambertw(0, -1) == lambertw(0.0, -1) == -Inf +@test lambertw(Inf, 0) == Inf +@test lambertw(complex(Inf, 1), 0) == complex(Inf, 1) +@test lambertw(complex(Inf, 0), 1) == complex(Inf, 2pi) +@test lambertw(complex(-Inf, 0), 1) == complex(Inf, 3pi) +@test lambertw(complex(0.0, 0.0), -1) == complex(-Inf, 0.0) ## default branch is k = 0 -@test lambertw(1.0) == lambertw(1.0,0) +@test lambertw(1.0) == lambertw(1.0, 0) ## BigInt args return BigFloats @test typeof(lambertw(BigInt(0))) == BigFloat @test typeof(lambertw(BigInt(3))) == BigFloat ## Any Integer type allowed for second argument -@test lambertw(-0.2,-1) == lambertw(-0.2,BigInt(-1)) +@test lambertw(-0.2, -1) == lambertw(-0.2, BigInt(-1)) ## BigInt for second arg does not promote the type -@test typeof(lambertw(-0.2,-1)) == typeof(lambertw(-0.2,BigInt(-1))) +@test typeof(lambertw(-0.2, -1)) == typeof(lambertw(-0.2, BigInt(-1))) -for (z,k,res) in [ (0,0 ,0), (complex(0,0),0 ,0), - (complex(0.0,0),0 ,0), (complex(1.0,0),0, 0.567143290409783873) ] +for (z, k, res) in [(0, 0 , 0), (complex(0, 0), 0 , 0), + (complex(0.0, 0), 0 , 0), (complex(1.0, 0), 0, 0.567143290409783873)] if Int != Int32 - @test isapprox(lambertw(z,k), res) + @test isapprox(lambertw(z, k), res) @test isapprox(lambertw(z), res) else - @test isapprox(lambertw(z,k), res; rtol = 1e-14) + @test isapprox(lambertw(z, k), res; rtol = 1e-14) @test isapprox(lambertw(z), res; rtol = 1e-14) end end -for (z,k) in ((complex(1,1),2), (complex(1,1),0),(complex(.6,.6),0), - (complex(.6,-.6),0)) +for (z, k) in ((complex(1, 1), 2), (complex(1, 1), 0), (complex(.6, .6), 0), + (complex(.6, -.6), 0)) let w - @test (w = lambertw(z,k) ; true) + @test (w = lambertw(z, k) ; true) @test abs(w*exp(w) - z) < 1e-15 end end -@test abs(lambertw(complex(-3.0,-4.0),0) - Complex(1.075073066569255, -1.3251023817343588)) < 1e-14 -@test abs(lambertw(complex(-3.0,-4.0),1) - Complex(0.5887666813694675, 2.7118802109452247)) < 1e-14 -@test (lambertw(complex(.3,.3),0); true) +@test abs(lambertw(complex(-3.0, -4.0), 0) - Complex(1.075073066569255, -1.3251023817343588)) < 1e-14 +@test abs(lambertw(complex(-3.0, -4.0), 1) - Complex(0.5887666813694675, 2.7118802109452247)) < 1e-14 +@test (lambertw(complex(.3, .3), 0); true) # bug fix # The routine will start at -1/e + eps * im, rather than -1/e + 0im, # otherwise root finding will fail if Int != Int32 - @test abs(lambertw(-1.0/MathConstants.e + 0im,-1)) == 1 + @test abs(lambertw(-inv(MathConstants.e) + 0im, -1)) == 1 else - @test abs(lambertw(-1.0/MathConstants.e + 0im,-1) + 1) < 1e-7 + @test abs(lambertw(-inv(MathConstants.e) + 0im, -1) + 1) < 1e-7 end # lambertw for BigFloat is more precise than Float64. Note # that 70 digits in test is about 35 digits in W let W - for z in [ BigFloat(1), BigFloat(2), complex(BigFloat(1), BigFloat(1))] + for z in [BigFloat(1), BigFloat(2), complex(BigFloat(1), BigFloat(1))] @test (W = lambertw(z); true) @test abs(z - W * exp(W)) < BigFloat(1)^(-70) end @@ -104,16 +100,16 @@ let sp = precision(BigFloat) end @test lambertw(1) == float(SpecialFunctions.omega) -@test convert(Float16,SpecialFunctions.omega) == convert(Float16,0.5674) -@test convert(Float32,SpecialFunctions.omega) == 0.56714326f0 +@test convert(Float16, SpecialFunctions.omega) == convert(Float16, 0.5674) +@test convert(Float32, SpecialFunctions.omega) == 0.56714326f0 @test lambertw(BigInt(1)) == big(SpecialFunctions.omega) ### expansion about branch point # not a domain error, but not implemented -@test_throws ArgumentError lambertwbp(1,1) +@test_throws ArgumentError lambertwbp(1, 1) -@test_throws DomainError lambertw(.3,2) +@test_throws DomainError lambertw(.3, 2) # Expansions about branch point converges almost to machine precision # except near the radius of convergence. @@ -125,7 +121,7 @@ if Int != Int32 let sp = precision(BigFloat), z = BigFloat(1)/10^12, wo, xdiff setprecision(2048) for i in 1:300 - innerarg = z-1/big(MathConstants.e) + innerarg = z - inv(big(MathConstants.e)) # branch k = 0 @test (wo = lambertwbp(Float64(z)); xdiff = abs(-1 + wo - lambertw(innerarg)); true) @@ -135,7 +131,7 @@ let sp = precision(BigFloat), z = BigFloat(1)/10^12, wo, xdiff @test xdiff < 5e-16 # branch k = -1 - @test (wo = lambertwbp(Float64(z),-1); xdiff = abs(-1 + wo - lambertw(innerarg,-1)); true) + @test (wo = lambertwbp(Float64(z), -1); xdiff = abs(-1 + wo - lambertw(innerarg, -1)); true) if xdiff > 5e-16 println(Float64(z), " ", Float64(xdiff)) end @@ -149,8 +145,8 @@ end # test the expansion about branch point for k=-1, # by comparing to exact BigFloat calculation. -@test lambertwbp(1e-20,-1) - 1 - lambertw(-BigFloat(1)/big(MathConstants.e)+ BigFloat(1)/BigFloat(10)^BigFloat(20),-1) < 1e-16 +@test lambertwbp(1e-20, -1) - 1 - lambertw(-inv(big(MathConstants.e)) + BigFloat(10)^BigFloat(-20), -1) < 1e-16 -@test abs(lambertwbp(Complex(.01,.01),-1) - Complex(-0.2755038208041206, -0.1277888928494641)) < 1e-14 +@test abs(lambertwbp(Complex(.01, .01), -1) - Complex(-0.2755038208041206, -0.1277888928494641)) < 1e-14 end # if Int != Int32 From 0f6bce336e01aa33810f26358f636fd1776dcf3c Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Fri, 31 Dec 2021 04:35:38 +0100 Subject: [PATCH 04/24] lambertw tests: check correct type inference --- test/lambertw.jl | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/test/lambertw.jl b/test/lambertw.jl index 8a761645..d638f594 100644 --- a/test/lambertw.jl +++ b/test/lambertw.jl @@ -3,21 +3,21 @@ @test_throws DomainError lambertw(-2.0, 0) @test_throws DomainError lambertw(-2.0, -1) @test_throws DomainError lambertw(-2.0, 1) -@test isnan(lambertw(NaN)) +@test isnan(@inferred(lambertw(NaN))) ## math constant e @test_throws DomainError lambertw(MathConstants.e, 1) @test_throws DomainError lambertw(MathConstants.e, -1) ## integer arguments return floating point types -@test lambertw(0) isa AbstractFloat -@test lambertw(0) == 0 +@test @inferred(lambertw(0)) isa AbstractFloat +@test @inferred(lambertw(0)) == 0 ### math constant, MathConstants.e e # could return math const e, but this would break type stability -@test lambertw(1) isa AbstractFloat -@test lambertw(MathConstants.e, 0) == 1 +@test @inferred(lambertw(1)) isa AbstractFloat +@test @inferred(lambertw(MathConstants.e, 0)) == 1 ## value at branch point where real branches meet @test lambertw(-inv(MathConstants.e), 0) == lambertw(-inv(MathConstants.e), -1) == -1 @@ -25,24 +25,24 @@ ## convert irrationals to float -@test isapprox(lambertw(pi), 1.0736581947961492) -@test isapprox(lambertw(pi, 0), 1.0736581947961492) +@test isapprox(@inferred(lambertw(pi)), 1.0736581947961492) +@test isapprox(@inferred(lambertw(pi, 0)), 1.0736581947961492) ### infinite args or return values @test lambertw(0, -1) == lambertw(0.0, -1) == -Inf @test lambertw(Inf, 0) == Inf -@test lambertw(complex(Inf, 1), 0) == complex(Inf, 1) +@test @inferred(lambertw(complex(Inf, 1), 0)) == complex(Inf, 1) @test lambertw(complex(Inf, 0), 1) == complex(Inf, 2pi) @test lambertw(complex(-Inf, 0), 1) == complex(Inf, 3pi) -@test lambertw(complex(0.0, 0.0), -1) == complex(-Inf, 0.0) +@test @inferred(lambertw(complex(0.0, 0.0), -1)) == complex(-Inf, 0.0) ## default branch is k = 0 @test lambertw(1.0) == lambertw(1.0, 0) ## BigInt args return BigFloats -@test typeof(lambertw(BigInt(0))) == BigFloat -@test typeof(lambertw(BigInt(3))) == BigFloat +@test @inferred(lambertw(BigInt(0))) isa BigFloat +@test @inferred(lambertw(BigInt(3))) isa BigFloat ## Any Integer type allowed for second argument @test lambertw(-0.2, -1) == lambertw(-0.2, BigInt(-1)) @@ -145,8 +145,8 @@ end # test the expansion about branch point for k=-1, # by comparing to exact BigFloat calculation. -@test lambertwbp(1e-20, -1) - 1 - lambertw(-inv(big(MathConstants.e)) + BigFloat(10)^BigFloat(-20), -1) < 1e-16 +@test @inferred(lambertwbp(1e-20, -1)) - 1 - lambertw(-inv(big(MathConstants.e)) + BigFloat(10)^BigFloat(-20), -1) < 1e-16 -@test abs(lambertwbp(Complex(.01, .01), -1) - Complex(-0.2755038208041206, -0.1277888928494641)) < 1e-14 +@test abs(@inferred(lambertwbp(Complex(.01, .01), -1)) - Complex(-0.2755038208041206, -0.1277888928494641)) < 1e-14 end # if Int != Int32 From 46f5b545e3654fe569103bae159f8b61387c4833 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Fri, 31 Dec 2021 04:44:10 +0100 Subject: [PATCH 05/24] lambertw tests: correct approx equality tests --- test/lambertw.jl | 79 +++++++++++++++++++----------------------------- 1 file changed, 31 insertions(+), 48 deletions(-) diff --git a/test/lambertw.jl b/test/lambertw.jl index d638f594..9ba2e8e0 100644 --- a/test/lambertw.jl +++ b/test/lambertw.jl @@ -25,8 +25,8 @@ ## convert irrationals to float -@test isapprox(@inferred(lambertw(pi)), 1.0736581947961492) -@test isapprox(@inferred(lambertw(pi, 0)), 1.0736581947961492) +@test @inferred(lambertw(pi)) ≈ 1.0736581947961492 +@test @inferred(lambertw(pi, 0)) ≈ 1.0736581947961492 ### infinite args or return values @@ -53,41 +53,36 @@ for (z, k, res) in [(0, 0 , 0), (complex(0, 0), 0 , 0), (complex(0.0, 0), 0 , 0), (complex(1.0, 0), 0, 0.567143290409783873)] if Int != Int32 - @test isapprox(lambertw(z, k), res) - @test isapprox(lambertw(z), res) + @test lambertw(z, k) ≈ res + @test lambertw(z) ≈ res else - @test isapprox(lambertw(z, k), res; rtol = 1e-14) - @test isapprox(lambertw(z), res; rtol = 1e-14) + @test lambertw(z, k) ≈ res rtol=1e-14 + @test lambertw(z) ≈ res rtol=1e-14 end end -for (z, k) in ((complex(1, 1), 2), (complex(1, 1), 0), (complex(.6, .6), 0), - (complex(.6, -.6), 0)) - let w - @test (w = lambertw(z, k) ; true) - @test abs(w*exp(w) - z) < 1e-15 - end +@testset "complex z=$z, k=$k" for (z, k) in + ((complex(1, 1), 2), (complex(1, 1), 0), (complex(.6, .6), 0), + (complex(.6, -.6), 0)) + w = lambertw(z, k) + @test w*exp(w) ≈ z atol=1e-15 end -@test abs(lambertw(complex(-3.0, -4.0), 0) - Complex(1.075073066569255, -1.3251023817343588)) < 1e-14 -@test abs(lambertw(complex(-3.0, -4.0), 1) - Complex(0.5887666813694675, 2.7118802109452247)) < 1e-14 -@test (lambertw(complex(.3, .3), 0); true) +@test lambertw(complex(-3.0, -4.0), 0) ≈ Complex(1.075073066569255, -1.3251023817343588) atol=1e-14 +@test lambertw(complex(-3.0, -4.0), 1) ≈ Complex(0.5887666813694675, 2.7118802109452247) atol=1e-14 +@test lambertw(complex(.3, .3)) ≈ Complex(0.26763519642648767, 0.1837481231767825) # bug fix # The routine will start at -1/e + eps * im, rather than -1/e + 0im, # otherwise root finding will fail -if Int != Int32 - @test abs(lambertw(-inv(MathConstants.e) + 0im, -1)) == 1 -else - @test abs(lambertw(-inv(MathConstants.e) + 0im, -1) + 1) < 1e-7 -end +@test lambertw(-inv(MathConstants.e) + 0im, -1) ≈ -1 atol=1e-7 + # lambertw for BigFloat is more precise than Float64. Note # that 70 digits in test is about 35 digits in W -let W - for z in [BigFloat(1), BigFloat(2), complex(BigFloat(1), BigFloat(1))] - @test (W = lambertw(z); true) - @test abs(z - W * exp(W)) < BigFloat(1)^(-70) - end +@testset "lambertw() for BigFloat z=$z" for z in + [BigFloat(1), BigFloat(2), complex(BigFloat(1), BigFloat(1))] + W = lambertw(z) + @test z ≈ W*exp(W) atol=BigFloat(10)^(-70) end ### ω constant @@ -117,36 +112,24 @@ end if Int != Int32 -# Test double-precision expansion near branch point using BigFloats -let sp = precision(BigFloat), z = BigFloat(1)/10^12, wo, xdiff - setprecision(2048) - for i in 1:300 - innerarg = z - inv(big(MathConstants.e)) +@testset "double-precision expansion near branch point using BigFloats" begin + setprecision(2048) do + z = BigFloat(10)^(-12) + for _ in 1:300 + innerarg = z - inv(big(MathConstants.e)) - # branch k = 0 - @test (wo = lambertwbp(Float64(z)); xdiff = abs(-1 + wo - lambertw(innerarg)); true) - if xdiff > 5e-16 - println(Float64(z), " ", Float64(xdiff)) - end - @test xdiff < 5e-16 + @test lambertwbp(Float64(z)) ≈ 1 + lambertw(innerarg) atol=5e-16 + @test lambertwbp(Float64(z), -1) ≈ 1 + lambertw(innerarg, -1) atol=5e-16 + z *= 1.1 + if z > 0.23 break end - # branch k = -1 - @test (wo = lambertwbp(Float64(z), -1); xdiff = abs(-1 + wo - lambertw(innerarg, -1)); true) - if xdiff > 5e-16 - println(Float64(z), " ", Float64(xdiff)) end - @test xdiff < 5e-16 - z *= 1.1 - if z > 0.23 break end - end - setprecision(sp) end # test the expansion about branch point for k=-1, # by comparing to exact BigFloat calculation. -@test @inferred(lambertwbp(1e-20, -1)) - 1 - lambertw(-inv(big(MathConstants.e)) + BigFloat(10)^BigFloat(-20), -1) < 1e-16 - -@test abs(@inferred(lambertwbp(Complex(.01, .01), -1)) - Complex(-0.2755038208041206, -0.1277888928494641)) < 1e-14 +@test @inferred(lambertwbp(1e-20, -1)) ≈ 1 + lambertw(-inv(big(MathConstants.e)) + BigFloat(10)^(-20), -1) atol=1e-16 +@test @inferred(lambertwbp(Complex(.01, .01), -1)) ≈ Complex(-0.2755038208041206, -0.1277888928494641) atol=1e-14 end # if Int != Int32 From 0894aade4a9b80d9d03a8238a5ba88720fdc6259 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Fri, 31 Dec 2021 04:47:02 +0100 Subject: [PATCH 06/24] lambertwbp: simplify series coeffs code --- src/lambertw.jl | 56 +++++++++++++++++++------------------------------ 1 file changed, 22 insertions(+), 34 deletions(-) diff --git a/src/lambertw.jl b/src/lambertw.jl index 0f76c668..3afb9bb5 100644 --- a/src/lambertw.jl +++ b/src/lambertw.jl @@ -201,45 +201,33 @@ Base.BigFloat(::Irrational{:ω}) = omega_const(BigFloat) # (4.23) and (4.24) for all μ are also given. This code implements the # recursion relations. -# (4.23) and (4.24) give zero based coefficients. -cset(a, i, v) = a[i+1] = v -cget(a, i) = a[i+1] - -# (4.24) -function compute_a_coeffs(k, m, a) - sum0 = zero(eltype(m)) - for j in 2:(k - 1) - sum0 += cget(m, j) * cget(m, k + 1 - j) - end - cset(a, k, sum0) - return sum0 -end - -# (4.23) -function compute_m_coefficients(k, m, a) - kt = convert(eltype(m), k) - mk = (kt - 1) / (kt + 1) *(cget(m, k - 2) / 2 + cget(a, k - 2) / 4) - - cget(a, k) / 2 - cget(m, k - 1) / (kt + 1) - cset(m, k, mk) - return mk -end - # We plug the known value μ₂ == -1//3 for (4.22) into (4.23) and # solve for α₂. We get α₂ = 0. # Compute array of coefficients μ in (4.22). # m[1] is μ₀ function compute_branch_point_coeffs(T::DataType, n::Int) - a = Array{T}(undef, n) - m = Array{T}(undef, n) - cset(a, 0, 2) # α₀ literal in paper - cset(a, 1, -1) # α₁ literal in paper - cset(a, 2, 0) # α₂ get this by solving (4.23) for alpha_2 with values printed in paper - cset(m, 0, -1) # μ₀ literal in paper - cset(m, 1, 1) # μ₁ literal in paper - cset(m, 2, -1//3) # μ₂ literal in paper, but only in (4.22) - for i in 3:(n - 1) # coeffs are zero indexed - compute_a_coeffs(i, m, a) - compute_m_coefficients(i, m, a) + a = Vector{T}(undef, n) + m = Vector{T}(undef, n) + + a[1] = 2 # α₀ literal in paper + a[2] = -1 # α₁ literal in paper + a[3] = 0 # α₂ get this by solving (4.23) for alpha_2 with values printed in paper + m[1] = -1 # μ₀ literal in paper + m[2] = 1 # μ₁ literal in paper + m[3] = -1//3 # μ₂ literal in paper, but only in (4.22) + + for i in 4:n + # (4.24) + msum = zero(T) + @inbounds for j in 2:(i - 2) + msum += m[j + 1] * m[i + 1 - j] + end + a[i] = msum + + # (4.23) + it = convert(T, i) + m[i] = (it - 2) / it *(m[i - 2] / 2 + a[i - 2] / 4) - + a[i] / 2 - m[i - 1] / it end return m end From 59f1f844af4abe0d1a51a71c876d55facdd52ffb Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Fri, 31 Dec 2021 10:46:46 +0100 Subject: [PATCH 07/24] lambertw_root_find: fix diff type --- src/lambertw.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lambertw.jl b/src/lambertw.jl index 3afb9bb5..2794b613 100644 --- a/src/lambertw.jl +++ b/src/lambertw.jl @@ -126,15 +126,15 @@ function lambertw_root_finding(z::T, x0::T, maxits) where T <: Number two_t = convert(T, 2) x = x0 lastx = x - lastdiff = zero(T) - converged::Bool = false + lastdiff = zero(real(T)) + converged = false for i in 1:maxits ex = exp(x) xexz = x * ex - z x1 = x + 1 x -= xexz / (ex * x1 - (x + two_t) * xexz / (two_t * x1 )) xdiff = abs(lastx - x) - if xdiff <= 3 * eps(abs(lastx)) || lastdiff == xdiff # second condition catches two-value cycle + if xdiff <= 3 * eps(lastdiff) || lastdiff == xdiff # second condition catches two-value cycle converged = true break end From edf1820d7be0d83a8247a78fd5de5adf96384ea0 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Fri, 31 Dec 2021 10:35:45 +0100 Subject: [PATCH 08/24] lambertw: simplify handling non-finite vals --- src/lambertw.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/lambertw.jl b/src/lambertw.jl index 2794b613..bb0e58c1 100644 --- a/src/lambertw.jl +++ b/src/lambertw.jl @@ -52,16 +52,14 @@ function _lambertw(x::Real, k, maxits) end # Real x, k = 0 -# This appears to be inferrable with T=Float64 and T=BigFloat, including if x=Inf. # There is a magic number here. It could be noted, or possibly removed. # In particular, the fancy initial condition selection does not seem to help speed. -function lambertw_branch_zero(x::T, maxits)::T where T<:Real - isnan(x) && return(NaN) - x == Inf && return Inf # appears to return convert(BigFloat, Inf) for x == BigFloat(Inf) +function lambertw_branch_zero(x::T, maxits) where T<:Real + isfinite(x) || return x one_t = one(T) oneoe = -one_t / convert(T, MathConstants.e) # The branch point x == oneoe && return -one_t - oneoe <= x || throw(DomainError(x)) + oneoe < x || throw(DomainError(x)) itwo_t = 1 / convert(T, 2) if x > one_t lx = log(x) @@ -77,7 +75,7 @@ end function lambertw_branch_one(x::T, maxits) where T<:Real oneoe = -one(T) / convert(T, MathConstants.e) x == oneoe && return -one(T) # W approaches -1 as x -> -1/e from above - oneoe <= x || throw(DomainError(x)) # branch domain exludes x < -1/e + oneoe < x || throw(DomainError(x)) # branch domain exludes x < -1/e x == zero(T) && return -convert(T, Inf) # W decreases w/o bound as x -> 0 from below x < zero(T) || throw(DomainError(x)) return lambertw_root_finding(x, log(-x), maxits) From 83d79cacc3c20a2f1e8baba8a78abdbf51d9a8fc Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Fri, 31 Dec 2021 10:38:56 +0100 Subject: [PATCH 09/24] lambertw: use inv(e) --- src/lambertw.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lambertw.jl b/src/lambertw.jl index bb0e58c1..321e2376 100644 --- a/src/lambertw.jl +++ b/src/lambertw.jl @@ -57,7 +57,7 @@ end function lambertw_branch_zero(x::T, maxits) where T<:Real isfinite(x) || return x one_t = one(T) - oneoe = -one_t / convert(T, MathConstants.e) # The branch point + oneoe = -inv(convert(T, MathConstants.e)) # The branch point x == oneoe && return -one_t oneoe < x || throw(DomainError(x)) itwo_t = 1 / convert(T, 2) @@ -73,7 +73,7 @@ end # Real x, k = -1 function lambertw_branch_one(x::T, maxits) where T<:Real - oneoe = -one(T) / convert(T, MathConstants.e) + oneoe = -inv(convert(T, MathConstants.e)) x == oneoe && return -one(T) # W approaches -1 as x -> -1/e from above oneoe < x || throw(DomainError(x)) # branch domain exludes x < -1/e x == zero(T) && return -convert(T, Inf) # W decreases w/o bound as x -> 0 from below @@ -89,7 +89,7 @@ function _lambertw(z::Complex{T}, k, maxits) where T<:Real one_t = one(T) local w::Complex{T} pointseven = 7//10 - if abs(z) <= one_t/convert(T, MathConstants.e) + if abs(z) <= inv(convert(T, MathConstants.e)) if z == 0 k == 0 && return z return complex(-convert(T, Inf), zero(T)) From 87f97493a5939254cd28307017faded0732de26b Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Fri, 31 Dec 2021 10:45:35 +0100 Subject: [PATCH 10/24] lambertw: annotate function args with types --- src/lambertw.jl | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/src/lambertw.jl b/src/lambertw.jl index 321e2376..08982f52 100644 --- a/src/lambertw.jl +++ b/src/lambertw.jl @@ -1,8 +1,7 @@ #### Lambert W function #### """ - lambertw(z::Complex{T}, k::V=0, maxits=1000) where {T<:Real, V<:Integer} - lambertw(z::T, k::V=0, maxits=1000) where {T<:Real, V<:Integer} + lambertw(z::Number, k::Integer=0, maxits=1000) Compute the `k`th branch of the Lambert W function of `z`. If `z` is real, `k` must be either `0` or `-1`. For `Real` `z`, the domain of the branch `k = -1` is `[-1/e, 0]` and the @@ -26,16 +25,16 @@ julia> lambertw(Complex(-10.0, 3.0), 4) -0.9274337508660128 + 26.37693445371142im ``` """ -lambertw(z, k::Integer=0, maxits::Integer=1000) = _lambertw(float(z), k, maxits) +lambertw(z::Number, k::Integer=0, maxits::Integer=1000) = _lambertw(float(z), k, maxits) # lambertw(e + 0im, k) is ok for all k # Maybe this should return a float. But, this should cause no type instability in any case -function _lambertw(::typeof(MathConstants.e), k, maxits) +function _lambertw(::typeof(MathConstants.e), k::Integer, maxits::Integer) k == 0 && return 1 throw(DomainError(k)) end -_lambertw(x::Irrational, k, maxits) = _lambertw(float(x), k, maxits) -function _lambertw(x::Union{Integer, Rational}, k, maxits) +_lambertw(x::Irrational, k::Integer, maxits::Integer) = _lambertw(float(x), k, maxits) +function _lambertw(x::Union{Integer, Rational}, k::Integer, maxits::Integer) if k == 0 x == 0 && return float(zero(x)) x == 1 && return convert(typeof(float(x)), omega) # must be a more efficient way @@ -45,7 +44,7 @@ end ### Real x -function _lambertw(x::Real, k, maxits) +function _lambertw(x::Real, k::Integer, maxits::Integer) k == 0 && return lambertw_branch_zero(x, maxits) k == -1 && return lambertw_branch_one(x, maxits) throw(DomainError(k, "lambertw: real x must have branch k == 0 or k == -1")) @@ -54,7 +53,7 @@ end # Real x, k = 0 # There is a magic number here. It could be noted, or possibly removed. # In particular, the fancy initial condition selection does not seem to help speed. -function lambertw_branch_zero(x::T, maxits) where T<:Real +function lambertw_branch_zero(x::T, maxits::Integer) where T<:Real isfinite(x) || return x one_t = one(T) oneoe = -inv(convert(T, MathConstants.e)) # The branch point @@ -72,7 +71,7 @@ function lambertw_branch_zero(x::T, maxits) where T<:Real end # Real x, k = -1 -function lambertw_branch_one(x::T, maxits) where T<:Real +function lambertw_branch_one(x::T, maxits::Integer) where T<:Real oneoe = -inv(convert(T, MathConstants.e)) x == oneoe && return -one(T) # W approaches -1 as x -> -1/e from above oneoe < x || throw(DomainError(x)) # branch domain exludes x < -1/e @@ -83,10 +82,9 @@ end ### Complex z -_lambertw(z::Complex{<:Integer}, k, maxits) = _lambertw(float(z), k, maxits) +_lambertw(z::Complex{<:Integer}, k::Integer, maxits::Integer) = _lambertw(float(z), k, maxits) # choose initial value inside correct branch for root finding -function _lambertw(z::Complex{T}, k, maxits) where T<:Real - one_t = one(T) +function _lambertw(z::Complex{T}, k::Integer, maxits::Integer) where T<:Real local w::Complex{T} pointseven = 7//10 if abs(z) <= inv(convert(T, MathConstants.e)) @@ -120,7 +118,7 @@ end # Use Halley's root-finding method to find # x = lambertw(z) with initial point x0. -function lambertw_root_finding(z::T, x0::T, maxits) where T <: Number +function lambertw_root_finding(z::T, x0::T, maxits::Integer) where T <: Number two_t = convert(T, 2) x = x0 lastx = x @@ -203,7 +201,7 @@ Base.BigFloat(::Irrational{:ω}) = omega_const(BigFloat) # solve for α₂. We get α₂ = 0. # Compute array of coefficients μ in (4.22). # m[1] is μ₀ -function compute_branch_point_coeffs(T::DataType, n::Int) +function compute_branch_point_coeffs(T::Type{<:Number}, n::Integer) a = Vector{T}(undef, n) m = Vector{T}(undef, n) @@ -259,7 +257,7 @@ end # Why is wser5 omitted ? # p is the argument to the series which is computed # from x before calling `branch_point_series`. -function branch_point_series(p, x) +function branch_point_series(p::Real, x::Real) x < 4e-11 && return wser3(p) x < 1e-5 && return wser7(p) x < 1e-3 && return wser12(p) @@ -273,7 +271,7 @@ function branch_point_series(p, x) end # These may need tuning. -function branch_point_series(p::Complex{T}, z) where T<:Real +function branch_point_series(p::Complex{T}, z::Complex{T}) where T<:Real x = abs(z) x < 4e-11 && return wser3(p) x < 1e-5 && return wser7(p) @@ -287,13 +285,13 @@ function branch_point_series(p::Complex{T}, z) where T<:Real return wser290(p) end -function _lambertw0(x) # 1 + W(-1/e + x) , k = 0 +function _lambertw0(x::Number) # 1 + W(-1/e + x) , k = 0 ps = 2 * MathConstants.e * x series_arg = sqrt(ps) branch_point_series(series_arg, x) end -function _lambertwm1(x) # 1 + W(-1/e + x) , k = -1 +function _lambertwm1(x::Number) # 1 + W(-1/e + x) , k = -1 ps = 2 * MathConstants.e * x series_arg = -sqrt(ps) branch_point_series(series_arg, x) From d39470c0197ea85dc76d2b285c912aa0966ce04a Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Fri, 31 Dec 2021 10:51:48 +0100 Subject: [PATCH 11/24] lambertw: use dispatch to compute diff. branches --- src/lambertw.jl | 38 ++++++++++++++------------------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/src/lambertw.jl b/src/lambertw.jl index 08982f52..02aceabe 100644 --- a/src/lambertw.jl +++ b/src/lambertw.jl @@ -44,16 +44,12 @@ end ### Real x -function _lambertw(x::Real, k::Integer, maxits::Integer) - k == 0 && return lambertw_branch_zero(x, maxits) - k == -1 && return lambertw_branch_one(x, maxits) - throw(DomainError(k, "lambertw: real x must have branch k == 0 or k == -1")) -end +_lambertw(x::Real, k::Integer, maxits::Integer) = _lambertw(x, Val(Int(k)), maxits) # Real x, k = 0 # There is a magic number here. It could be noted, or possibly removed. # In particular, the fancy initial condition selection does not seem to help speed. -function lambertw_branch_zero(x::T, maxits::Integer) where T<:Real +function _lambertw(x::T, ::Val{0}, maxits::Integer) where T<:Real isfinite(x) || return x one_t = one(T) oneoe = -inv(convert(T, MathConstants.e)) # The branch point @@ -71,7 +67,7 @@ function lambertw_branch_zero(x::T, maxits::Integer) where T<:Real end # Real x, k = -1 -function lambertw_branch_one(x::T, maxits::Integer) where T<:Real +function _lambertw(x::T, ::Val{-1}, maxits::Integer) where T<:Real oneoe = -inv(convert(T, MathConstants.e)) x == oneoe && return -one(T) # W approaches -1 as x -> -1/e from above oneoe < x || throw(DomainError(x)) # branch domain exludes x < -1/e @@ -80,6 +76,9 @@ function lambertw_branch_one(x::T, maxits::Integer) where T<:Real return lambertw_root_finding(x, log(-x), maxits) end +_lambertw(x::Real, k::Val, maxits::Integer) = + throw(DomainError(x, "lambertw: for branch k=$k not defined, real x must have branch k == 0 or k == -1")) + ### Complex z _lambertw(z::Complex{<:Integer}, k::Integer, maxits::Integer) = _lambertw(float(z), k, maxits) @@ -285,17 +284,14 @@ function branch_point_series(p::Complex{T}, z::Complex{T}) where T<:Real return wser290(p) end -function _lambertw0(x::Number) # 1 + W(-1/e + x) , k = 0 - ps = 2 * MathConstants.e * x - series_arg = sqrt(ps) - branch_point_series(series_arg, x) -end +_lambertwbp(x::Number, ::Val{0}) = + branch_point_series(sqrt(2 * MathConstants.e * x), x) -function _lambertwm1(x::Number) # 1 + W(-1/e + x) , k = -1 - ps = 2 * MathConstants.e * x - series_arg = -sqrt(ps) - branch_point_series(series_arg, x) -end +_lambertwbp(x::Number, ::Val{-1}) = + branch_point_series(-sqrt(2 * MathConstants.e * x), x) + +_lambertwbp(_::Number, k::Val) = + throw(ArgumentError("lambertw() expansion about branch point for k=$k not implemented (only implemented for 0 and -1).")) """ lambertwbp(z, k=0) @@ -323,10 +319,4 @@ julia> convert(Float64, (lambertw(-BigFloat(1)/e + BigFloat(10)^(-18), -1) + 1)) The loss of precision in `lambertw` is analogous to the loss of precision in computing the `sqrt(1-x)` for `x` close to `1`. """ -function lambertwbp(x::Number, k::Integer) - k == 0 && return _lambertw0(x) - k == -1 && return _lambertwm1(x) - throw(ArgumentError("expansion about branch point only implemented for k = 0 and -1.")) -end - -lambertwbp(x::Number) = _lambertw0(x) +lambertwbp(x::Number, k::Integer=0) = _lambertwbp(x, Val(Int(k))) From 8f0e616b45b5d848df65811af4bf1542f21c215f Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Fri, 31 Dec 2021 10:56:31 +0100 Subject: [PATCH 12/24] lambertwbp_series(): simplify - use Base.evalpoly() instead of horner macro reimplementation - use Val()-based dispatch instead of function generation --- src/lambertw.jl | 81 +++++++++++++++++-------------------------------- 1 file changed, 28 insertions(+), 53 deletions(-) diff --git a/src/lambertw.jl b/src/lambertw.jl index 02aceabe..ecb7f697 100644 --- a/src/lambertw.jl +++ b/src/lambertw.jl @@ -200,7 +200,7 @@ Base.BigFloat(::Irrational{:ω}) = omega_const(BigFloat) # solve for α₂. We get α₂ = 0. # Compute array of coefficients μ in (4.22). # m[1] is μ₀ -function compute_branch_point_coeffs(T::Type{<:Number}, n::Integer) +function lambertw_coeffs(T::Type{<:Number}, n::Integer) a = Vector{T}(undef, n) m = Vector{T}(undef, n) @@ -227,68 +227,43 @@ function compute_branch_point_coeffs(T::Type{<:Number}, n::Integer) return m end -const BRANCH_POINT_COEFFS_FLOAT64 = compute_branch_point_coeffs(Float64, 500) +const LAMBERTW_COEFFS_FLOAT64 = lambertw_coeffs(Float64, 500) -# Base.Math.@horner requires literal coefficients -# It cannot be used here because we have an array of computed coefficients -function horner(x, coeffs::AbstractArray, n) - n += 1 - ex = coeffs[n] - for i = (n - 1):-1:2 - ex = :(muladd(t, $ex, $(coeffs[i]))) - end - ex = :( t * $ex) - return Expr(:block, :(t = $x), ex) -end - -# write functions that evaluate the branch point series -# with `num_terms` number of terms. -for (func_name, num_terms) in ( - (:wser3, 3), (:wser5, 5), (:wser7, 7), (:wser12, 12), - (:wser19, 19), (:wser26, 26), (:wser32, 32), - (:wser50, 50), (:wser100, 100), (:wser290, 290)) - iex = horner(:x, BRANCH_POINT_COEFFS_FLOAT64, num_terms) - @eval function ($func_name)(x) $iex end -end +(lambertwbp_evalpoly(x::T, ::Val{N})::T) where {T<:Number, N} = + # assume that Julia compiler is smart to decide for which N to unroll at compile time + # note that we skip μ₀=-1 + evalpoly(x, ntuple(i -> LAMBERTW_COEFFS_FLOAT64[i+1], N-1))*x -# Converges to Float64 precision +# how many coefficients of the series to use +# to converge to Float64 precision for given x # We could get finer tuning by separating k=0, -1 branches. -# Why is wser5 omitted ? -# p is the argument to the series which is computed -# from x before calling `branch_point_series`. -function branch_point_series(p::Real, x::Real) - x < 4e-11 && return wser3(p) - x < 1e-5 && return wser7(p) - x < 1e-3 && return wser12(p) - x < 1e-2 && return wser19(p) - x < 3e-2 && return wser26(p) - x < 5e-2 && return wser32(p) - x < 1e-1 && return wser50(p) - x < 1.9e-1 && return wser100(p) - x > 1 / MathConstants.e && throw(DomainError(x)) # radius of convergence - return wser290(p) # good for x approx .32 +function lambertwbp_series_length(x::Real) + x < 4e-11 && return 3 + # Why N = 5 is omitted? + x < 1e-5 && return 7 + x < 1e-3 && return 12 + x < 1e-2 && return 19 + x < 3e-2 && return 26 + x < 5e-2 && return 32 + x < 1e-1 && return 50 + x < 1.9e-1 && return 100 + x > inv(MathConstants.e) && throw(DomainError(x)) # radius of convergence + return 290 # good for x approx .32 end # These may need tuning. -function branch_point_series(p::Complex{T}, z::Complex{T}) where T<:Real - x = abs(z) - x < 4e-11 && return wser3(p) - x < 1e-5 && return wser7(p) - x < 1e-3 && return wser12(p) - x < 1e-2 && return wser19(p) - x < 3e-2 && return wser26(p) - x < 5e-2 && return wser32(p) - x < 1e-1 && return wser50(p) - x < 1.9e-1 && return wser100(p) - x > 1 / MathConstants.e && throw(DomainError(x)) # radius of convergence - return wser290(p) -end +lambertwbp_series_length(z::Complex) = lambertwbp_series_length(abs(z)) + +# p is the argument to the series which is computed from x, +# see `_lambertwbp()`. +lambertwbp_series(p::Number, x::Number) = + lambertwbp_evalpoly(p, Val{lambertwbp_series_length(x)}()) _lambertwbp(x::Number, ::Val{0}) = - branch_point_series(sqrt(2 * MathConstants.e * x), x) + lambertwbp_series(sqrt(2 * MathConstants.e * x), x) _lambertwbp(x::Number, ::Val{-1}) = - branch_point_series(-sqrt(2 * MathConstants.e * x), x) + lambertwbp_series(-sqrt(2 * MathConstants.e * x), x) _lambertwbp(_::Number, k::Val) = throw(ArgumentError("lambertw() expansion about branch point for k=$k not implemented (only implemented for 0 and -1).")) From 6fc569f60f0c0c23e6c6523df698e792436326c6 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Fri, 31 Dec 2021 10:57:59 +0100 Subject: [PATCH 13/24] lambertw(): make maxiter a keyword --- src/lambertw.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lambertw.jl b/src/lambertw.jl index ecb7f697..cef4fd12 100644 --- a/src/lambertw.jl +++ b/src/lambertw.jl @@ -1,9 +1,9 @@ #### Lambert W function #### """ - lambertw(z::Number, k::Integer=0, maxits=1000) + lambertw(z::Number, k::Integer=0; [maxiter=1000]) -Compute the `k`th branch of the Lambert W function of `z`. If `z` is real, `k` must be +Compute the `k`th branch of the [Lambert W function](https://en.wikipedia.org/wiki/Lambert_W_function) of `z`. If `z` is real, `k` must be either `0` or `-1`. For `Real` `z`, the domain of the branch `k = -1` is `[-1/e, 0]` and the domain of the branch `k = 0` is `[-1/e, Inf]`. For `Complex` `z`, and all `k`, the domain is the complex plane. @@ -25,7 +25,7 @@ julia> lambertw(Complex(-10.0, 3.0), 4) -0.9274337508660128 + 26.37693445371142im ``` """ -lambertw(z::Number, k::Integer=0, maxits::Integer=1000) = _lambertw(float(z), k, maxits) +lambertw(z::Number, k::Integer=0; maxiter::Integer=1000) = _lambertw(z, k, maxiter) # lambertw(e + 0im, k) is ok for all k # Maybe this should return a float. But, this should cause no type instability in any case From 739a17c35f9a468daac30d1fdfe49501d4014c10 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Fri, 31 Dec 2021 10:58:36 +0100 Subject: [PATCH 14/24] lambertw: small fixes --- src/lambertw.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/lambertw.jl b/src/lambertw.jl index cef4fd12..7b46195d 100644 --- a/src/lambertw.jl +++ b/src/lambertw.jl @@ -100,7 +100,9 @@ function _lambertw(z::Complex{T}, k::Integer, maxits::Integer) where T<:Real k != 0 ? w += complex(0, k * 2 * pi) : nothing end elseif k == 0 && imag(z) <= pointseven && abs(z) <= pointseven - w = abs(z+ 1//2) < 1//10 ? imag(z) > 0 ? complex(pointseven, pointseven) : complex(pointseven, -pointseven) : z + w = abs(z+ 1//2) < 1//10 ? imag(z) > 0 ? + complex(pointseven, pointseven) : + complex(pointseven, -pointseven) : z else if real(z) == convert(T, Inf) k == 0 && return z @@ -118,16 +120,15 @@ end # Use Halley's root-finding method to find # x = lambertw(z) with initial point x0. function lambertw_root_finding(z::T, x0::T, maxits::Integer) where T <: Number - two_t = convert(T, 2) x = x0 lastx = x lastdiff = zero(real(T)) converged = false - for i in 1:maxits + for _ in 1:maxits ex = exp(x) xexz = x * ex - z x1 = x + 1 - x -= xexz / (ex * x1 - (x + two_t) * xexz / (two_t * x1 )) + x -= xexz / (ex * x1 - (x + 2) * xexz / (2 * x1)) xdiff = abs(lastx - x) if xdiff <= 3 * eps(lastdiff) || lastdiff == xdiff # second condition catches two-value cycle converged = true @@ -136,7 +137,7 @@ function lambertw_root_finding(z::T, x0::T, maxits::Integer) where T <: Number lastx = x lastdiff = xdiff end - converged || @warn("lambertw with z=", z, " did not converge in ", maxits, " iterations.") + converged || @warn("lambertw(", z, ") did not converge in ", maxits, " iterations.") return x end From e34e8d9f41c52327d544ba9f218b58e777ebae0d Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Fri, 31 Dec 2021 11:03:11 +0100 Subject: [PATCH 15/24] lambertwbp(): cleanup tests --- test/lambertw.jl | 57 +++++++++++++++++++++++------------------------- 1 file changed, 27 insertions(+), 30 deletions(-) diff --git a/test/lambertw.jl b/test/lambertw.jl index 9ba2e8e0..9801b806 100644 --- a/test/lambertw.jl +++ b/test/lambertw.jl @@ -8,6 +8,7 @@ ## math constant e @test_throws DomainError lambertw(MathConstants.e, 1) @test_throws DomainError lambertw(MathConstants.e, -1) +@test_throws DomainError lambertw(.3, 2) ## integer arguments return floating point types @test @inferred(lambertw(0)) isa AbstractFloat @@ -100,36 +101,32 @@ end @test lambertw(BigInt(1)) == big(SpecialFunctions.omega) ### expansion about branch point - -# not a domain error, but not implemented -@test_throws ArgumentError lambertwbp(1, 1) - -@test_throws DomainError lambertw(.3, 2) - -# Expansions about branch point converges almost to machine precision -# except near the radius of convergence. -# Complex args are not tested here. - -if Int != Int32 - -@testset "double-precision expansion near branch point using BigFloats" begin - setprecision(2048) do - z = BigFloat(10)^(-12) - for _ in 1:300 - innerarg = z - inv(big(MathConstants.e)) - - @test lambertwbp(Float64(z)) ≈ 1 + lambertw(innerarg) atol=5e-16 - @test lambertwbp(Float64(z), -1) ≈ 1 + lambertw(innerarg, -1) atol=5e-16 - z *= 1.1 - if z > 0.23 break end - +@testset "lambertwbp()" begin + # not a domain error, but not implemented + @test_throws ArgumentError lambertwbp(1, 1) + @test_throws ArgumentError lambertwbp(inv(MathConstants.e) + 1e-5, 2) + @test_throws DomainError lambertwbp(inv(MathConstants.e) + 1e-5, 0) + @test_throws DomainError lambertwbp(inv(MathConstants.e) + 1e-5, -1) + + # Expansions about branch point converges almost to machine precision + # except near the radius of convergence. + # Complex args are not tested here. + + @testset "double-precision expansion near branch point using BigFloats" begin + setprecision(2048) do + z = BigFloat(10)^(-12) + for _ in 1:300 + @test lambertwbp(Float64(z)) ≈ 1 + lambertw(z - inv(big(MathConstants.e))) atol=5e-16 + @test lambertwbp(Float64(z), -1) ≈ 1 + lambertw(z - inv(big(MathConstants.e)), -1) atol=5e-16 + + z *= 1.1 + if z > 0.23 break end + end end end -end - -# test the expansion about branch point for k=-1, -# by comparing to exact BigFloat calculation. -@test @inferred(lambertwbp(1e-20, -1)) ≈ 1 + lambertw(-inv(big(MathConstants.e)) + BigFloat(10)^(-20), -1) atol=1e-16 -@test @inferred(lambertwbp(Complex(.01, .01), -1)) ≈ Complex(-0.2755038208041206, -0.1277888928494641) atol=1e-14 -end # if Int != Int32 + # test the expansion about branch point for k=-1, + # by comparing to exact BigFloat calculation. + @test @inferred(lambertwbp(1e-20, -1)) ≈ 1 + lambertw(-inv(big(MathConstants.e)) + BigFloat(10)^(-20), -1) atol=1e-16 + @test @inferred(lambertwbp(Complex(.01, .01), -1)) ≈ Complex(-0.2755038208041206, -0.1277888928494641) atol=1e-14 +end From 4f8b18b8dc59a8900a8ee067291309971fae0d67 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Fri, 31 Dec 2021 11:04:33 +0100 Subject: [PATCH 16/24] lambertw(): use inve constant in tests --- test/lambertw.jl | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/test/lambertw.jl b/test/lambertw.jl index 9801b806..fcdd845b 100644 --- a/test/lambertw.jl +++ b/test/lambertw.jl @@ -1,4 +1,5 @@ ### domain errors +using IrrationalConstants @test_throws DomainError lambertw(-2.0, 0) @test_throws DomainError lambertw(-2.0, -1) @@ -21,8 +22,8 @@ @test @inferred(lambertw(MathConstants.e, 0)) == 1 ## value at branch point where real branches meet -@test lambertw(-inv(MathConstants.e), 0) == lambertw(-inv(MathConstants.e), -1) == -1 -@test typeof(lambertw(-inv(MathConstants.e), 0)) == typeof(lambertw(-inv(MathConstants.e), -1)) <: AbstractFloat +@test lambertw(-inve, 0) == lambertw(-inve, -1) == -1 +@test typeof(lambertw(-inve, 0)) == typeof(lambertw(-inve, -1)) <: AbstractFloat ## convert irrationals to float @@ -76,7 +77,7 @@ end # bug fix # The routine will start at -1/e + eps * im, rather than -1/e + 0im, # otherwise root finding will fail -@test lambertw(-inv(MathConstants.e) + 0im, -1) ≈ -1 atol=1e-7 +@test lambertw(-inve + 0im, -1) ≈ -1 atol=1e-7 # lambertw for BigFloat is more precise than Float64. Note # that 70 digits in test is about 35 digits in W @@ -104,9 +105,9 @@ end @testset "lambertwbp()" begin # not a domain error, but not implemented @test_throws ArgumentError lambertwbp(1, 1) - @test_throws ArgumentError lambertwbp(inv(MathConstants.e) + 1e-5, 2) - @test_throws DomainError lambertwbp(inv(MathConstants.e) + 1e-5, 0) - @test_throws DomainError lambertwbp(inv(MathConstants.e) + 1e-5, -1) + @test_throws ArgumentError lambertwbp(inve + 1e-5, 2) + @test_throws DomainError lambertwbp(inve + 1e-5, 0) + @test_throws DomainError lambertwbp(inve + 1e-5, -1) # Expansions about branch point converges almost to machine precision # except near the radius of convergence. @@ -116,8 +117,8 @@ end setprecision(2048) do z = BigFloat(10)^(-12) for _ in 1:300 - @test lambertwbp(Float64(z)) ≈ 1 + lambertw(z - inv(big(MathConstants.e))) atol=5e-16 - @test lambertwbp(Float64(z), -1) ≈ 1 + lambertw(z - inv(big(MathConstants.e)), -1) atol=5e-16 + @test lambertwbp(Float64(z)) ≈ 1 + lambertw(z - big(inve)) atol=5e-16 + @test lambertwbp(Float64(z), -1) ≈ 1 + lambertw(z - big(inve), -1) atol=5e-15 z *= 1.1 if z > 0.23 break end @@ -127,6 +128,6 @@ end # test the expansion about branch point for k=-1, # by comparing to exact BigFloat calculation. - @test @inferred(lambertwbp(1e-20, -1)) ≈ 1 + lambertw(-inv(big(MathConstants.e)) + BigFloat(10)^(-20), -1) atol=1e-16 + @test @inferred(lambertwbp(1e-20, -1)) ≈ 1 + lambertw(-big(inve) + BigFloat(10)^(-20), -1) atol=1e-16 @test @inferred(lambertwbp(Complex(.01, .01), -1)) ≈ Complex(-0.2755038208041206, -0.1277888928494641) atol=1e-14 end From c90b502f328b18860007dcdcdd8b2cd933c84d74 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 3 Jan 2022 15:11:55 +0100 Subject: [PATCH 17/24] rename omega => LambertW.Omega and update its def using the code from JuliaMath/IrrationalConstants.jl/issues/12 No __init__() section is required --- docs/src/functions_list.md | 2 +- src/SpecialFunctions.jl | 10 ++----- src/lambertw.jl | 60 ++++++++++++++++---------------------- test/lambertw.jl | 35 +++++++++++++++------- 4 files changed, 52 insertions(+), 55 deletions(-) diff --git a/docs/src/functions_list.md b/docs/src/functions_list.md index d8583f16..8757b66c 100644 --- a/docs/src/functions_list.md +++ b/docs/src/functions_list.md @@ -71,5 +71,5 @@ SpecialFunctions.logabsbeta SpecialFunctions.logabsbinomial SpecialFunctions.lambertw SpecialFunctions.lambertwbp -SpecialFunctions.omega +SpecialFunctions.LambertW.Ω ``` diff --git a/src/SpecialFunctions.jl b/src/SpecialFunctions.jl index 1090abb2..fe9aed52 100644 --- a/src/SpecialFunctions.jl +++ b/src/SpecialFunctions.jl @@ -79,14 +79,8 @@ export cosint, lbinomial, lambertw, - lambertwbp - -const omega_const_bf_ = Ref{BigFloat}() -function __init__() - # allocate storage for this BigFloat constant each time this module is loaded - omega_const_bf_[] = - parse(BigFloat,"0.5671432904097838729999686622103555497538157871865125081351310792230457930866845666932194") -end + lambertwbp, + LambertW include("bessel.jl") include("erf.jl") diff --git a/src/lambertw.jl b/src/lambertw.jl index 7b46195d..43a99489 100644 --- a/src/lambertw.jl +++ b/src/lambertw.jl @@ -37,7 +37,7 @@ _lambertw(x::Irrational, k::Integer, maxits::Integer) = _lambertw(float(x), k, m function _lambertw(x::Union{Integer, Rational}, k::Integer, maxits::Integer) if k == 0 x == 0 && return float(zero(x)) - x == 1 && return convert(typeof(float(x)), omega) # must be a more efficient way + x == 1 && return convert(typeof(float(x)), LambertW.Omega) # must be a more efficient way end return _lambertw(float(x), k, maxits) end @@ -141,54 +141,44 @@ function lambertw_root_finding(z::T, x0::T, maxits::Integer) where T <: Number return x end -### omega constant +### Lambert's Omega constant -const _omega_const = 0.567143290409783872999968662210355 - -# The BigFloat `omega_const_bf_` is set via a literal in the function __init__ to prevent a segfault - -# compute omega constant via root finding -# We could compute higher precision. This converges very quickly. -function omega_const(::Type{BigFloat}) - precision(BigFloat) <= 256 && return omega_const_bf_[] +# compute BigFloat Omega constant at arbitrary precision +function compute_lambertw_Omega() + oc = BigFloat("0.5671432904097838729999686622103555497538157871865125081351310792230457930866845666932194") + precision(oc) <= 256 && return oc + # iteratively improve the precision + # see https://en.wikipedia.org/wiki/Omega_constant#Computation myeps = eps(BigFloat) - oc = omega_const_bf_[] - for i in 1:100 + for _ in 1:1000 nextoc = (1 + oc) / (1 + exp(oc)) - abs(oc - nextoc) <= myeps && break + abs(oc - nextoc) <= myeps && return oc oc = nextoc end + @warn "Omega precision is less than current BigFloat precision ($(precision(BigFloat)))" return oc end -""" - omega - ω +# "private" declaration of Omega constant +Base.@irrational lambertw_Omega 0.567143290409783872999968662210355 compute_lambertw_Omega() -The constant defined by `ω exp(ω) = 1`. +module LambertW -# Example -```jldoctest -julia> ω -ω = 0.5671432904097... - -julia> omega -ω = 0.5671432904097... +""" +Lambert's Omega (*Ω*) constant. -julia> ω * exp(ω) -1.0 +Lambert's *Ω* is the solution to *W(Ω) = 1* equation, +where *W(t) = t exp(t)* is the +[Lambert's *W* function](https://en.wikipedia.org/wiki/Lambert_W_function). -julia> big(omega) -5.67143290409783872999968662210355549753815787186512508135131079223045793086683e-01 -``` +# See also + * https://en.wikipedia.org/wiki/Omega_constant + * [`lambertw()`][@ref SpecialFunctions.lambertw] """ -const ω = Irrational{:ω}() -@doc (@doc ω) omega = ω +const Ω = Irrational{:lambertw_Omega}() +const Omega = Ω # ASCII alias -Base.Float64(::Irrational{:ω}) = _omega_const -Base.Float32(::Irrational{:ω}) = Float32(_omega_const) -Base.Float16(::Irrational{:ω}) = Float16(_omega_const) -Base.BigFloat(::Irrational{:ω}) = omega_const(BigFloat) +end ### Expansion about branch point x = -1/e diff --git a/test/lambertw.jl b/test/lambertw.jl index fcdd845b..1b6f33d8 100644 --- a/test/lambertw.jl +++ b/test/lambertw.jl @@ -19,6 +19,8 @@ using IrrationalConstants # could return math const e, but this would break type stability @test @inferred(lambertw(1)) isa AbstractFloat +@test @inferred(lambertw(1)) == float(LambertW.Omega) +@test @inferred(lambertw(big(1))) == big(LambertW.Omega) @test @inferred(lambertw(MathConstants.e, 0)) == 1 ## value at branch point where real branches meet @@ -87,19 +89,30 @@ end @test z ≈ W*exp(W) atol=BigFloat(10)^(-70) end -### ω constant +@testset "LambertW.Omega" begin + @test isapprox(LambertW.Ω * exp(LambertW.Ω), 1) + @test LambertW.Omega === LambertW.Ω -## get ω from recursion and compare to value from lambertw -let sp = precision(BigFloat) - setprecision(512) - @test lambertw(big(1)) == big(SpecialFunctions.omega) - setprecision(sp) -end + # lower than default precision + setprecision(BigFloat, 196) do + o = big(LambertW.Ω) + @test precision(o) == 196 + @test isapprox(o * exp(o), 1, atol=eps(BigFloat)) + + oalias = big(LambertW.Omega) + @test o == oalias + end -@test lambertw(1) == float(SpecialFunctions.omega) -@test convert(Float16, SpecialFunctions.omega) == convert(Float16, 0.5674) -@test convert(Float32, SpecialFunctions.omega) == 0.56714326f0 -@test lambertw(BigInt(1)) == big(SpecialFunctions.omega) + # higher than default precision + setprecision(BigFloat, 2048) do + o = big(LambertW.Ω) + @test precision(o) == 2048 + @test isapprox(o * exp(o), 1, atol=eps(BigFloat)) + + oalias = big(LambertW.Omega) + @test o == oalias + end +end ### expansion about branch point @testset "lambertwbp()" begin From 87d08281ec8bc2c3e0347de43ac1be230fae36b8 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sat, 1 Jan 2022 11:49:58 +0100 Subject: [PATCH 18/24] lambertw(): more inference tests --- test/lambertw.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/lambertw.jl b/test/lambertw.jl index 1b6f33d8..6c7051a0 100644 --- a/test/lambertw.jl +++ b/test/lambertw.jl @@ -34,15 +34,15 @@ using IrrationalConstants ### infinite args or return values -@test lambertw(0, -1) == lambertw(0.0, -1) == -Inf -@test lambertw(Inf, 0) == Inf +@test @inferred(lambertw(0, -1)) == @inferred(lambertw(0.0, -1)) == -Inf +@test @inferred(lambertw(Inf, 0)) == Inf @test @inferred(lambertw(complex(Inf, 1), 0)) == complex(Inf, 1) -@test lambertw(complex(Inf, 0), 1) == complex(Inf, 2pi) -@test lambertw(complex(-Inf, 0), 1) == complex(Inf, 3pi) +@test @inferred(lambertw(complex(Inf, 0), 1)) == complex(Inf, 2pi) +@test @inferred(lambertw(complex(-Inf, 0), 1)) == complex(Inf, 3pi) @test @inferred(lambertw(complex(0.0, 0.0), -1)) == complex(-Inf, 0.0) -## default branch is k = 0 -@test lambertw(1.0) == lambertw(1.0, 0) +## default branch is k = 0 +@test @inferred(lambertw(1.0)) == @inferred(lambertw(1.0, 0)) ## BigInt args return BigFloats @test @inferred(lambertw(BigInt(0))) isa BigFloat From adf29b492e6663c8cc64a2840bae8b04f8de87fe Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sat, 1 Jan 2022 22:45:26 +0100 Subject: [PATCH 19/24] lambertw: improve test precision --- test/lambertw.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/lambertw.jl b/test/lambertw.jl index 6c7051a0..60df57e7 100644 --- a/test/lambertw.jl +++ b/test/lambertw.jl @@ -131,7 +131,7 @@ end z = BigFloat(10)^(-12) for _ in 1:300 @test lambertwbp(Float64(z)) ≈ 1 + lambertw(z - big(inve)) atol=5e-16 - @test lambertwbp(Float64(z), -1) ≈ 1 + lambertw(z - big(inve), -1) atol=5e-15 + @test lambertwbp(Float64(z), -1) ≈ 1 + lambertw(z - big(inve), -1) atol=1e-15 z *= 1.1 if z > 0.23 break end @@ -142,5 +142,5 @@ end # test the expansion about branch point for k=-1, # by comparing to exact BigFloat calculation. @test @inferred(lambertwbp(1e-20, -1)) ≈ 1 + lambertw(-big(inve) + BigFloat(10)^(-20), -1) atol=1e-16 - @test @inferred(lambertwbp(Complex(.01, .01), -1)) ≈ Complex(-0.2755038208041206, -0.1277888928494641) atol=1e-14 + @test @inferred(lambertwbp(Complex(.01, .01), -1)) ≈ Complex(-0.27550382080412062443536, -0.12778889284946406573511) atol=1e-16 end From d277be3f2a9c07cab878cb040def23abaa047356 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 3 Jan 2022 13:15:51 +0100 Subject: [PATCH 20/24] lambertw(): use inve constant --- src/SpecialFunctions.jl | 3 ++- src/lambertw.jl | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/SpecialFunctions.jl b/src/SpecialFunctions.jl index fe9aed52..157e2cf3 100644 --- a/src/SpecialFunctions.jl +++ b/src/SpecialFunctions.jl @@ -11,7 +11,8 @@ using IrrationalConstants: invsqrt2π, logtwo, logπ, - log2π + log2π, + inve import LogExpFunctions diff --git a/src/lambertw.jl b/src/lambertw.jl index 43a99489..90fb5233 100644 --- a/src/lambertw.jl +++ b/src/lambertw.jl @@ -52,7 +52,7 @@ _lambertw(x::Real, k::Integer, maxits::Integer) = _lambertw(x, Val(Int(k)), maxi function _lambertw(x::T, ::Val{0}, maxits::Integer) where T<:Real isfinite(x) || return x one_t = one(T) - oneoe = -inv(convert(T, MathConstants.e)) # The branch point + oneoe = -T(inve) # The branch point x == oneoe && return -one_t oneoe < x || throw(DomainError(x)) itwo_t = 1 / convert(T, 2) @@ -68,7 +68,7 @@ end # Real x, k = -1 function _lambertw(x::T, ::Val{-1}, maxits::Integer) where T<:Real - oneoe = -inv(convert(T, MathConstants.e)) + oneoe = -T(inve) x == oneoe && return -one(T) # W approaches -1 as x -> -1/e from above oneoe < x || throw(DomainError(x)) # branch domain exludes x < -1/e x == zero(T) && return -convert(T, Inf) # W decreases w/o bound as x -> 0 from below @@ -86,7 +86,7 @@ _lambertw(z::Complex{<:Integer}, k::Integer, maxits::Integer) = _lambertw(float( function _lambertw(z::Complex{T}, k::Integer, maxits::Integer) where T<:Real local w::Complex{T} pointseven = 7//10 - if abs(z) <= inv(convert(T, MathConstants.e)) + if abs(z) <= T(inve) if z == 0 k == 0 && return z return complex(-convert(T, Inf), zero(T)) @@ -238,7 +238,7 @@ function lambertwbp_series_length(x::Real) x < 5e-2 && return 32 x < 1e-1 && return 50 x < 1.9e-1 && return 100 - x > inv(MathConstants.e) && throw(DomainError(x)) # radius of convergence + x > typeof(x)(inve) && throw(DomainError(x)) # radius of convergence return 290 # good for x approx .32 end From 5b38fec7f52fd85cd7b6218d588f3f6945ed04c2 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 3 Jan 2022 14:01:43 +0100 Subject: [PATCH 21/24] use Compat to support evalpoly() on Julia 1.3 --- Project.toml | 2 ++ src/SpecialFunctions.jl | 2 ++ 2 files changed, 4 insertions(+) diff --git a/Project.toml b/Project.toml index cf84beb2..a0bde5eb 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "2.2.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" OpenLibm_jll = "05823500-19ac-5b8b-9628-191a04bc5112" @@ -18,6 +19,7 @@ SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" [compat] ChainRulesCore = "0.9.44, 0.10, 1" ChainRulesTestUtils = "0.6.8, 0.7, 1" +Compat = "3.7, 4" IrrationalConstants = "0.1, 0.2" LogExpFunctions = "0.3.2" OpenLibm_jll = "0.7, 0.8" diff --git a/src/SpecialFunctions.jl b/src/SpecialFunctions.jl index 157e2cf3..73553883 100644 --- a/src/SpecialFunctions.jl +++ b/src/SpecialFunctions.jl @@ -1,5 +1,7 @@ module SpecialFunctions +using Compat + using IrrationalConstants: twoπ, halfπ, From afa1d0ec8180dd2b4c5846eb1b69659c94cd9c9d Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 3 Jan 2022 13:40:18 +0100 Subject: [PATCH 22/24] DONTMERGE temporarily define inve here --- src/SpecialFunctions.jl | 6 ++++-- test/lambertw.jl | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/SpecialFunctions.jl b/src/SpecialFunctions.jl index 73553883..2725a946 100644 --- a/src/SpecialFunctions.jl +++ b/src/SpecialFunctions.jl @@ -13,8 +13,10 @@ using IrrationalConstants: invsqrt2π, logtwo, logπ, - log2π, - inve + log2π + +# FIXME temporary until the fate of inve is decided +Base.@irrational inve 0.367879441171442321595 inv(big(ℯ)) import LogExpFunctions diff --git a/test/lambertw.jl b/test/lambertw.jl index 60df57e7..32ba7c1d 100644 --- a/test/lambertw.jl +++ b/test/lambertw.jl @@ -1,5 +1,6 @@ ### domain errors -using IrrationalConstants +using SpecialFunctions: inve # FIXME temporary until the fate of inve is decided +#using IrrationalConstants @test_throws DomainError lambertw(-2.0, 0) @test_throws DomainError lambertw(-2.0, -1) From b0c63e56707237b7ffda54a2bd3ebff0b785df67 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 3 Jan 2022 14:32:50 +0100 Subject: [PATCH 23/24] lambertw: test maxiter= and convergence warning --- src/lambertw.jl | 2 +- test/lambertw.jl | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/lambertw.jl b/src/lambertw.jl index 90fb5233..8ef0c46d 100644 --- a/src/lambertw.jl +++ b/src/lambertw.jl @@ -137,7 +137,7 @@ function lambertw_root_finding(z::T, x0::T, maxits::Integer) where T <: Number lastx = x lastdiff = xdiff end - converged || @warn("lambertw(", z, ") did not converge in ", maxits, " iterations.") + converged || @warn "lambertw($z) did not converge in $maxits iterations." return x end diff --git a/test/lambertw.jl b/test/lambertw.jl index 32ba7c1d..354da1ed 100644 --- a/test/lambertw.jl +++ b/test/lambertw.jl @@ -77,6 +77,12 @@ end @test lambertw(complex(-3.0, -4.0), 1) ≈ Complex(0.5887666813694675, 2.7118802109452247) atol=1e-14 @test lambertw(complex(.3, .3)) ≈ Complex(0.26763519642648767, 0.1837481231767825) +# test maxiter keyword and convergence warning +@test_logs (:warn, "lambertw(-0.2) did not converge in 3 iterations.") @inferred(lambertw(-0.2, -1, maxiter=3)) +@test lambertw(-0.2, -1, maxiter=5) == lambertw(-0.2, -1) +@test_logs (:warn, "lambertw(0.3 + 0.3im) did not converge in 3 iterations.") @inferred(lambertw(complex(.3, .3), maxiter=3)) +@test lambertw(complex(.3, .3), maxiter=5) == lambertw(complex(.3, .3)) + # bug fix # The routine will start at -1/e + eps * im, rather than -1/e + 0im, # otherwise root finding will fail From ed3c825a487121807e391e562d334ef2e834bff4 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 3 Jan 2022 14:43:50 +0100 Subject: [PATCH 24/24] lambertw: fix jldoctests --- src/lambertw.jl | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/lambertw.jl b/src/lambertw.jl index 8ef0c46d..a739f0cb 100644 --- a/src/lambertw.jl +++ b/src/lambertw.jl @@ -8,11 +8,11 @@ either `0` or `-1`. For `Real` `z`, the domain of the branch `k = -1` is `[-1/e, domain of the branch `k = 0` is `[-1/e, Inf]`. For `Complex` `z`, and all `k`, the domain is the complex plane. -```jldoctest -julia> lambertw(-1/e, -1) +```jldoctest; setup=:(using SpecialFunctions) +julia> lambertw(-1/ℯ, -1) -1.0 -julia> lambertw(-1/e, 0) +julia> lambertw(-1/ℯ, 0) -1.0 julia> lambertw(0, 0) @@ -268,20 +268,19 @@ The result is accurate to Float64 precision for abs(z) < 0.32. If `k=-1` and `imag(z) < 0`, the value on the branch `k=1` is returned. # Example -```jldoctest -julia> lambertw(-1/e + 1e-18, -1) +```jldoctest; setup=:(using SpecialFunctions) +julia> lambertw(-1/ℯ + 1e-18, -1) -1.0 julia> lambertwbp(1e-18, -1) -2.331643983409312e-9 -# Same result, but 1000 times slower -julia> convert(Float64, (lambertw(-BigFloat(1)/e + BigFloat(10)^(-18), -1) + 1)) +julia> convert(Float64, (lambertw(-big(1)/ℯ + big(10)^(-18), -1) + 1)) # Same result, but 1000 times slower -2.331643983409312e-9 ``` !!! note - `lambertwbp` uses a series expansion about the branch point `z=-1/e` to avoid loss of precision. + `lambertwbp` uses a series expansion about the branch point `z=-1/ℯ` to avoid loss of precision. The loss of precision in `lambertw` is analogous to the loss of precision in computing the `sqrt(1-x)` for `x` close to `1`. """