Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
db962f1
multiple_dispatch for build_weights
Davide-Miotti Aug 31, 2025
b84464b
tests missing, integration complete
Davide-Miotti Aug 31, 2025
6ea5269
more unified approach
kylebeggs Sep 5, 2025
f72e1d5
removed duplicated code in solve_hermite.jl
Davide-Miotti Sep 14, 2025
7a0d317
renamed StandardStencil -> InternalStencil
Davide-Miotti Sep 14, 2025
3bdbea1
boundary_types quick fix with test
Davide-Miotti Sep 21, 2025
7343fd0
fix formatting
Davide-Miotti Sep 21, 2025
ed9c1d4
unit tests added
Davide-Miotti Sep 21, 2025
9c2a1b5
integration with RBF operators (temp)
Davide-Miotti Sep 21, 2025
93808f5
update julia compat, bump CI accordinyl, remove claude CI
kylebeggs Oct 15, 2025
2f65324
fix format CI and PrecompileTools compat
kylebeggs Oct 15, 2025
6cbb51c
PrecompileTools compat
kylebeggs Oct 15, 2025
3666230
bump LinearSOlve.jl compat
kylebeggs Oct 15, 2025
4168755
try newer LinearSolve.jl
kylebeggs Oct 15, 2025
9fa306a
remove LinearSOlve as we aren't actually using it!!
kylebeggs Oct 15, 2025
58f1730
format
kylebeggs Oct 15, 2025
14b95ff
update docs and documenter
kylebeggs Oct 18, 2025
aa17ff1
fix headline
kylebeggs Oct 18, 2025
76859ee
introduced Internal BC for stencil data
Davide-Miotti Oct 19, 2025
141c3ca
complete and refactor unit tests
Davide-Miotti Oct 19, 2025
f43e9a0
refactor hermite_integration
Davide-Miotti Oct 19, 2025
9ee249f
clean hermite_integration test
Davide-Miotti Oct 19, 2025
f9bb6ed
remove hermite_simple
Davide-Miotti Oct 19, 2025
0315123
add end to end tests
Davide-Miotti Oct 26, 2025
c012ac2
remove extra tests
Davide-Miotti Oct 26, 2025
3af4bb0
fix bug in solve_hermite.jl
Davide-Miotti Oct 27, 2025
ed70f0e
lower tolerances
Davide-Miotti Oct 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/RadialBasisFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ include("operators/operators.jl")
export RadialBasisOperator, ScalarValuedOperator, VectorValuedOperator
export update_weights!, is_cache_valid

include("solve_utils.jl")

# Boundary types needed by solve.jl
# Boundary types needed by solve.jl and solve_utils.jl
include("boundary_types.jl")

include("solve_utils.jl")

include("solve.jl")

# New clean Hermite implementation
Expand Down
104 changes: 88 additions & 16 deletions src/boundary_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ Boundary condition types and utilities for Hermite interpolation.
struct BoundaryCondition{T<:Real}
α::T
β::T

function BoundaryCondition(α::A, β::B) where {A,B}
T = promote_type(A, B)
new{T}(α, β)
return new{T}(α, β)
end
end

Expand All @@ -19,7 +19,7 @@ end

# Predicate functions
is_dirichlet(bc::BoundaryCondition) = isone(bc.α) && iszero(bc.β)
is_neumann(bc::BoundaryCondition) = iszero(bc.α) && isone(bc.β)
is_neumann(bc::BoundaryCondition) = iszero(bc.α) && isone(bc.β)
is_robin(bc::BoundaryCondition) = !iszero(bc.α) && !iszero(bc.β)

# Constructor helpers
Expand All @@ -28,30 +28,102 @@ Neumann(::Type{T}=Float64) where {T<:Real} = BoundaryCondition(zero(T), one(T))
Robin(α::Real, β::Real) = BoundaryCondition(α, β)

# Boundary information for a local stencil
struct HermiteBoundaryInfo{T<:Real}
"""
This struct is meant to be used to correctly broadcast the build_stencil() function
When Hermite scheme is used, it can be given to _build_stencil!() in place of the sole data.
"""
struct HermiteStencilData{T<:Real}
data::AbstractVector{Vector{T}} # Coordinates of stencil points
is_boundary::Vector{Bool}
boundary_conditions::Vector{BoundaryCondition{T}}
normals::Vector{Vector{T}}

function HermiteBoundaryInfo(

function HermiteStencilData(
data::AbstractVector{Vector{T}},
is_boundary::Vector{Bool},
boundary_conditions::Vector{BoundaryCondition{T}},
normals::Vector{Vector{T}}
normals::Vector{Vector{T}},
) where {T<:Real}
@assert length(is_boundary) == length(boundary_conditions) == length(normals)
new{T}(is_boundary, boundary_conditions, normals)
@assert length(data) ==
length(is_boundary) ==
length(boundary_conditions) ==
length(normals)
return new{T}(data, is_boundary, boundary_conditions, normals)
end
end

#pre-allocation constructor
function HermiteStencilData{T}(k::Int, dim::Int) where {T<:Real}
data = [Vector{T}(undef, dim) for _ in 1:k] # Pre-allocate with correct dimension
is_boundary = falses(k)
boundary_conditions = [Dirichlet(T) for _ in 1:k]
normals = [Vector{T}(undef, dim) for _ in 1:k] # Pre-allocate with correct dimension
return HermiteStencilData(data, is_boundary, boundary_conditions, normals)
end

"""
Populate local boundary information for a specific stencil within a kernel.
This function extracts boundary data for the neighbors of eval_idx and fills
the pre-allocated HermiteBoundaryInfo structure.

# Arguments
- `boundary_info`: Pre-allocated HermiteBoundaryInfo structure to fill (for this batch)
- `eval_idx`: Current evaluation point index
- `adjl`: Adjacency list for eval_idx (the neighbors)
- `is_boundary`: Global is_boundary vector for all points
- `boundary_conditions`: Global boundary_conditions vector for all points
- `normals`: Global normals vector for all points
"""
function update_stencil_data!(
hermite_data::HermiteStencilData{T}, # Pre-allocated structure passed in
global_data::AbstractVector{Vector{T}},
neighbors::Vector{Int}, # adjl[eval_idx]
is_boundary::Vector{Bool},
boundary_conditions::Vector{BoundaryCondition{T}},
normals::Vector{Vector{T}},
global_to_boundary::Vector{Int},
) where {T}
k = length(neighbors)

# Fill local boundary info for each neighbor (in-place, no allocation)
@inbounds for local_idx in 1:k
global_idx = neighbors[local_idx]
hermite_data.data[local_idx] .= global_data[global_idx]
hermite_data.is_boundary[local_idx] = is_boundary[global_idx]

if is_boundary[global_idx]
boundary_idx = global_to_boundary[global_idx]
hermite_data.boundary_conditions[local_idx] = boundary_conditions[boundary_idx]
hermite_data.normals[local_idx] .= normals[boundary_idx]
else
# Set default Dirichlet for interior points (not used but keeps type consistency)
hermite_data.boundary_conditions[local_idx] = Dirichlet(T)
fill!(hermite_data.normals[local_idx], zero(T))
end
end

return nothing
end

# Trait types for dispatch
abstract type StencilType end
struct StandardStencil <: StencilType end
struct DirichletStencil <: StencilType end
struct HermiteStencil <: StencilType end

# Trait function to determine stencil type
stencil_type(boundary_info::Nothing) = StandardStencil()
stencil_type(boundary_info::HermiteBoundaryInfo) = any(boundary_info.is_boundary) ? HermiteStencil() : StandardStencil()

# Convenience function to check if any boundary points in stencil
has_boundary_points(boundary_info::Nothing) = false
has_boundary_points(boundary_info::HermiteBoundaryInfo) = any(boundary_info.is_boundary)
function stencil_type(
is_boundary::Vector{Bool},
boundary_conditions::Vector{BoundaryCondition},
eval_idx::Int,
neighbors::Vector{Int},
global_to_boundary::Vector{Int},
)
if sum(is_boundary[neighbors]) == 0
return StandardStencil()
elseif is_boundary[eval_idx] &&
is_dirichlet(boundary_conditions[global_to_boundary[eval_idx]])
return DirichletStencil()
else
return HermiteStencil()
end
end
2 changes: 1 addition & 1 deletion src/interpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function Interpolator(x, y, basis::B=PHS()) where {B<:AbstractRadialBasis}
mon = MonomialBasis(dim, basis.poly_deg)
data_type = promote_type(eltype(first(x)), eltype(y))
A = Symmetric(zeros(data_type, n, n))
_build_collocation_matrix!(A, x, basis, mon, k, StandardStencil())
_build_collocation_matrix!(A, x, basis, mon, k)
b = data_type[i < k ? y[i] : 0 for i in 1:n]
w = A \ b
return Interpolator(x, y, w[1:k], w[(k + 1):end], basis, mon)
Expand Down
95 changes: 11 additions & 84 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,92 +20,20 @@ end
function _build_weights(
data, eval_points, adjl, basis, ℒrbf, ℒmon, mon; batch_size=10, device=CPU()
)
TD = eltype(first(data))
dim = length(first(data)) # dimension of data
nmon = binomial(dim + basis.poly_deg, basis.poly_deg)
k = length(first(adjl)) # number of data in influence/support domain
sizes = (k, nmon)

# allocate arrays to build sparse matrix
Na = length(adjl)
I = zeros(Int, k * Na)
J = reduce(vcat, adjl)
V = zeros(TD, k * Na, _num_ops(ℒrbf))

# Calculate number of batches
n_batches = ceil(Int, Na / batch_size)

# Create kernel for building stencils in batches
@kernel function build_stencils_kernel(
I, J, V, data, eval_points, adjl, basis, ℒrbf, ℒmon, mon, k, batch_size, Na
)
# Get the batch index for this thread
batch_idx = @index(Global)

# Calculate the range of points for this batch
start_idx = (batch_idx - 1) * batch_size + 1
end_idx = min(batch_idx * batch_size, Na)

# Pre-allocate work arrays for this thread
n = k + nmon
A = Symmetric(zeros(TD, n, n), :U)
b = _prepare_b(ℒrbf, TD, n)

# Process each point in the batch sequentially
for i in start_idx:end_idx
# Set row indices for sparse matrix
for idx in 1:k
I[(i - 1) * k + idx] = i
end

# Get data points in the influence domain
local_data = [data[j] for j in adjl[i]]

# Build stencil and store in global weight matrix
stencil = _build_stencil!(
A, b, ℒrbf, ℒmon, local_data, eval_points[i], basis, mon, k, StandardStencil()
)

# Store the stencil weights in the value array
for op in axes(V, 2)
for idx in 1:k
V[(i - 1) * k + idx, op] = stencil[idx, op]
end
end
end
end

# Launch kernel with one thread per batch
kernel = build_stencils_kernel(device)
kernel(
I,
J,
V,
# Use the unified kernel infrastructure with standard allocation strategy
return _build_weights_unified(
StandardAllocation(),
data,
eval_points,
adjl,
basis,
ℒrbf,
ℒmon,
mon,
k,
batch_size,
Na;
ndrange=n_batches,
workgroupsize=1,
nothing;
batch_size=batch_size,
device=device,
)

# Wait for kernel to complete
KernelAbstractions.synchronize(device)

# Create and return sparse matrix/matrices
nrows = length(adjl)
ncols = length(data)
if size(V, 2) == 1
return sparse(I, J, V[:, 1], nrows, ncols)
else
return ntuple(i -> sparse(I, J, V[:, i], nrows, ncols), size(V, 2))
end
end

function _build_stencil!(
Expand All @@ -118,15 +46,14 @@ function _build_stencil!(
basis::B,
mon::MonomialBasis,
k::Int,
::StandardStencil
) where {TD,TE,B<:AbstractRadialBasis}
_build_collocation_matrix!(A, data, basis, mon, k, StandardStencil())
_build_rhs!(b, ℒrbf, ℒmon, data, eval_point, basis, k, StandardStencil())
_build_collocation_matrix!(A, data, basis, mon, k)
_build_rhs!(b, ℒrbf, ℒmon, data, eval_point, basis, k)
return (A \ b)[1:k, :]
end

function _build_collocation_matrix!(
A::Symmetric, data::AbstractVector, basis::B, mon::MonomialBasis{Dim,Deg}, k::K, ::StandardStencil
A::Symmetric, data::AbstractVector, basis::B, mon::MonomialBasis{Dim,Deg}, k::K
) where {B<:AbstractRadialBasis,K<:Int,Dim,Deg}
# radial basis section
AA = parent(A)
Expand All @@ -147,7 +74,7 @@ function _build_collocation_matrix!(
end

function _build_rhs!(
b, ℒrbf, ℒmon, data::AbstractVector{TD}, eval_point::TE, basis::B, k::K, ::StandardStencil
b, ℒrbf, ℒmon, data::AbstractVector{TD}, eval_point::TE, basis::B, k::K
) where {TD,TE,B<:AbstractRadialBasis,K<:Int}
# radial basis section
@inbounds for i in eachindex(data)
Expand All @@ -165,7 +92,7 @@ function _build_rhs!(
end

function _build_rhs!(
b, ℒrbf::Tuple, ℒmon::Tuple, data::AbstractVector{TD}, eval_point::TE, basis::B, k::K, ::StandardStencil
b, ℒrbf::Tuple, ℒmon::Tuple, data::AbstractVector{TD}, eval_point::TE, basis::B, k::K
) where {TD,TE,B<:AbstractRadialBasis,K<:Int}
@assert size(b, 2) == length(ℒrbf) == length(ℒmon) "b, ℒrbf, ℒmon must have the same length"
# radial basis section
Expand Down
Loading
Loading