Skip to content

Commit b4fa3f6

Browse files
authored
Merge pull request #45 from JuliaDiffEq/sampling
Sampling Methods for Data
2 parents de617a8 + 83703ea commit b4fa3f6

File tree

4 files changed

+110
-2
lines changed

4 files changed

+110
-2
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@ ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
1010
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"
1111
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
1212
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
13+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1314
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1415

1516
[compat]
1617
Compat = "2.2, 3.0"
17-
ModelingToolkit = "1.1.3"
18+
ModelingToolkit = "1.2.5"
1819
ProximalOperators = "0.10"
1920
QuadGK = "2.3.1"
21+
StatsBase = "0.32.0"
2022
julia = "1"
2123

2224
[extras]

src/DataDrivenDiffEq.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ module DataDrivenDiffEq
22

33
using LinearAlgebra
44
using ModelingToolkit
5-
using QuadGK, Statistics
5+
using QuadGK
6+
using Statistics
67
using Compat
78

89
abstract type abstractBasis end;
@@ -46,5 +47,6 @@ export ISInDy
4647
include("./utils.jl")
4748
export AIC, AICC, BIC
4849
export hankel, optimal_shrinkage, optimal_shrinkage!
50+
export burst_sampling, subsample
4951

5052
end # module

src/utils.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import StatsBase: sample
2+
13
# Model selection
24

35
# Taken from https://royalsocietypublishing.org/doi/pdf/10.1098/rspa.2017.0009
@@ -100,3 +102,72 @@ function optimal_shrinkage!(X::AbstractArray{T, 2}) where T <: Number
100102
X .= U*Diagonal(S)*V'
101103
return
102104
end
105+
106+
107+
@inline function burst_sampling(x::AbstractArray, samplesize::Int64, bursts::Int64)
108+
@assert size(x)[end] >= samplesize*bursts "Length of data array too small for subsampling of size $size!"
109+
inds = sample(1:size(x)[end]-samplesize, bursts, replace = false)
110+
inds = sort(unique(vcat([collect(i:i+samplesize) for i in inds]...)))
111+
return resample(x, inds)
112+
end
113+
114+
115+
@inline function burst_sampling(x::AbstractArray, y::AbstractArray, samplesize::Int64, bursts::Int64)
116+
@assert size(x)[end] >= samplesize*bursts "Length of data array too small for subsampling of size $size!"
117+
@assert size(x)[end] == size(y)[end]
118+
inds = sample(1:size(x)[end]-samplesize, bursts, replace = false)
119+
inds = sort(unique(vcat([collect(i:i+samplesize) for i in inds]...)))
120+
return resample(x, inds), resample(y, inds)
121+
end
122+
123+
124+
@inline function burst_sampling(x::AbstractArray, t::AbstractVector, period::T, bursts::Int64) where T <: AbstractFloat
125+
@assert period > zero(typeof(period)) "Sampling period has to be positive."
126+
@assert size(x)[end] == size(t)[end] "Provide consistent data."
127+
@assert bursts >= 1 "Number of bursts has to be positive."
128+
@assert t[end]-t[1]>= period*bursts "Bursting impossible. Please provide more data or reduce bursts."
129+
t_ids = zero(eltype(t)) .<= t .- period .<= t[end] .- 2*period
130+
samplesize = Int64(floor(period/(t[end]-t[1])*length(t)))
131+
inds = sample(collect(1:length(t))[t_ids], bursts, replace = false)
132+
inds = sort(unique(vcat([collect(i:i+samplesize) for i in inds]...)))
133+
return resample(x, inds), resample(t, inds)
134+
end
135+
136+
137+
@inline function subsample(x::AbstractVector, frequency::Int64)
138+
@assert frequency > 1
139+
return x[1:frequency:end]
140+
end
141+
142+
143+
@inline function subsample(x::AbstractArray, frequency::Int64)
144+
@assert frequency > 1
145+
return x[:, 1:frequency:end]
146+
end
147+
148+
@inline function subsample(x::AbstractArray, t::AbstractVector, period::T) where T <: AbstractFloat
149+
@assert period > zero(typeof(period)) "Sampling period has to be positive."
150+
@assert size(x)[end] == size(t)[end] "Provide consistent data."
151+
@assert t[end]-t[1]>= period "Subsampling impossible. Sampling period exceeds time window."
152+
idx = Int64[1]
153+
t_now = t[1]
154+
@inbounds for (i, t_current) in enumerate(t)
155+
if t_current - t_now >= period
156+
push!(idx, i)
157+
t_now = t_current
158+
end
159+
end
160+
return resample(x, idx), resample(t, idx)
161+
end
162+
163+
@inline function resample(x::AbstractArray{T,1}, indx::AbstractArray{Int64}) where T <: Number
164+
@assert maximum(indx) <= length(x)
165+
@assert minimum(indx) >= 1
166+
return x[indx]
167+
end
168+
169+
@inline function resample(x::AbstractArray{T,2}, indx::AbstractArray{Int64}) where T <: Number
170+
@assert maximum(indx) <= size(x, 2)
171+
@assert minimum(indx) >= 1
172+
return x[:, indx]
173+
end

test/runtests.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,4 +291,37 @@ end
291291
@test BIC(k, X, Y) == -2*log(sum(abs2, X -Y)) + k*log(size(X)[2])
292292
@test AICC(k, X, Y, likelyhood = (X,Y)->sum(abs, X-Y)) == AIC(k, X, Y, likelyhood = (X,Y)->sum(abs, X-Y))+ 2*(k+1)*(k+2)/(size(X)[2]-k-2)
293293

294+
295+
# Sampling
296+
X = randn(Float64, 2, 100)
297+
t = collect(0:0.1:9.99)
298+
Y = randn(size(X))
299+
xt = burst_sampling(X, 5, 10)
300+
@test 10 <= size(xt)[end] <= 60
301+
@test all([any(xi .≈ X) for xi in eachcol(xt)])
302+
xt, tt = burst_sampling(X, t, 5, 10)
303+
@test all(diff(tt) .> 0.0)
304+
@test size(xt)[end] == size(tt)[end]
305+
@test all([any(xi .≈ X) for xi in eachcol(xt)])
306+
@test !all([any(xi .≈ Y) for xi in eachcol(xt)])
307+
xs, ts = burst_sampling(X, t, 2.0, 1)
308+
@test all([any(xi .≈ X) for xi in eachcol(xs)])
309+
@test size(xs)[end] == size(ts)[end]
310+
@test ts[end]-ts[1] 2.0
311+
X2n = subsample(X, 2)
312+
t2n = subsample(t, 2)
313+
@test size(X2n)[end] == size(t2n)[end]
314+
@test size(X2n)[end] == Int(round(size(X)[end]/2))
315+
@test X2n[:, 1] == X[:, 1]
316+
@test X2n[:, end] == X[:, end-1]
317+
@test all([any(xi .≈ X) for xi in eachcol(X2n)])
318+
xs, ts = subsample(X, t, 0.5)
319+
@test size(xs)[end] == size(ts)[end]
320+
@test size(xs)[1] == size(X)[1]
321+
@test all(diff(ts) .≈ 0.5)
322+
# Loop this a few times to be sure its right
323+
@test_nowarn for i in 1:20
324+
xs, ts = burst_sampling(X, t, 2.0, 1)
325+
xs, ts = subsample(X, t, 0.5)
326+
end
294327
end

0 commit comments

Comments
 (0)