Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmark/test/reference/distributed_solver.matrix.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"cg::initialize": 1.0,
"advanced_apply(<typename>)": 1.0,
"dense::row_gather": 1.0,
"event::record_event": 1.0,
"csr::advanced_spmv": 1.0,
"dense::compute_squared_norm2": 1.0,
"dense::compute_sqrt": 1.0,
Expand Down
18 changes: 18 additions & 0 deletions benchmark/test/reference/distributed_solver.profile.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ DEBUG: end cg::initialize
DEBUG: begin advanced_apply(<typename>)
DEBUG: begin dense::row_gather
DEBUG: end dense::row_gather
DEBUG: begin event::record_event
DEBUG: end event::record_event
DEBUG: begin advanced_apply(<typename>)
DEBUG: begin csr::advanced_spmv
DEBUG: end csr::advanced_spmv
Expand Down Expand Up @@ -182,6 +184,8 @@ DEBUG: end cg::step_1
DEBUG: begin apply(<typename>)
DEBUG: begin dense::row_gather
DEBUG: end dense::row_gather
DEBUG: begin event::record_event
DEBUG: end event::record_event
DEBUG: begin apply(<typename>)
DEBUG: begin csr::spmv
DEBUG: end csr::spmv
Expand Down Expand Up @@ -222,6 +226,8 @@ DEBUG: end cg::step_1
DEBUG: begin apply(<typename>)
DEBUG: begin dense::row_gather
DEBUG: end dense::row_gather
DEBUG: begin event::record_event
DEBUG: end event::record_event
DEBUG: begin apply(<typename>)
DEBUG: begin csr::spmv
DEBUG: end csr::spmv
Expand Down Expand Up @@ -262,6 +268,8 @@ DEBUG: end cg::step_1
DEBUG: begin apply(<typename>)
DEBUG: begin dense::row_gather
DEBUG: end dense::row_gather
DEBUG: begin event::record_event
DEBUG: end event::record_event
DEBUG: begin apply(<typename>)
DEBUG: begin csr::spmv
DEBUG: end csr::spmv
Expand Down Expand Up @@ -302,6 +310,8 @@ DEBUG: end cg::step_1
DEBUG: begin apply(<typename>)
DEBUG: begin dense::row_gather
DEBUG: end dense::row_gather
DEBUG: begin event::record_event
DEBUG: end event::record_event
DEBUG: begin apply(<typename>)
DEBUG: begin csr::spmv
DEBUG: end csr::spmv
Expand Down Expand Up @@ -342,6 +352,8 @@ DEBUG: end cg::step_1
DEBUG: begin apply(<typename>)
DEBUG: begin dense::row_gather
DEBUG: end dense::row_gather
DEBUG: begin event::record_event
DEBUG: end event::record_event
DEBUG: begin apply(<typename>)
DEBUG: begin csr::spmv
DEBUG: end csr::spmv
Expand Down Expand Up @@ -382,6 +394,8 @@ DEBUG: end cg::step_1
DEBUG: begin apply(<typename>)
DEBUG: begin dense::row_gather
DEBUG: end dense::row_gather
DEBUG: begin event::record_event
DEBUG: end event::record_event
DEBUG: begin apply(<typename>)
DEBUG: begin csr::spmv
DEBUG: end csr::spmv
Expand Down Expand Up @@ -422,6 +436,8 @@ DEBUG: end cg::step_1
DEBUG: begin apply(<typename>)
DEBUG: begin dense::row_gather
DEBUG: end dense::row_gather
DEBUG: begin event::record_event
DEBUG: end event::record_event
DEBUG: begin apply(<typename>)
DEBUG: begin csr::spmv
DEBUG: end csr::spmv
Expand Down Expand Up @@ -463,6 +479,8 @@ DEBUG: end copy(<typename>)
DEBUG: begin advanced_apply(<typename>)
DEBUG: begin dense::row_gather
DEBUG: end dense::row_gather
DEBUG: begin event::record_event
DEBUG: end event::record_event
DEBUG: begin advanced_apply(<typename>)
DEBUG: begin csr::advanced_spmv
DEBUG: end csr::advanced_spmv
Expand Down
1 change: 1 addition & 0 deletions benchmark/test/reference/distributed_solver.simple.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"cg::initialize": 1.0,
"advanced_apply(<typename>)": 1.0,
"dense::row_gather": 1.0,
"event::record_event": 1.0,
"csr::advanced_spmv": 1.0,
"dense::compute_squared_norm2": 1.0,
"dense::compute_sqrt": 1.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"cg::initialize": 1.0,
"advanced_apply(<typename>)": 1.0,
"dense::row_gather": 1.0,
"event::record_event": 1.0,
"csr::advanced_spmv": 1.0,
"dense::compute_squared_norm2": 1.0,
"dense::compute_sqrt": 1.0,
Expand Down
2 changes: 2 additions & 0 deletions benchmark/test/reference/spmv_distributed.profile.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ DEBUG: begin repetition
DEBUG: begin apply(<typename>)
DEBUG: begin dense::row_gather
DEBUG: end dense::row_gather
DEBUG: begin event::record_event
DEBUG: end event::record_event
DEBUG: begin apply(<typename>)
DEBUG: begin csr::spmv
DEBUG: end csr::spmv
Expand Down
1 change: 1 addition & 0 deletions common/unified/components/fill_array_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ void fill_array(std::shared_ptr<const DefaultExecutor> exec, ValueType* array,

GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE(GKO_DECLARE_FILL_ARRAY_KERNEL);
template GKO_DECLARE_FILL_ARRAY_KERNEL(bool);
template GKO_DECLARE_FILL_ARRAY_KERNEL(char);
template GKO_DECLARE_FILL_ARRAY_KERNEL(uint16);
template GKO_DECLARE_FILL_ARRAY_KERNEL(uint32);
#ifndef GKO_SIZE_T_IS_UINT64_T
Expand Down
1 change: 1 addition & 0 deletions core/base/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ ValueType reduce_add(const array<ValueType>& input_arr,

GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE(GKO_DECLARE_ARRAY_FILL);
template GKO_DECLARE_ARRAY_FILL(bool);
template GKO_DECLARE_ARRAY_FILL(char);
template GKO_DECLARE_ARRAY_FILL(uint16);
template GKO_DECLARE_ARRAY_FILL(uint32);
#ifndef GKO_SIZE_T_IS_UINT64_T
Expand Down
36 changes: 36 additions & 0 deletions core/base/event.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// SPDX-FileCopyrightText: 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#ifndef GKO_CORE_BASE_EVENT_HPP_
#define GKO_CORE_BASE_EVENT_HPP_

#include <memory>

#include <ginkgo/core/base/event.hpp>
#include <ginkgo/core/base/executor.hpp>


namespace gko {


/**
* NotAsyncEvent is to provide an Event implementation on unsupported executor
* like reference. It will ensure the kernels are finished when recording this
* event.
*/
class NotAsyncEvent : public Event {
public:
NotAsyncEvent(std::shared_ptr<const Executor> exec) { exec->synchronize(); }

void synchronize() const override
{
// we have sync in the recording phase
}
};


} // namespace gko


#endif // #ifndef GKO_CORE_BASE_EVENT_HPP_
38 changes: 38 additions & 0 deletions core/base/event_kernels.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// SPDX-FileCopyrightText: 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#ifndef GKO_CORE_BASE_EVENT_KERNELS_HPP_
#define GKO_CORE_BASE_EVENT_KERNELS_HPP_


#include <memory>

#include <ginkgo/core/base/event.hpp>
#include <ginkgo/core/base/executor.hpp>

#include "core/base/kernel_declaration.hpp"


namespace gko {
namespace kernels {


#define GKO_DECLARE_EVENT_RECORD_EVENT \
void record_event(std::shared_ptr<const DefaultExecutor> exec, \
std::shared_ptr<const Event>& event)


#define GKO_DECLARE_ALL_AS_TEMPLATES GKO_DECLARE_EVENT_RECORD_EVENT


GKO_DECLARE_FOR_ALL_EXECUTOR_NAMESPACES(event, GKO_DECLARE_ALL_AS_TEMPLATES);


#undef GKO_DECLARE_ALL_AS_TEMPLATES


} // namespace kernels
} // namespace gko

#endif // GKO_CORE_BASE_EVENT_KERNELS_HPP_
4 changes: 2 additions & 2 deletions core/base/executor.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#include "ginkgo/core/base/executor.hpp"

#include <ginkgo/core/base/event.hpp>
#include <ginkgo/core/base/exception.hpp>
#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/name_demangling.hpp>


namespace gko {


Expand Down
11 changes: 11 additions & 0 deletions core/device_hooks/common_kernels.inc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "core/base/batch_instantiation.hpp"
#include "core/base/batch_multi_vector_kernels.hpp"
#include "core/base/device_matrix_data_kernels.hpp"
#include "core/base/event_kernels.hpp"
#include "core/base/index_set_kernels.hpp"
#include "core/base/mixed_precision_types.hpp"
#include "core/components/absolute_array_kernels.hpp"
Expand Down Expand Up @@ -253,6 +254,7 @@ template GKO_DECLARE_PREFIX_SUM_NONNEGATIVE_KERNEL(size_type);

GKO_STUB_TEMPLATE_TYPE(GKO_DECLARE_FILL_ARRAY_KERNEL);
template GKO_DECLARE_FILL_ARRAY_KERNEL(bool);
template GKO_DECLARE_FILL_ARRAY_KERNEL(char);
template GKO_DECLARE_FILL_ARRAY_KERNEL(uint16);
template GKO_DECLARE_FILL_ARRAY_KERNEL(uint32);
#ifndef GKO_SIZE_T_IS_UINT64_T
Expand Down Expand Up @@ -316,6 +318,15 @@ GKO_STUB_INDEX_TYPE(GKO_DECLARE_INDEX_SET_LOCAL_TO_GLOBAL_KERNEL);
} // namespace idx_set


namespace event {


GKO_STUB(GKO_DECLARE_EVENT_RECORD_EVENT);


}


namespace partition {


Expand Down
43 changes: 36 additions & 7 deletions core/distributed/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,9 +471,22 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
auto recv_ptr = mpi::requires_host_buffer(exec, comm)
? host_recv_vector.get()
: recv_vector.get();
auto req = this->row_gatherer_->apply_async(dense_b, recv_ptr);
local_mtx_->apply(dense_b->get_local_vector(), local_x);
req.wait();
if (dense_b->get_executor() ==
dense_b->get_executor()->get_master()) {
// reference and omp executor does not have event, so we still
// submit the mpi first.
auto req = this->row_gatherer_->apply_async(dense_b, recv_ptr);
local_mtx_->apply(dense_b->get_local_vector(), local_x);
req.wait();
} else {
// we use event here such that we can submit spmv job first
// without waiting for synchronization from the row gatherer.
auto ev = this->row_gatherer_->apply_prepare(dense_b, recv_ptr);
local_mtx_->apply(dense_b->get_local_vector(), local_x);
auto req =
this->row_gatherer_->apply_finalize(dense_b, recv_ptr, ev);
req.wait();
}

if (recv_ptr != recv_vector.get()) {
recv_vector->copy_from(host_recv_vector);
Expand Down Expand Up @@ -524,10 +537,26 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
auto recv_ptr = mpi::requires_host_buffer(exec, comm)
? host_recv_vector.get()
: recv_vector.get();
auto req = this->row_gatherer_->apply_async(dense_b, recv_ptr);
local_mtx_->apply(local_alpha.get(), dense_b->get_local_vector(),
local_beta.get(), local_x);
req.wait();
if (dense_b->get_executor() ==
dense_b->get_executor()->get_master()) {
// reference and omp executor does not have event, so we still
// submit the mpi first.
auto req = this->row_gatherer_->apply_async(dense_b, recv_ptr);
local_mtx_->apply(local_alpha.get(),
dense_b->get_local_vector(), local_beta.get(),
local_x);
req.wait();
} else {
// we use event here such that we can submit spmv job first
// without waiting for synchronization from the row gatherer.
auto ev = this->row_gatherer_->apply_prepare(dense_b, recv_ptr);
local_mtx_->apply(local_alpha.get(),
dense_b->get_local_vector(), local_beta.get(),
local_x);
auto req =
this->row_gatherer_->apply_finalize(dense_b, recv_ptr, ev);
req.wait();
}

if (recv_ptr != recv_vector.get()) {
recv_vector->copy_from(host_recv_vector);
Expand Down
Loading
Loading