From 5c16731da2efacacf7ef5fd7766fe0c32e43e5c9 Mon Sep 17 00:00:00 2001 From: John Lapeyre Date: Wed, 28 Mar 2018 20:20:09 +0200 Subject: [PATCH 1/2] implement Lambert W function This copies the implementation of the Lambert W function and related functions and data from jlapeyre/LambertW.jl to SpecialFunctions.jl --- README.md | 2 +- docs/src/index.md | 1 + docs/src/special.md | 3 + src/SpecialFunctions.jl | 3 + src/lambertw.jl | 386 ++++++++++++++++++++++++++++++++++++++++ test/lambertw_test.jl | 177 ++++++++++++++++++ test/runtests.jl | 4 + 7 files changed, 575 insertions(+), 1 deletion(-) create mode 100644 src/lambertw.jl create mode 100644 test/lambertw_test.jl diff --git a/README.md b/README.md index dea48cb9..709e35e0 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # SpecialFunctions.jl Special mathematical functions in Julia, including Bessel, Hankel, Airy, error, Dawson, sine and cosine integrals, -eta, zeta, digamma, inverse digamma, trigamma, and polygamma functions. +eta, zeta, digamma, inverse digamma, trigamma, polygamma, and Lambert W functions. Most of these functions were formerly part of Base. Note: On Julia 0.7, this package downloads and/or builds diff --git a/docs/src/index.md b/docs/src/index.md index 4dead14b..d2f29f87 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -40,6 +40,7 @@ libraries. | [`besselix(nu,z)`](@ref SpecialFunctions.besselix) | scaled modified Bessel function of the first kind of order `nu` at `z` | | [`besselk(nu,z)`](@ref SpecialFunctions.besselk) | modified [Bessel function](https://en.wikipedia.org/wiki/Bessel_function) of the second kind of order `nu` at `z` | | [`besselkx(nu,z)`](@ref SpecialFunctions.besselkx) | scaled modified Bessel function of the second kind of order `nu` at `z` | +| [`lambertw(z,k)`](@ref SpecialFunctions.lambertw) | `k`th branch of the Lambert W function at `z` | ## Installation diff --git a/docs/src/special.md b/docs/src/special.md index f8bb48b1..c8e06e8b 100644 --- a/docs/src/special.md +++ b/docs/src/special.md @@ -46,4 +46,7 @@ SpecialFunctions.besselk SpecialFunctions.besselkx SpecialFunctions.eta SpecialFunctions.zeta +SpecialFunctions.lambertw +SpecialFunctions.lambertwbp +SpecialFunctions.omega ``` diff --git a/src/SpecialFunctions.jl b/src/SpecialFunctions.jl index 5994a194..ec251579 100644 --- a/src/SpecialFunctions.jl +++ b/src/SpecialFunctions.jl @@ -71,10 +71,13 @@ end export sinint, cosint +export lambertw, lambertwbp + include("bessel.jl") include("erf.jl") include("sincosint.jl") include("gamma.jl") +include("lambertw.jl") include("deprecated.jl") end # module diff --git a/src/lambertw.jl b/src/lambertw.jl new file mode 100644 index 00000000..b98f7aee --- /dev/null +++ b/src/lambertw.jl @@ -0,0 +1,386 @@ +import Base: convert +#export lambertw, lambertwbp +using Compat + +const euler = +if isdefined(Base, :MathConstants) + Base.MathConstants.e +else + e +end + +const omega_const_bf_ = Ref{BigFloat}() + +function __init__() + omega_const_bf_[] = + parse(BigFloat,"0.5671432904097838729999686622103555497538157871865125081351310792230457930866845666932194") +end + +#### Lambert W function #### + +const LAMBERTW_USE_NAN = false + +macro baddomain(v) + if LAMBERTW_USE_NAN + return :(return(NaN)) + else + return esc(:(throw(DomainError($v)))) + end +end + +# Use Halley's root-finding method to find x = lambertw(z) with +# initial point x. +function _lambertw(z::T, x::T) where T <: Number + two_t = convert(T,2) + lastx = x + lastdiff = zero(T) + for i in 1:100 + ex = exp(x) + xexz = x * ex - z + x1 = x + 1 + x = x - xexz / (ex * x1 - (x + two_t) * xexz / (two_t * x1 ) ) + xdiff = abs(lastx - x) + xdiff <= 2*eps(abs(lastx)) && break + lastdiff == diff && break + lastx = x + lastdiff = xdiff + end + x +end + +### Real z ### + +# Real x, k = 0 + +# The fancy initial condition selection does not seem to help speed, but we leave it for now. +function lambertwk0(x::T)::T where T<:AbstractFloat + x == Inf && return Inf + one_t = one(T) + oneoe = -one_t/convert(T,euler) + x == oneoe && return -one_t + itwo_t = 1/convert(T,2) + oneoe <= x || @baddomain(x) + if x > one_t + lx = log(x) + llx = log(lx) + x1 = lx - llx - log(one_t - llx/lx) * itwo_t + else + x1 = (567//1000) * x + end + _lambertw(x,x1) +end + +# Real x, k = -1 +function _lambertwkm1(x::T) where T<:Real + oneoe = -one(T)/convert(T,euler) + x == oneoe && return -one(T) + oneoe <= x || @baddomain(x) + x == zero(T) && return -convert(T,Inf) + x < zero(T) || @baddomain(x) + _lambertw(x,log(-x)) +end + + +""" + lambertw(z::Complex{T}, k::V=0) where {T<:Real, V<:Integer} + lambertw(z::T, k::V=0) where {T<:Real, V<:Integer} + +Compute the `k`th branch of the Lambert W function of `z`. If `z` is real, `k` must be +either `0` or `-1`. For `Real` `z`, the domain of the branch `k = -1` is `[-1/e,0]` and the +domain of the branch `k = 0` is `[-1/e,Inf]`. For `Complex` `z`, and all `k`, the domain is +the complex plane. + +```jldoctest +julia> lambertw(-1/e,-1) +-1.0 + +julia> lambertw(-1/e,0) +-1.0 + +julia> lambertw(0,0) +0.0 + +julia> lambertw(0,-1) +-Inf + +julia> lambertw(Complex(-10.0,3.0), 4) +-0.9274337508660128 + 26.37693445371142im +``` + +!!! note + The constant `LAMBERTW_USE_NAN` at the top of the source file controls whether arguments + outside the domain throw `DomainError` or return `NaN`. The default is `DomainError`. +""" +function lambertw(x::Real, k::Integer) + k == 0 && return lambertwk0(x) + k == -1 && return _lambertwkm1(x) + @baddomain(k) # more informative message like below ? +# error("lambertw: real x must have k == 0 or k == -1") +end + +function lambertw(x::Union{Integer,Rational}, k::Integer) + if k == 0 + x == 0 && return float(zero(x)) + x == 1 && return convert(typeof(float(x)),omega) # must be more efficient way + end + lambertw(float(x),k) +end + +### Complex z ### + +# choose initial value inside correct branch for root finding +function lambertw(z::Complex{T}, k::Integer) where T<:Real + one_t = one(T) + local w::Complex{T} + pointseven = 7//10 + if abs(z) <= one_t/convert(T,euler) + if z == 0 + k == 0 && return z + return complex(-convert(T,Inf),zero(T)) + end + if k == 0 + w = z + elseif k == -1 && imag(z) == 0 && real(z) < 0 + w = complex(log(-real(z)),1//10^7) # need offset for z ≈ -1/e. + else + w = log(z) + k != 0 ? w += complex(0,k * 2 * pi) : nothing + end + elseif k == 0 && imag(z) <= pointseven && abs(z) <= pointseven + w = abs(z+ 1//2) < 1//10 ? imag(z) > 0 ? complex(pointseven,pointseven) : complex(pointseven,-pointseven) : z + else + if real(z) == convert(T,Inf) + k == 0 && return z + return z + complex(0,2*k*pi) + end + real(z) == -convert(T,Inf) && return -z + complex(0,(2*k+1)*pi) + w = log(z) + k != 0 ? w += complex(0, 2*k*pi) : nothing + end + _lambertw(z,w) +end + +lambertw(z::Complex{T}, k::Integer) where T<:Integer = lambertw(float(z),k) + +# lambertw(e + 0im,k) is ok for all k +#function lambertw(::Irrational{:e}, k::T) where T<:Integer +function lambertw(::typeof(euler), k::T) where T<:Integer + k == 0 && return 1 + @baddomain(k) +end + +# Maybe this should return a float +lambertw(::typeof(euler)) = 1 +#lambertw(::Irrational{:e}) = 1 + +#lambertw{T<:Number}(x::T) = lambertw(x,0) +lambertw(x::Number) = lambertw(x,0) + +lambertw(n::Irrational, args::Integer...) = lambertw(float(n),args...) + +### omega constant ### + +const omega_const_ = 0.567143290409783872999968662210355 +# The BigFloat `omega_const_bf_` is set via a literal in the function __init__ to prevent a segfault + +# maybe compute higher precision. converges very quickly +function omega_const(::Type{BigFloat}) + @compat precision(BigFloat) <= 256 && return omega_const_bf_[] + myeps = eps(BigFloat) + oc = omega_const_bf_[] + for i in 1:100 + nextoc = (1 + oc) / (1 + exp(oc)) + abs(oc - nextoc) <= myeps && break + oc = nextoc + end + return oc +end + +""" + omega + ω + +A constant defined by `ω exp(ω) = 1`. + +```jldoctest +julia> ω +ω = 0.5671432904097... + +julia> omega +ω = 0.5671432904097... + +julia> ω * exp(ω) +1.0 + +julia> big(omega) +5.67143290409783872999968662210355549753815787186512508135131079223045793086683e-01 +``` +""" +const ω = Irrational{:ω}() +@doc (@doc ω) omega = ω + +# The following three lines may be removed when support for v0.6 is dropped +Base.convert(::Type{AbstractFloat}, o::Irrational{:ω}) = Float64(o) +Base.convert(::Type{Float16}, o::Irrational{:ω}) = Float16(o) +Base.convert(::Type{T}, o::Irrational{:ω}) where T <:Number = T(o) + +Base.Float64(::Irrational{:ω}) = omega_const_ # FIXME: This is very slow. Why ? +Base.Float32(::Irrational{:ω}) = Float32(omega_const_) +Base.Float16(::Irrational{:ω}) = Float16(omega_const_) +Base.BigFloat(o::Irrational{:ω}) = omega_const(BigFloat) + +### Expansion about branch point x = -1/e ### + +# Refer to the paper "On the Lambert W function". In (4.22) +# coefficients μ₀ through μ₃ are given explicitly. Recursion relations +# (4.23) and (4.24) for all μ are also given. This code implements the +# recursion relations. + +# (4.23) and (4.24) give zero based coefficients +cset(a,i,v) = a[i+1] = v +cget(a,i) = a[i+1] + +# (4.24) +function compa(k,m,a) + sum0 = zero(eltype(m)) + for j in 2:k-1 + sum0 += cget(m,j) * cget(m,k+1-j) + end + cset(a,k,sum0) + sum0 +end + +# (4.23) +function compm(k,m,a) + kt = convert(eltype(m),k) + mk = (kt-1)/(kt+1) *(cget(m,k-2)/2 + cget(a,k-2)/4) - + cget(a,k)/2 - cget(m,k-1)/(kt+1) + cset(m,k,mk) + mk +end + +# We plug the known value μ₂ == -1//3 for (4.22) into (4.23) and +# solve for α₂. We get α₂ = 0. +# compute array of coefficients μ in (4.22). +# m[1] is μ₀ +function lamwcoeff(T::DataType, n::Int) + # a = @compat Array{T}(undef,n) + # m = @compat Array{T}(undef,n) + a = zeros(T,n) # We don't need initialization, but Compat is a huge PITA. + m = zeros(T,n) + cset(a,0,2) # α₀ literal in paper + cset(a,1,-1) # α₁ literal in paper + cset(a,2,0) # α₂ get this by solving (4.23) for alpha_2 with values printed in paper + cset(m,0,-1) # μ₀ literal in paper + cset(m,1,1) # μ₁ literal in paper + cset(m,2,-1//3) # μ₂ literal in paper, but only in (4.22) + for i in 3:n-1 # coeffs are zero indexed + compa(i,m,a) + compm(i,m,a) + end + return m +end + +const LAMWMU_FLOAT64 = lamwcoeff(Float64,500) + +function horner(x, p::AbstractArray,n) + n += 1 + ex = p[n] + for i = n-1:-1:2 + ex = :($(p[i]) + t * $ex) + end + ex = :( t * $ex) + Expr(:block, :(t = $x), ex) +end + +function mkwser(name, n) + iex = horner(:x,LAMWMU_FLOAT64,n) + :(function ($name)(x) $iex end) +end + +eval(mkwser(:wser3, 3)) +eval(mkwser(:wser5, 5)) +eval(mkwser(:wser7, 7)) +eval(mkwser(:wser12, 12)) +eval(mkwser(:wser19, 19)) +eval(mkwser(:wser26, 26)) +eval(mkwser(:wser32, 32)) +eval(mkwser(:wser50, 50)) +eval(mkwser(:wser100, 100)) +eval(mkwser(:wser290, 290)) + +# Converges to Float64 precision +# We could get finer tuning by separating k=0,-1 branches. +function wser(p,x) + x < 4e-11 && return wser3(p) + x < 1e-5 && return wser7(p) + x < 1e-3 && return wser12(p) + x < 1e-2 && return wser19(p) + x < 3e-2 && return wser26(p) + x < 5e-2 && return wser32(p) + x < 1e-1 && return wser50(p) + x < 1.9e-1 && return wser100(p) + x > 1/euler && @baddomain(x) # radius of convergence + return wser290(p) # good for x approx .32 +end + +# These may need tuning. +function wser(p::Complex{T},z) where T<:Real + x = abs(z) + x < 4e-11 && return wser3(p) + x < 1e-5 && return wser7(p) + x < 1e-3 && return wser12(p) + x < 1e-2 && return wser19(p) + x < 3e-2 && return wser26(p) + x < 5e-2 && return wser32(p) + x < 1e-1 && return wser50(p) + x < 1.9e-1 && return wser100(p) + x > 1/euler && @baddomain(x) # radius of convergence + return wser290(p) +end + +@inline function _lambertw0(x) # 1 + W(-1/e + x) , k = 0 + ps = 2*euler*x; + p = sqrt(ps) + wser(p,x) +end + +@inline function _lambertwm1(x) # 1 + W(-1/e + x) , k = -1 + ps = 2*euler*x; + p = -sqrt(ps) + wser(p,x) +end + +""" + lambertwbp(z,k=0) + +Accurate value of `1 + W(-1/e + z)`, for `abs(z)` in `[0,1/e]` for `k` either `0` or `-1`. +Accurate to Float64 precision for abs(z) < 0.32. +If `k=-1` and `imag(z) < 0`, the value on the branch `k=1` is returned. `lambertwbp` is vectorized. + +```jldoctest +julia> lambertw(-1/e + 1e-18, -1) +-1.0 + +julia> lambertwbp(1e-18, -1) +-2.331643983409312e-9 + +# Same result, but 1000 times slower +julia> convert(Float64,(lambertw(-BigFloat(1)/e + BigFloat(10)^(-18),-1) + 1)) +-2.331643983409312e-9 +``` + +!!! note + `lambertwbp` uses a series expansion about the branch point `z=-1/e` to avoid loss of precision. + The loss of precision in `lambertw` is analogous to the loss of precision + in computing the `sqrt(1-x)` for `x` close to `1`. +""" +function lambertwbp(x::Number,k::Integer) + k == 0 && return _lambertw0(x) + k == -1 && return _lambertwm1(x) + error("expansion about branch point only implemented for k = 0 and -1") +end + +lambertwbp(x::Number) = _lambertw0(x) + +nothing diff --git a/test/lambertw_test.jl b/test/lambertw_test.jl new file mode 100644 index 00000000..7614d9d2 --- /dev/null +++ b/test/lambertw_test.jl @@ -0,0 +1,177 @@ +using Compat + +macro test_baddomain(expr) + if SpecialFunctions.LAMBERTW_USE_NAN + :(@test $(esc(expr)) === NaN) + else + :(@test_throws DomainError $(esc(expr))) + end +end + +const euler = +if isdefined(Base, :MathConstants) + Base.MathConstants.e +else + e +end + +### domain errors + +@test_baddomain lambertw(-2.0,0) +@test_baddomain lambertw(-2.0,-1) +@test_baddomain lambertw(-2.0,1) +@test_baddomain lambertw(NaN) + +## math constant e +@test_baddomain lambertw(euler,1) +@test_baddomain lambertw(euler,-1) + +## integer arguments return floating point types +@test typeof(lambertw(0)) <: AbstractFloat +@test lambertw(0) == 0 + +### math constant, euler e + +# could return math const e, but this would break type stability +@test typeof(lambertw(1)) <: AbstractFloat +@test lambertw(euler,0) == 1 + +## value at branch point where real branches meet +@test lambertw(-1/euler,0) == lambertw(-1/euler,-1) == -1 +@test typeof(lambertw(-1/euler,0)) == typeof(lambertw(-1/euler,-1)) <: AbstractFloat + +## convert irrationals to float + +@test isapprox(lambertw(pi), 1.0736581947961492) +@test isapprox(lambertw(pi,0), 1.0736581947961492) + +### infinite args or return values + +@test lambertw(0,-1) == lambertw(0.0,-1) == -Inf +@test lambertw(Inf,0) == Inf +@test lambertw(complex(Inf,1),0) == complex(Inf,1) +@test lambertw(complex(Inf,0),1) == complex(Inf,2pi) +@test lambertw(complex(-Inf,0),1) == complex(Inf,3pi) +@test lambertw(complex(0.0,0.0),-1) == complex(-Inf,0.0) + +## default branch is k = 0 +@test lambertw(1.0) == lambertw(1.0,0) + +## BigInt args return BigFloats +@test typeof(lambertw(BigInt(0))) == BigFloat +@test typeof(lambertw(BigInt(3))) == BigFloat + +## Any Integer type allowed for second argument +@test lambertw(-0.2,-1) == lambertw(-0.2,BigInt(-1)) + +## BigInt for second arg does not promote the type +@test typeof(lambertw(-0.2,-1)) == typeof(lambertw(-0.2,BigInt(-1))) + +for (z,k,res) in [ (0,0 ,0), (complex(0,0),0 ,0), + (complex(0.0,0),0 ,0), (complex(1.0,0),0, 0.567143290409783873) ] + if Int != Int32 + @test isapprox(lambertw(z,k), res) + @test isapprox(lambertw(z), res) + else + @test isapprox(lambertw(z,k), res; rtol = 1e-14) + @test isapprox(lambertw(z), res; rtol = 1e-14) + end +end + +for (z,k) in ((complex(1,1),2), (complex(1,1),0),(complex(.6,.6),0), + (complex(.6,-.6),0)) + let w + @test (w = lambertw(z,k) ; true) + @test abs(w*exp(w) - z) < 1e-15 + end +end + +@test abs(lambertw(complex(-3.0,-4.0),0) - Complex(1.075073066569255, -1.3251023817343588)) < 1e-14 +@test abs(lambertw(complex(-3.0,-4.0),1) - Complex(0.5887666813694675, 2.7118802109452247)) < 1e-14 +@test (lambertw(complex(.3,.3),0); true) + +# bug fix +# The routine will start at -1/e + eps * im, rather than -1/e + 0im, +# otherwise root finding will fail +if Int != Int32 + @test abs(lambertw(-1.0/euler + 0im,-1)) == 1 +else + @test abs(lambertw(-1.0/euler + 0im,-1) + 1) < 1e-7 +end +# lambertw for BigFloat is more precise than Float64. Note +# that 70 digits in test is about 35 digits in W +let W + for z in [ BigFloat(1), BigFloat(2), complex(BigFloat(1), BigFloat(1))] + @test (W = lambertw(z); true) + @test abs(z - W * exp(W)) < BigFloat(1)^(-70) + end +end + +### ω constant + +## get ω from recursion and compare to value from lambertw +let sp = precision(BigFloat) + @compat setprecision(512) + @test lambertw(big(1)) == big(SpecialFunctions.omega) + @compat setprecision(sp) +end + +@test lambertw(1) == float(SpecialFunctions.omega) +@test convert(Float16,SpecialFunctions.omega) == convert(Float16,0.5674) +@test convert(Float32,SpecialFunctions.omega) == 0.56714326f0 +@test lambertw(BigInt(1)) == big(SpecialFunctions.omega) + +### expansion about branch point + +# not a domain error, but not implemented +@test_throws ErrorException lambertwbp(1,1) + +@test_throws DomainError lambertw(.3,2) + +# Expansions about branch point converges almost to machine precision +# except near the radius of convergence. +# Complex args are not tested here. + +if Int != Int32 + +let sp = precision(BigFloat), z = BigFloat(1)/10^12, wo, diff + @compat setprecision(2048) + for i in 1:300 + # k = 0 + @test (wo = lambertwbp(Float64(z)); diff = abs(-1 + wo - lambertw(z-1/big(euler))); true) + if diff > 5e-16 + println(Float64(z), " ", Float64(diff)) + end + @test diff < 5e-16 + # k = -1 + @test (wo = lambertwbp(Float64(z),-1); diff = abs(-1 + wo - lambertw(z-1/big(euler),-1)); true) + if diff > 5e-16 + println(Float64(z), " ", Float64(diff)) + end + @test diff < 5e-16 + z *= 1.1 + if z > 0.23 break end + end + @compat setprecision(sp) +end + +# test the expansion about branch point for k=-1, +# by comparing to exact BigFloat calculation. +@test lambertwbp(1e-20,-1) - 1 - lambertw(-BigFloat(1)/big(euler)+ BigFloat(1)/BigFloat(10)^BigFloat(20),-1) < 1e-16 + +@test abs(lambertwbp(Complex(.01,.01),-1) - Complex(-0.2755038208041206, -0.1277888928494641)) < 1e-14 + +end + +## vectorization + +if VERSION >= v"0.5" + @test lambertw.([0.1,0.2]) == [lambertw(0.1),lambertw(0.2)] + @test lambertw.([0.1+im ,0.2-im]) == [lambertw(0.1+im),lambertw(0.2-im)] + @test lambertw.([0.1,-0.2],[0,-1]) == [lambertw(0.1,0),lambertw(-0.2,-1)] + @test lambertwbp.([.1,.2,.3],-1) == map(x -> lambertwbp(x,-1), [.1,.2,.3]) +else + @test lambertw([0.1,0.2]) == [lambertw(0.1),lambertw(0.2)] + @test lambertw([0.1+im ,0.2-im]) == [lambertw(0.1+im),lambertw(0.2-im)] + @test lambertw([0.1,-0.2],[0,-1]) == [lambertw(0.1,0),lambertw(-0.2,-1)] +end diff --git a/test/runtests.jl b/test/runtests.jl index c82a4cfa..4501b256 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,6 +20,10 @@ relerr(z, x) = z == x ? 0.0 : abs(z - x) / abs(x) relerrc(z, x) = max(relerr(real(z),real(x)), relerr(imag(z),imag(x))) ≅(a,b) = relerrc(a,b) ≤ 1e-13 +@testset "Lambert W" begin + include("lambertw_test.jl") +end + @testset "error functions" begin @test SF.erf(Float16(1)) ≈ 0.84270079294971486934 @test SF.erf(1) ≈ 0.84270079294971486934 From 9ef0205adab63d4e9242e2af2eb4fc01cea3a2e7 Mon Sep 17 00:00:00 2001 From: John Lapeyre Date: Thu, 19 Apr 2018 01:37:39 +0200 Subject: [PATCH 2/2] made changes request in PR review --- LICENSE | 32 ++++++++ src/SpecialFunctions.jl | 7 ++ src/lambertw.jl | 170 +++++++++++++++++----------------------- test/lambertw_test.jl | 93 +++++++++------------- 4 files changed, 146 insertions(+), 156 deletions(-) diff --git a/LICENSE b/LICENSE index 28e07032..b9a9755c 100644 --- a/LICENSE +++ b/LICENSE @@ -22,3 +22,35 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +Portions of this code are derived from SciPy and are licensed under +the Scipy License: + +> Copyright (c) 2001, 2002 Enthought, Inc. +> All rights reserved. + +> Copyright (c) 2003-2012 SciPy Developers. +> All rights reserved. + +> Redistribution and use in source and binary forms, with or without +> modification, are permitted provided that the following conditions are met: + +> a. Redistributions of source code must retain the above copyright notice, +> this list of conditions and the following disclaimer. +> b. Redistributions in binary form must reproduce the above copyright +> notice, this list of conditions and the following disclaimer in the +> documentation and/or other materials provided with the distribution. +> c. Neither the name of Enthought nor the names of the SciPy Developers +> may be used to endorse or promote products derived from this software +> without specific prior written permission. +> +> THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +> AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +> IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +> ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS +> BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +> OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +> SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +> INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +> CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +> ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +> THE POSSIBILITY OF SUCH DAMAGE. diff --git a/src/SpecialFunctions.jl b/src/SpecialFunctions.jl index ec251579..7ede35a4 100644 --- a/src/SpecialFunctions.jl +++ b/src/SpecialFunctions.jl @@ -73,6 +73,13 @@ export sinint, export lambertw, lambertwbp +const omega_const_bf_ = Ref{BigFloat}() +function __init__() + # allocate storage for this BigFloat constant each time this module is loaded + omega_const_bf_[] = + parse(BigFloat,"0.5671432904097838729999686622103555497538157871865125081351310792230457930866845666932194") +end + include("bessel.jl") include("erf.jl") include("sincosint.jl") diff --git a/src/lambertw.jl b/src/lambertw.jl index b98f7aee..5fdd9e74 100644 --- a/src/lambertw.jl +++ b/src/lambertw.jl @@ -1,65 +1,47 @@ import Base: convert -#export lambertw, lambertwbp -using Compat - -const euler = -if isdefined(Base, :MathConstants) - Base.MathConstants.e -else - e -end -const omega_const_bf_ = Ref{BigFloat}() - -function __init__() - omega_const_bf_[] = - parse(BigFloat,"0.5671432904097838729999686622103555497538157871865125081351310792230457930866845666932194") -end +using Compat +import Compat.MathConstants # For clarity, we use MathConstants.e for Euler's number #### Lambert W function #### -const LAMBERTW_USE_NAN = false - -macro baddomain(v) - if LAMBERTW_USE_NAN - return :(return(NaN)) - else - return esc(:(throw(DomainError($v)))) - end -end - -# Use Halley's root-finding method to find x = lambertw(z) with -# initial point x. -function _lambertw(z::T, x::T) where T <: Number +# Use Halley's root-finding method to find +# x = lambertw(z) with initial point x. +function _lambertw(z::T, x::T, maxits) where T <: Number two_t = convert(T,2) lastx = x lastdiff = zero(T) - for i in 1:100 + converged::Bool = false + for i in 1:maxits ex = exp(x) xexz = x * ex - z x1 = x + 1 - x = x - xexz / (ex * x1 - (x + two_t) * xexz / (two_t * x1 ) ) + x -= xexz / (ex * x1 - (x + two_t) * xexz / (two_t * x1 ) ) xdiff = abs(lastx - x) - xdiff <= 2*eps(abs(lastx)) && break - lastdiff == diff && break + if xdiff <= 3*eps(abs(lastx)) || lastdiff == xdiff # second condition catches two-value cycle + converged = true + break + end lastx = x lastdiff = xdiff end - x + converged || warn("lambertw with z=", z, " did not converge in ", maxits, " iterations.") + return x end ### Real z ### # Real x, k = 0 - +# This appears to be inferrable with T=Float64 and T=BigFloat, including if x=Inf. # The fancy initial condition selection does not seem to help speed, but we leave it for now. -function lambertwk0(x::T)::T where T<:AbstractFloat - x == Inf && return Inf +function lambertwk0(x::T, maxits)::T where T<:AbstractFloat + isnan(x) && return(NaN) + x == Inf && return Inf # appears to return convert(BigFloat,Inf) for x == BigFloat(Inf) one_t = one(T) - oneoe = -one_t/convert(T,euler) + oneoe = -one_t/convert(T,MathConstants.e) # The branch point x == oneoe && return -one_t + oneoe <= x || throw(DomainError(x)) itwo_t = 1/convert(T,2) - oneoe <= x || @baddomain(x) if x > one_t lx = log(x) llx = log(lx) @@ -67,28 +49,28 @@ function lambertwk0(x::T)::T where T<:AbstractFloat else x1 = (567//1000) * x end - _lambertw(x,x1) + return _lambertw(x, x1, maxits) end # Real x, k = -1 -function _lambertwkm1(x::T) where T<:Real - oneoe = -one(T)/convert(T,euler) - x == oneoe && return -one(T) - oneoe <= x || @baddomain(x) - x == zero(T) && return -convert(T,Inf) - x < zero(T) || @baddomain(x) - _lambertw(x,log(-x)) +function lambertwkm1(x::T, maxits) where T<:Real + oneoe = -one(T)/convert(T,MathConstants.e) + x == oneoe && return -one(T) # W approaches -1 as x -> -1/e from above + oneoe <= x || throw(DomainError(x)) # branch domain exludes x < -1/e + x == zero(T) && return -convert(T,Inf) # W decreases w/o bound as x -> 0 from below + x < zero(T) || throw(DomainError(x)) + return _lambertw(x, log(-x), maxits) end - """ - lambertw(z::Complex{T}, k::V=0) where {T<:Real, V<:Integer} - lambertw(z::T, k::V=0) where {T<:Real, V<:Integer} + lambertw(z::Complex{T}, k::V=0, maxits=1000) where {T<:Real, V<:Integer} + lambertw(z::T, k::V=0, maxits=1000) where {T<:Real, V<:Integer} Compute the `k`th branch of the Lambert W function of `z`. If `z` is real, `k` must be either `0` or `-1`. For `Real` `z`, the domain of the branch `k = -1` is `[-1/e,0]` and the domain of the branch `k = 0` is `[-1/e,Inf]`. For `Complex` `z`, and all `k`, the domain is -the complex plane. +the complex plane. When using root finding to compute `W`, a value for `W` is returned +with a warning if it has not converged after `maxits` iterations. ```jldoctest julia> lambertw(-1/e,-1) @@ -107,33 +89,31 @@ julia> lambertw(Complex(-10.0,3.0), 4) -0.9274337508660128 + 26.37693445371142im ``` -!!! note - The constant `LAMBERTW_USE_NAN` at the top of the source file controls whether arguments - outside the domain throw `DomainError` or return `NaN`. The default is `DomainError`. """ -function lambertw(x::Real, k::Integer) - k == 0 && return lambertwk0(x) - k == -1 && return _lambertwkm1(x) - @baddomain(k) # more informative message like below ? -# error("lambertw: real x must have k == 0 or k == -1") +lambertw(z, k::Integer=0, maxits::Integer=1000) = lambertw_(z, k, maxits) + +function lambertw_(x::Real, k, maxits) + k == 0 && return lambertwk0(x, maxits) + k == -1 && return lambertwkm1(x, maxits) + throw(DomainError(k, "lambertw: real x must have branch k == 0 or k == -1")) end -function lambertw(x::Union{Integer,Rational}, k::Integer) +function lambertw_(x::Union{Integer,Rational}, k, maxits) if k == 0 x == 0 && return float(zero(x)) - x == 1 && return convert(typeof(float(x)),omega) # must be more efficient way + x == 1 && return convert(typeof(float(x)), omega) # must be a more efficient way end - lambertw(float(x),k) + return lambertw_(float(x), k, maxits) end ### Complex z ### # choose initial value inside correct branch for root finding -function lambertw(z::Complex{T}, k::Integer) where T<:Real +function lambertw_(z::Complex{T}, k, maxits) where T<:Real one_t = one(T) local w::Complex{T} pointseven = 7//10 - if abs(z) <= one_t/convert(T,euler) + if abs(z) <= one_t/convert(T,MathConstants.e) if z == 0 k == 0 && return z return complex(-convert(T,Inf),zero(T)) @@ -157,27 +137,19 @@ function lambertw(z::Complex{T}, k::Integer) where T<:Real w = log(z) k != 0 ? w += complex(0, 2*k*pi) : nothing end - _lambertw(z,w) + return _lambertw(z, w, maxits) end -lambertw(z::Complex{T}, k::Integer) where T<:Integer = lambertw(float(z),k) +lambertw_(z::Complex{T}, k, maxits) where T<:Integer = lambertw_(float(z), k, maxits) +lambertw_(n::Irrational, k, maxits) = lambertw_(float(n), k, maxits) # lambertw(e + 0im,k) is ok for all k -#function lambertw(::Irrational{:e}, k::T) where T<:Integer -function lambertw(::typeof(euler), k::T) where T<:Integer +# Maybe this should return a float. But, this should cause no type instability in any case +function lambertw_(::typeof(MathConstants.e), k, maxits) k == 0 && return 1 - @baddomain(k) + throw(DomainError(k)) end -# Maybe this should return a float -lambertw(::typeof(euler)) = 1 -#lambertw(::Irrational{:e}) = 1 - -#lambertw{T<:Number}(x::T) = lambertw(x,0) -lambertw(x::Number) = lambertw(x,0) - -lambertw(n::Irrational, args::Integer...) = lambertw(float(n),args...) - ### omega constant ### const omega_const_ = 0.567143290409783872999968662210355 @@ -185,7 +157,7 @@ const omega_const_ = 0.567143290409783872999968662210355 # maybe compute higher precision. converges very quickly function omega_const(::Type{BigFloat}) - @compat precision(BigFloat) <= 256 && return omega_const_bf_[] + precision(BigFloat) <= 256 && return omega_const_bf_[] myeps = eps(BigFloat) oc = omega_const_bf_[] for i in 1:100 @@ -200,7 +172,7 @@ end omega ω -A constant defined by `ω exp(ω) = 1`. +The constant defined by `ω exp(ω) = 1`. ```jldoctest julia> ω @@ -219,7 +191,7 @@ julia> big(omega) const ω = Irrational{:ω}() @doc (@doc ω) omega = ω -# The following three lines may be removed when support for v0.6 is dropped +# The following two lines may be removed when support for v0.6 is dropped Base.convert(::Type{AbstractFloat}, o::Irrational{:ω}) = Float64(o) Base.convert(::Type{Float16}, o::Irrational{:ω}) = Float16(o) Base.convert(::Type{T}, o::Irrational{:ω}) where T <:Number = T(o) @@ -236,7 +208,7 @@ Base.BigFloat(o::Irrational{:ω}) = omega_const(BigFloat) # (4.23) and (4.24) for all μ are also given. This code implements the # recursion relations. -# (4.23) and (4.24) give zero based coefficients +# (4.23) and (4.24) give zero based coefficients. cset(a,i,v) = a[i+1] = v cget(a,i) = a[i+1] @@ -247,7 +219,7 @@ function compa(k,m,a) sum0 += cget(m,j) * cget(m,k+1-j) end cset(a,k,sum0) - sum0 + return sum0 end # (4.23) @@ -256,7 +228,7 @@ function compm(k,m,a) mk = (kt-1)/(kt+1) *(cget(m,k-2)/2 + cget(a,k-2)/4) - cget(a,k)/2 - cget(m,k-1)/(kt+1) cset(m,k,mk) - mk + return mk end # We plug the known value μ₂ == -1//3 for (4.22) into (4.23) and @@ -283,19 +255,21 @@ end const LAMWMU_FLOAT64 = lamwcoeff(Float64,500) -function horner(x, p::AbstractArray,n) +# Base.Math.@horner requires literal coefficients +# But, we have an array `p` of computed coefficients +function horner(x, p::AbstractArray, n) n += 1 ex = p[n] for i = n-1:-1:2 - ex = :($(p[i]) + t * $ex) + ex = :(muladd(t, $ex, $(p[i]))) end ex = :( t * $ex) - Expr(:block, :(t = $x), ex) + return Expr(:block, :(t = $x), ex) end function mkwser(name, n) iex = horner(:x,LAMWMU_FLOAT64,n) - :(function ($name)(x) $iex end) + return :(function ($name)(x) $iex end) end eval(mkwser(:wser3, 3)) @@ -320,7 +294,7 @@ function wser(p,x) x < 5e-2 && return wser32(p) x < 1e-1 && return wser50(p) x < 1.9e-1 && return wser100(p) - x > 1/euler && @baddomain(x) # radius of convergence + x > 1/MathConstants.e && throw(DomainError(x)) # radius of convergence return wser290(p) # good for x approx .32 end @@ -335,28 +309,28 @@ function wser(p::Complex{T},z) where T<:Real x < 5e-2 && return wser32(p) x < 1e-1 && return wser50(p) x < 1.9e-1 && return wser100(p) - x > 1/euler && @baddomain(x) # radius of convergence + x > 1/MathConstants.e && throw(DomainError(x)) # radius of convergence return wser290(p) end @inline function _lambertw0(x) # 1 + W(-1/e + x) , k = 0 - ps = 2*euler*x; + ps = 2*MathConstants.e*x; p = sqrt(ps) - wser(p,x) + return wser(p,x) end @inline function _lambertwm1(x) # 1 + W(-1/e + x) , k = -1 - ps = 2*euler*x; + ps = 2*MathConstants.e*x; p = -sqrt(ps) - wser(p,x) + return wser(p,x) end """ lambertwbp(z,k=0) -Accurate value of `1 + W(-1/e + z)`, for `abs(z)` in `[0,1/e]` for `k` either `0` or `-1`. -Accurate to Float64 precision for abs(z) < 0.32. -If `k=-1` and `imag(z) < 0`, the value on the branch `k=1` is returned. `lambertwbp` is vectorized. +Compute accurate value of `1 + W(-1/e + z)`, for `abs(z)` in `[0,1/e]` for `k` either `0` or `-1`. +The result is accurate to Float64 precision for abs(z) < 0.32. +If `k=-1` and `imag(z) < 0`, the value on the branch `k=1` is returned. ```jldoctest julia> lambertw(-1/e + 1e-18, -1) @@ -378,9 +352,7 @@ julia> convert(Float64,(lambertw(-BigFloat(1)/e + BigFloat(10)^(-18),-1) + 1)) function lambertwbp(x::Number,k::Integer) k == 0 && return _lambertw0(x) k == -1 && return _lambertwm1(x) - error("expansion about branch point only implemented for k = 0 and -1") + throw(ArgumentError("expansion about branch point only implemented for k = 0 and -1.")) end lambertwbp(x::Number) = _lambertw0(x) - -nothing diff --git a/test/lambertw_test.jl b/test/lambertw_test.jl index 7614d9d2..46d5dd57 100644 --- a/test/lambertw_test.jl +++ b/test/lambertw_test.jl @@ -1,44 +1,31 @@ using Compat -macro test_baddomain(expr) - if SpecialFunctions.LAMBERTW_USE_NAN - :(@test $(esc(expr)) === NaN) - else - :(@test_throws DomainError $(esc(expr))) - end -end - -const euler = -if isdefined(Base, :MathConstants) - Base.MathConstants.e -else - e -end +import Compat.MathConstants ### domain errors -@test_baddomain lambertw(-2.0,0) -@test_baddomain lambertw(-2.0,-1) -@test_baddomain lambertw(-2.0,1) -@test_baddomain lambertw(NaN) +@test_throws DomainError lambertw(-2.0,0) +@test_throws DomainError lambertw(-2.0,-1) +@test_throws DomainError lambertw(-2.0,1) +@test isnan(lambertw(NaN)) ## math constant e -@test_baddomain lambertw(euler,1) -@test_baddomain lambertw(euler,-1) +@test_throws DomainError lambertw(MathConstants.e,1) +@test_throws DomainError lambertw(MathConstants.e,-1) ## integer arguments return floating point types @test typeof(lambertw(0)) <: AbstractFloat @test lambertw(0) == 0 -### math constant, euler e +### math constant, MathConstants.e e # could return math const e, but this would break type stability @test typeof(lambertw(1)) <: AbstractFloat -@test lambertw(euler,0) == 1 +@test lambertw(MathConstants.e,0) == 1 ## value at branch point where real branches meet -@test lambertw(-1/euler,0) == lambertw(-1/euler,-1) == -1 -@test typeof(lambertw(-1/euler,0)) == typeof(lambertw(-1/euler,-1)) <: AbstractFloat +@test lambertw(-1/MathConstants.e,0) == lambertw(-1/MathConstants.e,-1) == -1 +@test typeof(lambertw(-1/MathConstants.e,0)) == typeof(lambertw(-1/MathConstants.e,-1)) <: AbstractFloat ## convert irrationals to float @@ -94,9 +81,9 @@ end # The routine will start at -1/e + eps * im, rather than -1/e + 0im, # otherwise root finding will fail if Int != Int32 - @test abs(lambertw(-1.0/euler + 0im,-1)) == 1 + @test abs(lambertw(-1.0/MathConstants.e + 0im,-1)) == 1 else - @test abs(lambertw(-1.0/euler + 0im,-1) + 1) < 1e-7 + @test abs(lambertw(-1.0/MathConstants.e + 0im,-1) + 1) < 1e-7 end # lambertw for BigFloat is more precise than Float64. Note # that 70 digits in test is about 35 digits in W @@ -111,9 +98,9 @@ end ## get ω from recursion and compare to value from lambertw let sp = precision(BigFloat) - @compat setprecision(512) + setprecision(512) @test lambertw(big(1)) == big(SpecialFunctions.omega) - @compat setprecision(sp) + setprecision(sp) end @test lambertw(1) == float(SpecialFunctions.omega) @@ -124,7 +111,7 @@ end ### expansion about branch point # not a domain error, but not implemented -@test_throws ErrorException lambertwbp(1,1) +@test_throws ArgumentError lambertwbp(1,1) @test_throws DomainError lambertw(.3,2) @@ -134,44 +121,36 @@ end if Int != Int32 -let sp = precision(BigFloat), z = BigFloat(1)/10^12, wo, diff - @compat setprecision(2048) +# Test double-precision expansion near branch point using BigFloats +let sp = precision(BigFloat), z = BigFloat(1)/10^12, wo, xdiff + setprecision(2048) for i in 1:300 - # k = 0 - @test (wo = lambertwbp(Float64(z)); diff = abs(-1 + wo - lambertw(z-1/big(euler))); true) - if diff > 5e-16 - println(Float64(z), " ", Float64(diff)) + innerarg = z-1/big(MathConstants.e) + + # branch k = 0 + @test (wo = lambertwbp(Float64(z)); xdiff = abs(-1 + wo - lambertw(innerarg)); true) + if xdiff > 5e-16 + println(Float64(z), " ", Float64(xdiff)) end - @test diff < 5e-16 - # k = -1 - @test (wo = lambertwbp(Float64(z),-1); diff = abs(-1 + wo - lambertw(z-1/big(euler),-1)); true) - if diff > 5e-16 - println(Float64(z), " ", Float64(diff)) + @test xdiff < 5e-16 + + # branch k = -1 + @test (wo = lambertwbp(Float64(z),-1); xdiff = abs(-1 + wo - lambertw(innerarg,-1)); true) + if xdiff > 5e-16 + println(Float64(z), " ", Float64(xdiff)) end - @test diff < 5e-16 + @test xdiff < 5e-16 z *= 1.1 if z > 0.23 break end + end - @compat setprecision(sp) + setprecision(sp) end # test the expansion about branch point for k=-1, # by comparing to exact BigFloat calculation. -@test lambertwbp(1e-20,-1) - 1 - lambertw(-BigFloat(1)/big(euler)+ BigFloat(1)/BigFloat(10)^BigFloat(20),-1) < 1e-16 +@test lambertwbp(1e-20,-1) - 1 - lambertw(-BigFloat(1)/big(MathConstants.e)+ BigFloat(1)/BigFloat(10)^BigFloat(20),-1) < 1e-16 @test abs(lambertwbp(Complex(.01,.01),-1) - Complex(-0.2755038208041206, -0.1277888928494641)) < 1e-14 -end - -## vectorization - -if VERSION >= v"0.5" - @test lambertw.([0.1,0.2]) == [lambertw(0.1),lambertw(0.2)] - @test lambertw.([0.1+im ,0.2-im]) == [lambertw(0.1+im),lambertw(0.2-im)] - @test lambertw.([0.1,-0.2],[0,-1]) == [lambertw(0.1,0),lambertw(-0.2,-1)] - @test lambertwbp.([.1,.2,.3],-1) == map(x -> lambertwbp(x,-1), [.1,.2,.3]) -else - @test lambertw([0.1,0.2]) == [lambertw(0.1),lambertw(0.2)] - @test lambertw([0.1+im ,0.2-im]) == [lambertw(0.1+im),lambertw(0.2-im)] - @test lambertw([0.1,-0.2],[0,-1]) == [lambertw(0.1,0),lambertw(-0.2,-1)] -end +end # if Int != Int32