Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/RadialBasisFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,20 @@ include("operators/operators.jl")
export RadialBasisOperator, ScalarValuedOperator, VectorValuedOperator
export update_weights!, is_cache_valid

include("solve_utils.jl")

# Boundary types needed by solve.jl
# Boundary types needed by solve.jl and solve_utils.jl
include("boundary_types.jl")

include("solve_utils.jl")

include("solve.jl")

# New clean Hermite implementation
include("solve_hermite.jl")
export BoundaryCondition, Dirichlet, Neumann, Robin
export α, β, is_dirichlet, is_neumann, is_robin
export HermiteBoundaryInfo, StencilType, StandardStencil, HermiteStencil
export HermiteBoundaryInfo, StencilType, InternalStencil, HermiteStencil
export stencil_type, has_boundary_points


include("operators/custom.jl")
export Custom, custom

Expand Down
106 changes: 89 additions & 17 deletions src/boundary_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ Boundary condition types and utilities for Hermite interpolation.
struct BoundaryCondition{T<:Real}
α::T
β::T

function BoundaryCondition(α::A, β::B) where {A,B}
T = promote_type(A, B)
new{T}(α, β)
return new{T}(α, β)
end
end

Expand All @@ -19,7 +19,7 @@ end

# Predicate functions
is_dirichlet(bc::BoundaryCondition) = isone(bc.α) && iszero(bc.β)
is_neumann(bc::BoundaryCondition) = iszero(bc.α) && isone(bc.β)
is_neumann(bc::BoundaryCondition) = iszero(bc.α) && isone(bc.β)
is_robin(bc::BoundaryCondition) = !iszero(bc.α) && !iszero(bc.β)

# Constructor helpers
Expand All @@ -28,30 +28,102 @@ Neumann(::Type{T}=Float64) where {T<:Real} = BoundaryCondition(zero(T), one(T))
Robin(α::Real, β::Real) = BoundaryCondition(α, β)

# Boundary information for a local stencil
struct HermiteBoundaryInfo{T<:Real}
"""
This struct is meant to be used to correctly broadcast the build_stencil() function
When Hermite scheme is used, it can be given to _build_stencil!() in place of the sole data.
"""
struct HermiteStencilData{T<:Real}
data::AbstractVector{Vector{T}} # Coordinates of stencil points
is_boundary::Vector{Bool}
boundary_conditions::Vector{BoundaryCondition{T}}
normals::Vector{Vector{T}}

function HermiteBoundaryInfo(

function HermiteStencilData(
data::AbstractVector{Vector{T}},
is_boundary::Vector{Bool},
boundary_conditions::Vector{BoundaryCondition{T}},
normals::Vector{Vector{T}}
normals::Vector{Vector{T}},
) where {T<:Real}
@assert length(is_boundary) == length(boundary_conditions) == length(normals)
new{T}(is_boundary, boundary_conditions, normals)
@assert length(data) ==
length(is_boundary) ==
length(boundary_conditions) ==
length(normals)
return new{T}(data, is_boundary, boundary_conditions, normals)
end
end

#pre-allocation constructor
function HermiteStencilData{T}(k::Int, dim::Int) where {T<:Real}
data = [Vector{T}(undef, dim) for _ in 1:k] # Pre-allocate with correct dimension
is_boundary = falses(k)
boundary_conditions = [Dirichlet(T) for _ in 1:k]
normals = [Vector{T}(undef, dim) for _ in 1:k] # Pre-allocate with correct dimension
return HermiteStencilData(data, is_boundary, boundary_conditions, normals)
end

"""
Populate local boundary information for a specific stencil within a kernel.
This function extracts boundary data for the neighbors of eval_idx and fills
the pre-allocated HermiteBoundaryInfo structure.

# Arguments
- `boundary_info`: Pre-allocated HermiteBoundaryInfo structure to fill (for this batch)
- `eval_idx`: Current evaluation point index
- `adjl`: Adjacency list for eval_idx (the neighbors)
- `is_boundary`: Global is_boundary vector for all points
- `boundary_conditions`: Global boundary_conditions vector for all points
- `normals`: Global normals vector for all points
"""
function update_stencil_data!(
hermite_data::HermiteStencilData{T}, # Pre-allocated structure passed in
global_data::AbstractVector{Vector{T}},
neighbors::Vector{Int}, # adjl[eval_idx]
is_boundary::Vector{Bool},
boundary_conditions::Vector{BoundaryCondition{T}},
normals::Vector{Vector{T}},
global_to_boundary::Vector{Int},
) where {T}
k = length(neighbors)

# Fill local boundary info for each neighbor (in-place, no allocation)
@inbounds for local_idx in 1:k
global_idx = neighbors[local_idx]
hermite_data.data[local_idx] .= global_data[global_idx]
hermite_data.is_boundary[local_idx] = is_boundary[global_idx]

if is_boundary[global_idx]
boundary_idx = global_to_boundary[global_idx]
hermite_data.boundary_conditions[local_idx] = boundary_conditions[boundary_idx]
hermite_data.normals[local_idx] .= normals[boundary_idx]
else
# Set default Dirichlet for interior points (not used but keeps type consistency)
hermite_data.boundary_conditions[local_idx] = Dirichlet(T)
fill!(hermite_data.normals[local_idx], zero(T))
end
end

return nothing
end

# Trait types for dispatch
abstract type StencilType end
struct StandardStencil <: StencilType end
struct InternalStencil <: StencilType end
struct DirichletStencil <: StencilType end
struct HermiteStencil <: StencilType end

# Trait function to determine stencil type
stencil_type(boundary_info::Nothing) = StandardStencil()
stencil_type(boundary_info::HermiteBoundaryInfo) = any(boundary_info.is_boundary) ? HermiteStencil() : StandardStencil()

# Convenience function to check if any boundary points in stencil
has_boundary_points(boundary_info::Nothing) = false
has_boundary_points(boundary_info::HermiteBoundaryInfo) = any(boundary_info.is_boundary)
function stencil_type(
is_boundary::Vector{Bool},
boundary_conditions::Vector{BoundaryCondition},
eval_idx::Int,
neighbors::Vector{Int},
global_to_boundary::Vector{Int},
)
if sum(is_boundary[neighbors]) == 0
return InternalStencil()
elseif is_boundary[eval_idx] &&
is_dirichlet(boundary_conditions[global_to_boundary[eval_idx]])
return DirichletStencil()
else
return HermiteStencil()
end
end
2 changes: 1 addition & 1 deletion src/interpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function Interpolator(x, y, basis::B=PHS()) where {B<:AbstractRadialBasis}
mon = MonomialBasis(dim, basis.poly_deg)
data_type = promote_type(eltype(first(x)), eltype(y))
A = Symmetric(zeros(data_type, n, n))
_build_collocation_matrix!(A, x, basis, mon, k, StandardStencil())
_build_collocation_matrix!(A, x, basis, mon, k)
b = data_type[i < k ? y[i] : 0 for i in 1:n]
w = A \ b
return Interpolator(x, y, w[1:k], w[(k + 1):end], basis, mon)
Expand Down
96 changes: 11 additions & 85 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,92 +20,20 @@ end
function _build_weights(
data, eval_points, adjl, basis, ℒrbf, ℒmon, mon; batch_size=10, device=CPU()
)
TD = eltype(first(data))
dim = length(first(data)) # dimension of data
nmon = binomial(dim + basis.poly_deg, basis.poly_deg)
k = length(first(adjl)) # number of data in influence/support domain
sizes = (k, nmon)

# allocate arrays to build sparse matrix
Na = length(adjl)
I = zeros(Int, k * Na)
J = reduce(vcat, adjl)
V = zeros(TD, k * Na, _num_ops(ℒrbf))

# Calculate number of batches
n_batches = ceil(Int, Na / batch_size)

# Create kernel for building stencils in batches
@kernel function build_stencils_kernel(
I, J, V, data, eval_points, adjl, basis, ℒrbf, ℒmon, mon, k, batch_size, Na
)
# Get the batch index for this thread
batch_idx = @index(Global)

# Calculate the range of points for this batch
start_idx = (batch_idx - 1) * batch_size + 1
end_idx = min(batch_idx * batch_size, Na)

# Pre-allocate work arrays for this thread
n = k + nmon
A = Symmetric(zeros(TD, n, n), :U)
b = _prepare_b(ℒrbf, TD, n)

# Process each point in the batch sequentially
for i in start_idx:end_idx
# Set row indices for sparse matrix
for idx in 1:k
I[(i - 1) * k + idx] = i
end

# Get data points in the influence domain
local_data = [data[j] for j in adjl[i]]

# Build stencil and store in global weight matrix
stencil = _build_stencil!(
A, b, ℒrbf, ℒmon, local_data, eval_points[i], basis, mon, k, StandardStencil()
)

# Store the stencil weights in the value array
for op in axes(V, 2)
for idx in 1:k
V[(i - 1) * k + idx, op] = stencil[idx, op]
end
end
end
end

# Launch kernel with one thread per batch
kernel = build_stencils_kernel(device)
kernel(
I,
J,
V,
# Use the unified kernel infrastructure with standard allocation strategy
return _build_weights_unified(
StandardAllocation(),
data,
eval_points,
adjl,
basis,
ℒrbf,
ℒmon,
mon,
k,
batch_size,
Na;
ndrange=n_batches,
workgroupsize=1,
nothing;
batch_size=batch_size,
device=device,
)

# Wait for kernel to complete
KernelAbstractions.synchronize(device)

# Create and return sparse matrix/matrices
nrows = length(adjl)
ncols = length(data)
if size(V, 2) == 1
return sparse(I, J, V[:, 1], nrows, ncols)
else
return ntuple(i -> sparse(I, J, V[:, i], nrows, ncols), size(V, 2))
end
end

function _build_stencil!(
Expand All @@ -118,15 +46,14 @@ function _build_stencil!(
basis::B,
mon::MonomialBasis,
k::Int,
::StandardStencil
) where {TD,TE,B<:AbstractRadialBasis}
_build_collocation_matrix!(A, data, basis, mon, k, StandardStencil())
_build_rhs!(b, ℒrbf, ℒmon, data, eval_point, basis, k, StandardStencil())
_build_collocation_matrix!(A, data, basis, mon, k)
_build_rhs!(b, ℒrbf, ℒmon, data, eval_point, basis, k)
return (A \ b)[1:k, :]
end

function _build_collocation_matrix!(
A::Symmetric, data::AbstractVector, basis::B, mon::MonomialBasis{Dim,Deg}, k::K, ::StandardStencil
A::Symmetric, data::AbstractVector, basis::B, mon::MonomialBasis{Dim,Deg}, k::K
) where {B<:AbstractRadialBasis,K<:Int,Dim,Deg}
# radial basis section
AA = parent(A)
Expand All @@ -147,7 +74,7 @@ function _build_collocation_matrix!(
end

function _build_rhs!(
b, ℒrbf, ℒmon, data::AbstractVector{TD}, eval_point::TE, basis::B, k::K, ::StandardStencil
b, ℒrbf, ℒmon, data::AbstractVector{TD}, eval_point::TE, basis::B, k::K
) where {TD,TE,B<:AbstractRadialBasis,K<:Int}
# radial basis section
@inbounds for i in eachindex(data)
Expand All @@ -165,7 +92,7 @@ function _build_rhs!(
end

function _build_rhs!(
b, ℒrbf::Tuple, ℒmon::Tuple, data::AbstractVector{TD}, eval_point::TE, basis::B, k::K, ::StandardStencil
b, ℒrbf::Tuple, ℒmon::Tuple, data::AbstractVector{TD}, eval_point::TE, basis::B, k::K
) where {TD,TE,B<:AbstractRadialBasis,K<:Int}
@assert size(b, 2) == length(ℒrbf) == length(ℒmon) "b, ℒrbf, ℒmon must have the same length"
# radial basis section
Expand All @@ -186,4 +113,3 @@ function _build_rhs!(

return nothing
end

Loading
Loading