@@ -52,22 +52,23 @@ void Cg<ValueType>::apply(const matrix::MultiVector* b,
5252 matrix::MultiVector* x) const
5353{
5454 // @todo: need precision dispatch
55+ auto dense_b = b->temporary_precision <ValueType>();
56+ auto dense_x = x->temporary_precision <ValueType>();
5557
5658 using std::swap;
5759 constexpr uint8 RelativeStoppingId{1 };
5860
5961 auto exec = this ->get_executor ();
6062 this ->setup_workspace ();
6163
62- auto dense_b = b.get ();
63- GKO_SOLVER_VECTOR (r, dense_b);
64- GKO_SOLVER_VECTOR (z, dense_b);
65- GKO_SOLVER_VECTOR (p, dense_b);
66- GKO_SOLVER_VECTOR (q, dense_b);
64+ GKO_SOLVER_VECTOR (r, dense_b.get ());
65+ GKO_SOLVER_VECTOR (z, dense_b.get ());
66+ GKO_SOLVER_VECTOR (p, dense_b.get ());
67+ GKO_SOLVER_VECTOR (q, dense_b.get ());
6768
68- GKO_SOLVER_SCALAR (beta, dense_b);
69- GKO_SOLVER_SCALAR (prev_rho, dense_b);
70- GKO_SOLVER_SCALAR (rho, dense_b);
69+ GKO_SOLVER_SCALAR (beta, dense_b. get () );
70+ GKO_SOLVER_SCALAR (prev_rho, dense_b. get () );
71+ GKO_SOLVER_SCALAR (rho, dense_b. get () );
7172
7273 GKO_SOLVER_ONE_MINUS_ONE ();
7374
@@ -80,18 +81,19 @@ void Cg<ValueType>::apply(const matrix::MultiVector* b,
8081 // z = p = q = 0
8182 // @todo: I think the template keyword is necessary because some of these
8283 // variables are defined via auto.
83- exec->run (
84- cg::make_initialize (b ->template create_local_view <ValueType>().get (),
85- r->template create_local_view <ValueType>().get (),
86- z->template create_local_view <ValueType>().get (),
87- p->template create_local_view <ValueType>().get (),
88- q->template create_local_view <ValueType>().get (),
89- prev_rho, rho, &stop_status));
84+ exec->run (cg::make_initialize (
85+ dense_b ->template create_local_view <ValueType>().get (),
86+ r->template create_local_view <ValueType>().get (),
87+ z->template create_local_view <ValueType>().get (),
88+ p->template create_local_view <ValueType>().get (),
89+ q->template create_local_view <ValueType>().get (), prev_rho, rho ,
90+ &stop_status));
9091
91- this ->get_system_matrix ()->apply (neg_one_op, x , one_op, r);
92+ this ->get_system_matrix ()->apply (neg_one_op, dense_x , one_op, r);
9293 auto stop_criterion = this ->get_stop_criterion_factory ()->generate (
9394 this ->get_system_matrix (),
94- std::shared_ptr<const LinOp>(dense_b, [](const LinOp*) {}), x.get (), r);
95+ std::shared_ptr<const LinOp>(dense_b.get (), [](const LinOp*) {}),
96+ dense_x.get (), r);
9597
9698 int iter = -1 ;
9799 /* Memory movement summary:
@@ -115,11 +117,11 @@ void Cg<ValueType>::apply(const matrix::MultiVector* b,
115117 .num_iterations (iter)
116118 .residual (r)
117119 .implicit_sq_residual_norm (rho)
118- .solution (x .get ())
120+ .solution (dense_x .get ())
119121 .check (RelativeStoppingId, true , &stop_status, &one_changed);
120122 this ->template log <log::Logger::iteration_complete>(
121- this , dense_b, x .get (), iter, r, nullptr , rho, &stop_status ,
122- all_stopped);
123+ this , dense_b. get (), dense_x .get (), iter, r, nullptr , rho,
124+ &stop_status, all_stopped);
123125 if (all_stopped) {
124126 break ;
125127 }
@@ -137,12 +139,12 @@ void Cg<ValueType>::apply(const matrix::MultiVector* b,
137139 // tmp = rho / beta
138140 // x = x + tmp * p
139141 // r = r - tmp * q
140- exec->run (
141- cg::make_step_2 (x ->template create_local_view <ValueType>().get (),
142- r->template create_local_view <ValueType>().get (),
143- p->template create_local_view <ValueType>().get (),
144- q->template create_local_view <ValueType>().get (),
145- beta, rho, &stop_status));
142+ exec->run (cg::make_step_2 (
143+ dense_x ->template create_local_view <ValueType>().get (),
144+ r->template create_local_view <ValueType>().get (),
145+ p->template create_local_view <ValueType>().get (),
146+ q->template create_local_view <ValueType>().get (), beta, rho ,
147+ &stop_status));
146148 swap (prev_rho, rho);
147149 }
148150}
0 commit comments