diff --git a/core/solver/cg.cpp b/core/solver/cg.cpp index 04fd3c4de22..23e904247ba 100644 --- a/core/solver/cg.cpp +++ b/core/solver/cg.cpp @@ -47,6 +47,108 @@ typename Cg::parameters_type Cg::parse( return params; } +template +void Cg::apply_mv(ptr_param b, + ptr_param x) const +{ + // @todo: need precision dispatch + auto dense_b = b->temporary_precision(); + auto dense_x = x->temporary_precision(); + + using std::swap; + constexpr uint8 RelativeStoppingId{1}; + + auto exec = this->get_executor(); + this->setup_workspace(); + + GKO_SOLVER_VECTOR(r, dense_b.get()); + GKO_SOLVER_VECTOR(z, dense_b.get()); + GKO_SOLVER_VECTOR(p, dense_b.get()); + GKO_SOLVER_VECTOR(q, dense_b.get()); + + GKO_SOLVER_SCALAR(beta, dense_b.get()); + GKO_SOLVER_SCALAR(prev_rho, dense_b.get()); + GKO_SOLVER_SCALAR(rho, dense_b.get()); + + GKO_SOLVER_ONE_MINUS_ONE(); + + bool one_changed{}; + GKO_SOLVER_STOP_REDUCTION_ARRAYS(); + + // r = dense_b + // rho = 0.0 + // prev_rho = 1.0 + // z = p = q = 0 + // @todo: I think the template keyword is necessary because some of these + // variables are defined via auto. + exec->run(cg::make_initialize( + dense_b->template create_local_view().get(), + r->template create_local_view().get(), + z->template create_local_view().get(), + p->template create_local_view().get(), + q->template create_local_view().get(), prev_rho, rho, + &stop_status)); + + this->get_system_matrix()->apply(neg_one_op, dense_x, one_op, r); + auto stop_criterion = this->get_stop_criterion_factory()->generate( + this->get_system_matrix(), + std::shared_ptr(dense_b.get(), [](const LinOp*) {}), + dense_x.get(), r); + + int iter = -1; + /* Memory movement summary: + * 18n * values + matrix/preconditioner storage + * 1x SpMV: 2n * values + storage + * 1x Preconditioner: 2n * values + storage + * 2x dot 4n + * 1x step 1 (axpy) 3n + * 1x step 2 (axpys) 6n + * 1x norm2 residual n + */ + while (true) { + // z = preconditioner * r + this->get_preconditioner()->apply(r, z); + // rho = dot(r, z) + r->compute_conj_dot(z, rho, reduction_tmp); + + ++iter; + bool all_stopped = + stop_criterion->update() + .num_iterations(iter) + .residual(r) + .implicit_sq_residual_norm(rho) + .solution(dense_x.get()) + .check(RelativeStoppingId, true, &stop_status, &one_changed); + this->template log( + this, dense_b.get(), dense_x.get(), iter, r, nullptr, rho, + &stop_status, all_stopped); + if (all_stopped) { + break; + } + + // tmp = rho / prev_rho + // p = z + tmp * p + exec->run( + cg::make_step_1(p->template create_local_view().get(), + z->template create_local_view().get(), + rho, prev_rho, &stop_status)); + // q = A * p + this->get_system_matrix()->apply(p, q); + // beta = dot(p, q) + p->compute_conj_dot(q, beta, reduction_tmp); + // tmp = rho / beta + // x = x + tmp * p + // r = r - tmp * q + exec->run(cg::make_step_2( + dense_x->template create_local_view().get(), + r->template create_local_view().get(), + p->template create_local_view().get(), + q->template create_local_view().get(), beta, rho, + &stop_status)); + swap(prev_rho, rho); + } +} + template std::unique_ptr Cg::transpose() const diff --git a/include/ginkgo/core/solver/cg.hpp b/include/ginkgo/core/solver/cg.hpp index 984d5d1f104..5a4744c87f3 100644 --- a/include/ginkgo/core/solver/cg.hpp +++ b/include/ginkgo/core/solver/cg.hpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -54,6 +54,7 @@ class Cg : public EnableLinOp>, public: using value_type = ValueType; using transposed_type = Cg; + using EnableLinOp::apply; std::unique_ptr transpose() const override; @@ -93,6 +94,9 @@ class Cg : public EnableLinOp>, const config::type_descriptor& td_for_child = config::make_type_descriptor()); + void apply_mv(ptr_param b, + ptr_param x) const; + protected: void apply_impl(const LinOp* b, LinOp* x) const override; diff --git a/reference/test/solver/cg_kernels.cpp b/reference/test/solver/cg_kernels.cpp index c4987bb5b17..be4a353a6f0 100644 --- a/reference/test/solver/cg_kernels.cpp +++ b/reference/test/solver/cg_kernels.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -226,6 +226,48 @@ TYPED_TEST(Cg, SolvesStencilSystem) } +TYPED_TEST(Cg, SolvesStencilSystemMultiVector) +{ + using Mtx = typename TestFixture::Mtx; + using value_type = typename TestFixture::value_type; + auto solver = + gko::solver::Cg::build() + .with_criteria(gko::stop::Iteration::build().with_max_iters(3u)) + .on(this->exec) + ->generate(this->mtx); + std::unique_ptr b = + gko::initialize({-1.0, 3.0, 1.0}, this->exec); + std::unique_ptr x = + gko::initialize({0.0, 0.0, 0.0}, this->exec); + + solver->apply_mv(b, x); + + GKO_ASSERT_MTX_NEAR(gko::as(x.get()), l({1.0, 3.0, 2.0}), + r::value); +} + +TYPED_TEST(Cg, SolvesStencilSystemMultiVectorMixed) +{ + using value_type = typename TestFixture::value_type; + using snd_value_type = gko::next_precision; + using Mtx = gko::matrix::Dense; + auto solver = + gko::solver::Cg::build() + .with_criteria(gko::stop::Iteration::build().with_max_iters(3u)) + .on(this->exec) + ->generate(this->mtx); + std::unique_ptr b = + gko::initialize({-1.0, 3.0, 1.0}, this->exec); + std::unique_ptr x = + gko::initialize({0.0, 0.0, 0.0}, this->exec); + + solver->apply_mv(b, x); + + GKO_ASSERT_MTX_NEAR(gko::as(x.get()), l({1.0, 3.0, 2.0}), + (r_mixed())); +} + + TYPED_TEST(Cg, SolvesStencilSystemMixed) { using value_type = gko::next_precision;