Skip to content

Commit bdeb418

Browse files
Merge pull request #53 from chriselrod/supportmorethreadsthantrajectories
Support more threads than trajectories.
2 parents 130c723 + 6f23b72 commit bdeb418

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ jobs:
3333
- uses: julia-actions/julia-runtest@v1
3434
env:
3535
GROUP: ${{ matrix.group }}
36+
JULIA_NUM_THREADS: 11
3637
- uses: julia-actions/julia-processcoverage@v1
3738
- uses: codecov/codecov-action@v1
3839
with:

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SciMLBase"
22
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
33
authors = ["Chris Rackauckas <[email protected]> and contributors"]
4-
version = "1.11.3"
4+
version = "1.11.4"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/ensemble/basic_ensemble_solve.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,22 +199,21 @@ function SciMLBase.solve_batch(prob,alg,::EnsembleSerial,II,pmap_batch_size;kwar
199199
end
200200

201201
function solve_batch(prob,alg,ensemblealg::EnsembleThreads,II,pmap_batch_size;kwargs...)
202-
203-
if length(II) == 1 || Threads.nthreads() == 1
202+
nthreads = min(Threads.nthreads(), length(II))
203+
if length(II) == 1 || nthreads == 1
204204
return solve_batch(prob,alg,EnsembleSerial(),II,pmap_batch_size;kwargs...)
205205
end
206206

207207
if typeof(prob.prob) <: AbstractJumpProblem && length(II) != 1
208-
probs = [deepcopy(prob.prob) for i in 1:Threads.nthreads()]
208+
probs = [deepcopy(prob.prob) for i in 1:nthreads]
209209
else
210210
probs = prob.prob
211211
end
212212

213213
#
214-
batch_size = length(II)÷Threads.nthreads()
215-
216-
batch_data = tmap(1:Threads.nthreads()) do i
217-
if i == Threads.nthreads()
214+
batch_size = length(II)÷nthreads
215+
batch_data = tmap(1:nthreads) do i
216+
if i == nthreads
218217
I_local = II[(batch_size*(i-1)+1):end]
219218
else
220219
I_local = II[(batch_size*(i-1)+1):(batch_size*i)]

0 commit comments

Comments
 (0)