Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
name = "SpectralKit"
uuid = "5c252ae7-b5b6-46ab-a016-b0e3d78320b7"
authors = ["Tamas K. Papp <[email protected]>"]
version = "0.16.1"
authors = ["Tamas K. Papp <[email protected]>"]

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
ArgCheck = "1, 2"
Compat = "4.18.0"
DocStringExtensions = "0.8, 0.9"
InverseFunctions = "0.1"
OrderedCollections = "1"
StaticArrays = "1"
julia = "1.10"
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ makedocs(
format = Documenter.HTML(; prettyurls = get(ENV, "CI", nothing) == "true"),
authors = "Tamas K. Papp",
sitename = "SpectralKit.jl",
pages = Any["index.md"],
pages = Any["index.md", "experimental.md"],
clean = true,
checkdocs = :exports,
)
Expand Down
11 changes: 11 additions & 0 deletions docs/src/experimental.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Experimental

```@docs
SpectralKit.Experimental
```

## Proposed additions to general API

```@docs
SpectralKit.Experimental.constant_coefficients
```
1 change: 1 addition & 0 deletions src/SpectralKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ include("generic_api.jl")
include("chebyshev.jl")
include("smolyak_traversal.jl")
include("smolyak_api.jl")
include("experimental.jl") # experimental code is not part of the API, see its module docstring

end # module
233 changes: 233 additions & 0 deletions src/experimental.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
"""
Experimental functionality. Officially not (yet) part of the API.

A best effort is made to

1. only change in breaking ways (including integration into the public API) with major or
minor releases.
2. fix issues (users are encouraged to open them)
"""
module Experimental

# TODO
# - [ ] clean up API, too many functions with positional arguments
# - [ ] write up tutorial-like docs
# - [ ] docs: an ascii diagram of various calculations, with https://asciiflow.com/
# - [ ] solver functions
# - [ ] actually implement a solution and AD it

# migrate to generic API
export constant_coefficients

using Compat: @compat
@compat public model_parameters_dimension, make_model_parameters,
calculate_derived_quantities, make_approximation_basis, describe_policy_transformations,
policy_coefficients_dimension, make_policy_functions, constant_initial_guess,
calculate_initial_guess, sum_of_squared_residuals

import ..SpectralKit

using ArgCheck: @argcheck
using DocStringExtensions: FUNCTIONNAME, SIGNATURES
using InverseFunctions: inverse

####
#### utilities
####

"""
$(SIGNATURES)

FIXME docs
"""
function named_cumulative_ranges(lengths::NamedTuple{N}) where N
# NOTE this approach relies on Base.afoldl, so it unrolled only up to 31 elements
ranges = accumulate(lengths, init = 0:0) do r, l
a = last(r)
(a + 1):(a + l)
end
NamedTuple{N}(ranges)
end

####
#### generic API additions
####

"""
$(FUNCTIONNAME)(basis, y)

Approximate a constant value `y` on `basis`.

Formally, return a set of coefficients `θ` such that `linear_combination(basis, θ, x) ≈
y` for all `x` in the domain.
"""
function constant_coefficients end

function constant_coefficients(basis::SpectralKit.Chebyshev, y)
θ = zeros(basis.N)
θ[1] = y
θ
end

function constant_coefficients(basis::SpectralKit.TransformedBasis, y)
constant_coefficients(parent(basis), y)
end

function constant_coefficients(basis::SpectralKit.SmolyakBasis, y)
θ = zeros(SpectralKit.dimension(basis))
θ[1] = y
θ
end

####
#### modelling API
####

const USERNOTE = "**User should implement this method for the relevant `model_family`.**"

"""
$(FUNCTIONNAME)(model_family)

Dimension of the model parameters as a flat vector.

$(USERNOTE)
"""
function model_parameters_dimension end

"""
$(FUNCTIONNAME)(model_family, x::AbstractVector)

Convert a flat vector of dimension [`model_parameters_dimension`)(@ref) to the model
parameters.

$(USERNOTE)

The returned value can be an arbitrary (eg a `NamedTuple`), as it does not need to be
not used for dispatch.

!!! NOTE
The method should accept all finite numbers in ``ℝ``, and transform them accordingly.
"""
function make_model_parameters end

"""
$(SIGNATURES)(model_family, model_parameters)

Calculate derived quantities (for use in determining the bases and calculating the residuals).

$(USERNOTE)
"""
function calculate_derived_quantities end

"""
$(FUNCTIONNAME)(model_family, derived_quantities, approximation_scheme)

Construct an approximation basis.

$(USERNOTE)
"""
function make_approximation_basis end

"""
$(FUNCTIONNAME)(model_family)

Return a `NamedTuple` is policy function approximation schemes. The values describe the
transformations.

$(USERNOTE)
"""
function describe_policy_transformations end

"""
$(FUNCTIONNAME)(model_family, model_parameters, policy_functions, gridpoint)

Calculate the residuals at the given gridpoint.

$(USERNOTE)
"""
function calculate_residuals end

function policy_coefficients_dimension(policy_transformations::NamedTuple, approximation_basis)
SpectralKit.dimension(approximation_basis) * length(policy_transformations)
end

function make_policy_functions(model_family, policy_transformations::NamedTuple,
approximation_basis, coefficients)
d = SpectralKit.dimension(approximation_basis)
# QUESTION line below assumes all univariate, generalize?
ranges = named_cumulative_ranges(map(_ -> d, policy_transformations))
@argcheck firstindex(coefficients) == 1
@argcheck lastindex(coefficients) == last(last(ranges))
map(ranges, policy_transformations) do r, t
t ∘ SpectralKit.linear_combination(approximation_basis, @view coefficients[r])
end
end

"""
$(FUNCTIONNAME)(model_family, model_parameters, derived_quantities)

Return initial guesses in a type that supports `getproperty` (eg a `NamedTuple`).

Should provide a scalar for each name in [`describe_policy_transformations`](@ref) (can
be in any arbitrary order, and contain other names).
"""
function constant_initial_guess end

"""
$(SIGNATURES)

Return an initial guess for the coefficients. Falls back to using
[`constant_initial_guess`](@ref), or the user may define a method in case that is not
sufficient.
"""
function calculate_initial_guess(model_family, model_parameters, derived_quantities,
policy_transformations::NamedTuple{N},
approximation_basis) where N
constant_guess = constant_initial_guess(model_family, model_parameters, derived_quantities)
d = SpectralKit.dimension(approximation_basis)
# QUESTION line below assumes all univariate, generalize?
ranges = named_cumulative_ranges(map(_ -> d, policy_transformations))
coefficients = zeros(last(last(ranges)))
for (name, transformation) in pairs(policy_transformations)
r = getproperty(ranges, name)
transformed_y = getproperty(constant_guess, name)
y = inverse(transformation)(transformed_y)
# FIXME a constant_coefficient! API would have fewer allocations
coefficients[r] .= constant_coefficients(approximation_basis, y)
end
coefficients
end

"""
$(SIGNATURES)

Return a grid for calculating the residuals. The default implementation just uses the
grid that corresponds to the approximation basis.
"""
function make_approximation_grid(model_family, model_parameters, derived_quantities,
approximation_basis)
SpectralKit.grid(approximation_basis)
end

"""
$(SIGNATURES)

Sum of squares. Works for values returned by [`calculate_residuals`](@ref).
"""
sum_abs2(x::Real) = abs2(x)

sum_abs2(nt::NamedTuple) = sum(abs2, values(nt))

"""
$(SIGNATURES)

Calculate the sum of squared residuals.
"""
function sum_of_squared_residuals(model_family, model_parameters, policy_functions, grid)
mapreduce(+, grid) do gridpoint
sum_abs2(calculate_residuals(model_family, model_parameters, policy_functions,
gridpoint))
end
end

end
4 changes: 2 additions & 2 deletions src/generic_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,10 @@ function basis_at(basis::TransformedBasis, x)
basis_at(parent, transform_to(domain(parent), transformation, x))
end

function grid(basis::TransformedBasis)
function grid(::Type{T}, basis::TransformedBasis) where T
(; parent, transformation) = basis
d = domain(parent)
Iterators.map(x -> transform_from(d, transformation, x), grid(parent))
Iterators.map(x -> transform_from(d, transformation, x), grid(T, parent))
end

function Base.:(∘)(linear_combination::LinearCombination, transformation)
Expand Down
5 changes: 5 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Sobol = "ed01d8cd-4d21-5b2a-85b4-cc3bdc58bad4"
SpectralKit = "5c252ae7-b5b6-46ab-a016-b0e3d78320b7"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[sources]
SpectralKit = {path = ".."}
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ include("test_chebyshev.jl")
include("test_smolyak_traversal.jl")
include("test_smolyak.jl")
include("test_generic_api.jl") # NOTE moved last as it used constructs from above
include("test_experimental.jl") # NOTE experimental code is not public API

using JET
@testset "static analysis with JET.jl" begin
Expand Down
Loading
Loading