Skip to content

Commit d2288f9

Browse files
committed
fixup! test using multivector in cg
1 parent a72fedb commit d2288f9

File tree

2 files changed

+49
-26
lines changed

2 files changed

+49
-26
lines changed

core/solver/cg.cpp

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,23 @@ void Cg<ValueType>::apply(const matrix::MultiVector* b,
5252
matrix::MultiVector* x) const
5353
{
5454
// @todo: need precision dispatch
55+
auto dense_b = b->temporary_precision<ValueType>();
56+
auto dense_x = x->temporary_precision<ValueType>();
5557

5658
using std::swap;
5759
constexpr uint8 RelativeStoppingId{1};
5860

5961
auto exec = this->get_executor();
6062
this->setup_workspace();
6163

62-
auto dense_b = b.get();
63-
GKO_SOLVER_VECTOR(r, dense_b);
64-
GKO_SOLVER_VECTOR(z, dense_b);
65-
GKO_SOLVER_VECTOR(p, dense_b);
66-
GKO_SOLVER_VECTOR(q, dense_b);
64+
GKO_SOLVER_VECTOR(r, dense_b.get());
65+
GKO_SOLVER_VECTOR(z, dense_b.get());
66+
GKO_SOLVER_VECTOR(p, dense_b.get());
67+
GKO_SOLVER_VECTOR(q, dense_b.get());
6768

68-
GKO_SOLVER_SCALAR(beta, dense_b);
69-
GKO_SOLVER_SCALAR(prev_rho, dense_b);
70-
GKO_SOLVER_SCALAR(rho, dense_b);
69+
GKO_SOLVER_SCALAR(beta, dense_b.get());
70+
GKO_SOLVER_SCALAR(prev_rho, dense_b.get());
71+
GKO_SOLVER_SCALAR(rho, dense_b.get());
7172

7273
GKO_SOLVER_ONE_MINUS_ONE();
7374

@@ -80,18 +81,19 @@ void Cg<ValueType>::apply(const matrix::MultiVector* b,
8081
// z = p = q = 0
8182
// @todo: I think the template keyword is necessary because some of these
8283
// variables are defined via auto.
83-
exec->run(
84-
cg::make_initialize(b->template create_local_view<ValueType>().get(),
85-
r->template create_local_view<ValueType>().get(),
86-
z->template create_local_view<ValueType>().get(),
87-
p->template create_local_view<ValueType>().get(),
88-
q->template create_local_view<ValueType>().get(),
89-
prev_rho, rho, &stop_status));
84+
exec->run(cg::make_initialize(
85+
dense_b->template create_local_view<ValueType>().get(),
86+
r->template create_local_view<ValueType>().get(),
87+
z->template create_local_view<ValueType>().get(),
88+
p->template create_local_view<ValueType>().get(),
89+
q->template create_local_view<ValueType>().get(), prev_rho, rho,
90+
&stop_status));
9091

91-
this->get_system_matrix()->apply(neg_one_op, x, one_op, r);
92+
this->get_system_matrix()->apply(neg_one_op, dense_x, one_op, r);
9293
auto stop_criterion = this->get_stop_criterion_factory()->generate(
9394
this->get_system_matrix(),
94-
std::shared_ptr<const LinOp>(dense_b, [](const LinOp*) {}), x.get(), r);
95+
std::shared_ptr<const LinOp>(dense_b.get(), [](const LinOp*) {}),
96+
dense_x.get(), r);
9597

9698
int iter = -1;
9799
/* Memory movement summary:
@@ -115,11 +117,11 @@ void Cg<ValueType>::apply(const matrix::MultiVector* b,
115117
.num_iterations(iter)
116118
.residual(r)
117119
.implicit_sq_residual_norm(rho)
118-
.solution(x.get())
120+
.solution(dense_x.get())
119121
.check(RelativeStoppingId, true, &stop_status, &one_changed);
120122
this->template log<log::Logger::iteration_complete>(
121-
this, dense_b, x.get(), iter, r, nullptr, rho, &stop_status,
122-
all_stopped);
123+
this, dense_b.get(), dense_x.get(), iter, r, nullptr, rho,
124+
&stop_status, all_stopped);
123125
if (all_stopped) {
124126
break;
125127
}
@@ -137,12 +139,12 @@ void Cg<ValueType>::apply(const matrix::MultiVector* b,
137139
// tmp = rho / beta
138140
// x = x + tmp * p
139141
// r = r - tmp * q
140-
exec->run(
141-
cg::make_step_2(x->template create_local_view<ValueType>().get(),
142-
r->template create_local_view<ValueType>().get(),
143-
p->template create_local_view<ValueType>().get(),
144-
q->template create_local_view<ValueType>().get(),
145-
beta, rho, &stop_status));
142+
exec->run(cg::make_step_2(
143+
dense_x->template create_local_view<ValueType>().get(),
144+
r->template create_local_view<ValueType>().get(),
145+
p->template create_local_view<ValueType>().get(),
146+
q->template create_local_view<ValueType>().get(), beta, rho,
147+
&stop_status));
146148
swap(prev_rho, rho);
147149
}
148150
}

reference/test/solver/cg_kernels.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,27 @@ TYPED_TEST(Cg, SolvesStencilSystemMultiVector)
246246
r<value_type>::value);
247247
}
248248

249+
TYPED_TEST(Cg, SolvesStencilSystemMultiVectorMixed)
250+
{
251+
using value_type = typename TestFixture::value_type;
252+
using snd_value_type = gko::next_precision<value_type>;
253+
using Mtx = gko::matrix::Dense<snd_value_type>;
254+
auto solver =
255+
gko::solver::Cg<value_type>::build()
256+
.with_criteria(gko::stop::Iteration::build().with_max_iters(3u))
257+
.on(this->exec)
258+
->generate(this->mtx);
259+
std::unique_ptr<gko::matrix::MultiVector> b =
260+
gko::initialize<Mtx>({-1.0, 3.0, 1.0}, this->exec);
261+
std::unique_ptr<gko::matrix::MultiVector> x =
262+
gko::initialize<Mtx>({0.0, 0.0, 0.0}, this->exec);
263+
264+
solver->apply(b.get(), x.get());
265+
266+
GKO_ASSERT_MTX_NEAR(gko::as<Mtx>(x.get()), l({1.0, 3.0, 2.0}),
267+
(r_mixed<value_type, TypeParam>()));
268+
}
269+
249270

250271
TYPED_TEST(Cg, SolvesStencilSystemMixed)
251272
{

0 commit comments

Comments
 (0)