@@ -735,6 +735,22 @@ TYPED_TEST(Csr, MultipliesWithCsrMatrix)
735735
736736
737737TYPED_TEST (Csr, MultipliesReuseWithCsrMatrix)
738+ {
739+ using Vec = typename TestFixture::Vec;
740+ using T = typename TestFixture::value_type;
741+
742+ auto [result, _reuse] = this ->mtx ->multiply_reuse (this ->mtx3_unsorted );
743+
744+ auto expected = this ->mtx ->multiply (this ->mtx3_unsorted );
745+ ASSERT_EQ (result->get_size (), gko::dim<2 >(2 , 3 ));
746+ ASSERT_EQ (result->get_num_stored_elements (), 6 );
747+ ASSERT_TRUE (result->is_sorted_by_column_index ());
748+ GKO_ASSERT_MTX_EQ_SPARSITY (result, expected);
749+ GKO_ASSERT_MTX_NEAR (result, expected, 0 );
750+ }
751+
752+
753+ TYPED_TEST (Csr, MultipliesReuseUpdateWithCsrMatrix)
738754{
739755 using Vec = typename TestFixture::Vec;
740756 using T = typename TestFixture::value_type;
@@ -747,37 +763,10 @@ TYPED_TEST(Csr, MultipliesReuseWithCsrMatrix)
747763
748764 reuse.update_values (this ->mtx , this ->mtx3_unsorted , result);
749765
750- GKO_ASSERT_MTX_EQ_SPARSITY (result, orig_result);
751- ASSERT_EQ (result->get_size (), gko::dim<2 >(2 , 3 ));
752- ASSERT_EQ (result->get_num_stored_elements (), 6 );
766+ auto expected = this ->mtx ->multiply (this ->mtx3_unsorted );
753767 ASSERT_TRUE (result->is_sorted_by_column_index ());
754- auto r = result->get_const_row_ptrs ();
755- auto c = result->get_const_col_idxs ();
756- auto v = result->get_const_values ();
757- auto v2 = orig_result->get_const_values ();
758- // -26 -10 -62
759- // -30 -10 -80
760- EXPECT_EQ (r[0 ], 0 );
761- EXPECT_EQ (r[1 ], 3 );
762- EXPECT_EQ (r[2 ], 6 );
763- EXPECT_EQ (c[0 ], 0 );
764- EXPECT_EQ (c[1 ], 1 );
765- EXPECT_EQ (c[2 ], 2 );
766- EXPECT_EQ (c[3 ], 0 );
767- EXPECT_EQ (c[4 ], 1 );
768- EXPECT_EQ (c[5 ], 2 );
769- EXPECT_EQ (v[0 ], T{-26 });
770- EXPECT_EQ (v[1 ], T{-10 });
771- EXPECT_EQ (v[2 ], T{-62 });
772- EXPECT_EQ (v[3 ], T{-30 });
773- EXPECT_EQ (v[4 ], T{-10 });
774- EXPECT_EQ (v[5 ], T{-80 });
775- EXPECT_EQ (v2[0 ], T{13 });
776- EXPECT_EQ (v2[1 ], T{5 });
777- EXPECT_EQ (v2[2 ], T{31 });
778- EXPECT_EQ (v2[3 ], T{15 });
779- EXPECT_EQ (v2[4 ], T{5 });
780- EXPECT_EQ (v2[5 ], T{40 });
768+ GKO_ASSERT_MTX_EQ_SPARSITY (result, expected);
769+ GKO_ASSERT_MTX_NEAR (result, expected, 0 );
781770}
782771
783772
@@ -853,6 +842,24 @@ TYPED_TEST(Csr, MultiplyAddsWithCsrMatrices)
853842
854843
855844TYPED_TEST (Csr, MultiplyAddsReuseWithCsrMatrices)
845+ {
846+ using Vec = typename TestFixture::Vec;
847+ using T = typename TestFixture::value_type;
848+ auto alpha = gko::initialize<Vec>({-1.0 }, this ->exec );
849+ auto beta = gko::initialize<Vec>({2.0 }, this ->exec );
850+
851+ auto [result, reuse] = this ->mtx ->multiply_add_reuse (
852+ alpha, this ->mtx3_unsorted , beta, this ->mtx2 );
853+
854+ auto expected =
855+ this ->mtx ->multiply_add (alpha, this ->mtx3_unsorted , beta, this ->mtx2 );
856+ ASSERT_TRUE (result->is_sorted_by_column_index ());
857+ GKO_ASSERT_MTX_EQ_SPARSITY (result, expected);
858+ GKO_ASSERT_MTX_NEAR (result, expected, 0 );
859+ }
860+
861+
862+ TYPED_TEST (Csr, MultiplyAddsReuseUpdateWithCsrMatrices)
856863{
857864 using Vec = typename TestFixture::Vec;
858865 using T = typename TestFixture::value_type;
@@ -868,50 +875,17 @@ TYPED_TEST(Csr, MultiplyAddsReuseWithCsrMatrices)
868875 T{});
869876 std::fill_n (zero_mtx3->get_values (), zero_mtx3->get_num_stored_elements (),
870877 T{});
871-
872- auto [zero_result, reuse] =
878+ auto [result, reuse] =
873879 zero_mtx->multiply_add_reuse (zero, zero_mtx3, zero, zero_mtx2);
874- // we want to test that both the initial result and the updated result are
875- // correct!
876- auto [result, _reuse] = this ->mtx ->multiply_add_reuse (
877- alpha, this ->mtx3_unsorted , beta, this ->mtx2 );
878-
879- ASSERT_EQ (result->get_size (), gko::dim<2 >(2 , 3 ));
880- ASSERT_EQ (result->get_num_stored_elements (), 6 );
881- ASSERT_TRUE (result->is_sorted_by_column_index ());
882- GKO_ASSERT_MTX_EQ_SPARSITY (zero_result, result);
883- auto r = zero_result->get_const_row_ptrs ();
884- auto c = zero_result->get_const_col_idxs ();
885- auto v = zero_result->get_const_values ();
886- auto v2 = result->get_const_values ();
887- EXPECT_EQ (r[0 ], 0 );
888- EXPECT_EQ (r[1 ], 3 );
889- EXPECT_EQ (r[2 ], 6 );
890- EXPECT_EQ (c[0 ], 0 );
891- EXPECT_EQ (c[1 ], 1 );
892- EXPECT_EQ (c[2 ], 2 );
893- EXPECT_EQ (c[3 ], 0 );
894- EXPECT_EQ (c[4 ], 1 );
895- EXPECT_EQ (c[5 ], 2 );
896- // the values should be 0
897- EXPECT_EQ (v[0 ], T{});
898- EXPECT_EQ (v[1 ], T{});
899- EXPECT_EQ (v[2 ], T{});
900- EXPECT_EQ (v[3 ], T{});
901- EXPECT_EQ (v[4 ], T{});
902- EXPECT_EQ (v[5 ], T{});
903880
904881 reuse.update_values (this ->mtx , alpha, this ->mtx3_sorted , beta, this ->mtx2 ,
905882 result);
906883
907- // -11 1 -27
908- // -15 5 -40
909- EXPECT_EQ (v2[0 ], T{-11 });
910- EXPECT_EQ (v2[1 ], T{1 });
911- EXPECT_EQ (v2[2 ], T{-27 });
912- EXPECT_EQ (v2[3 ], T{-15 });
913- EXPECT_EQ (v2[4 ], T{5 });
914- EXPECT_EQ (v2[5 ], T{-40 });
884+ auto expected =
885+ this ->mtx ->multiply_add (alpha, this ->mtx3_sorted , beta, this ->mtx2 );
886+ ASSERT_TRUE (result->is_sorted_by_column_index ());
887+ GKO_ASSERT_MTX_EQ_SPARSITY (result, expected);
888+ GKO_ASSERT_MTX_NEAR (result, expected, 0 );
915889}
916890
917891
@@ -979,6 +953,32 @@ TYPED_TEST(Csr, AddsScaleCsrMatrices)
979953
980954
981955TYPED_TEST (Csr, AddsScaleReuseCsrMatrices)
956+ {
957+ using T = typename TestFixture::value_type;
958+ using Vec = typename TestFixture::Vec;
959+ using Mtx = typename TestFixture::Mtx;
960+ auto alpha = gko::initialize<Vec>({-3.0 }, this ->exec );
961+ auto beta = gko::initialize<Vec>({2.0 }, this ->exec );
962+ auto a = gko::initialize<Mtx>(
963+ {I<T>{2.0 , 0.0 , 3.0 }, I<T>{0.0 , 1.0 , -1.5 }, I<T>{0.0 , -2.0 , 0.0 },
964+ I<T>{5.0 , 0.0 , 0.0 }, I<T>{1.0 , 0.0 , 4.0 }, I<T>{2.0 , -2.0 , 0.0 },
965+ I<T>{0.0 , 0.0 , 0.0 }},
966+ this ->exec );
967+ auto b = gko::initialize<Mtx>(
968+ {I<T>{2.0 , -2.0 , 0.0 }, I<T>{1.0 , 0.0 , 4.0 }, I<T>{2.0 , 0.0 , 3.0 },
969+ I<T>{0.0 , 1.0 , -1.5 }, I<T>{1.0 , 0.0 , 0.0 }, I<T>{0.0 , 0.0 , 0.0 },
970+ I<T>{0.0 , 0.0 , 0.0 }},
971+ this ->exec );
972+
973+ auto [result, reuse] = a->add_scale_reuse (alpha, beta, b);
974+
975+ auto expected = a->add_scale (alpha, beta, b);
976+ GKO_ASSERT_MTX_EQ_SPARSITY (expected, result);
977+ ASSERT_TRUE (result->is_sorted_by_column_index ());
978+ }
979+
980+
981+ TYPED_TEST (Csr, AddsScaleReuseUpdateCsrMatrices)
982982{
983983 using T = typename TestFixture::value_type;
984984 using Vec = typename TestFixture::Vec;
@@ -1000,20 +1000,12 @@ TYPED_TEST(Csr, AddsScaleReuseCsrMatrices)
10001000 auto zero = gko::initialize<Vec>({0.0 }, this ->exec );
10011001 std::fill_n (zero_a->get_values (), a->get_num_stored_elements (), T{});
10021002 std::fill_n (zero_b->get_values (), b->get_num_stored_elements (), T{});
1003- auto expect = gko::initialize<Mtx>(
1004- {I<T>{-2.0 , -4.0 , -9.0 }, I<T>{2.0 , -3.0 , 12.5 }, I<T>{4.0 , 6.0 , 6.0 },
1005- I<T>{-15.0 , 2.0 , -3.0 }, I<T>{-1.0 , 0.0 , -12.0 }, I<T>{-6.0 , 6.0 , 0.0 },
1006- I<T>{0.0 , 0.0 , 0.0 }},
1007- this ->exec );
1008-
1003+ auto expected = a->add_scale (alpha, beta, b);
10091004 auto [result, reuse] = zero_a->add_scale_reuse (zero, zero, zero_b);
10101005
1011- GKO_ASSERT_MTX_EQ_SPARSITY (expect, result);
1012- ASSERT_TRUE (result->is_sorted_by_column_index ());
1013-
10141006 reuse.update_values (alpha, a, beta, b, result);
10151007
1016- GKO_ASSERT_MTX_NEAR (result, expect , r<T>::value);
1008+ GKO_ASSERT_MTX_NEAR (result, expected , r<T>::value);
10171009}
10181010
10191011
@@ -1029,8 +1021,9 @@ TYPED_TEST(Csr, MultiplyReuseFailsOnWrongDimensions)
10291021 auto & m1 = this ->mtx ;
10301022 auto & m2 = this ->mtx3_sorted ;
10311023 auto m2_nnz = Mtx::create (this ->exec , m2->get_size ());
1032- ASSERT_THROW (m1->multiply_reuse (m1), gko::DimensionMismatch);
10331024 auto [m3, reuse] = m1->multiply_reuse (m2);
1025+
1026+ ASSERT_THROW (m1->multiply_reuse (m1), gko::DimensionMismatch);
10341027 // Check for mismatching dimensions or nnz for every parameter
10351028 ASSERT_THROW (reuse.update_values (m2, m2, m3), gko::DimensionMismatch);
10361029 ASSERT_THROW (reuse.update_values (m3, m2, m3), gko::ValueMismatch);
@@ -1049,6 +1042,7 @@ TYPED_TEST(Csr, MultiplyAddFailsOnWrongDimensions)
10491042 auto & m1 = this ->mtx ;
10501043 auto & m2 = this ->mtx2 ;
10511044 auto & m3 = this ->mtx3_sorted ;
1045+
10521046 ASSERT_THROW (m1->multiply_add (v, m3, s, m2), gko::DimensionMismatch);
10531047 ASSERT_THROW (m1->multiply_add (s, m1, s, m2), gko::DimensionMismatch);
10541048 ASSERT_THROW (m1->multiply_add (s, m3, v, m2), gko::DimensionMismatch);
@@ -1066,11 +1060,12 @@ TYPED_TEST(Csr, MultiplyAddReuseFailsOnWrongDimensions)
10661060 auto & m2 = this ->mtx2 ;
10671061 auto & m3 = this ->mtx3_sorted ;
10681062 auto m3_nnz = Mtx::create (this ->exec , m3->get_size ());
1063+ auto [m4, r] = m1->multiply_add_reuse (s, m3, s, m2);
1064+
10691065 ASSERT_THROW (m1->multiply_add_reuse (v, m3, s, m2), gko::DimensionMismatch);
10701066 ASSERT_THROW (m1->multiply_add_reuse (s, m2, s, m2), gko::DimensionMismatch);
10711067 ASSERT_THROW (m1->multiply_add_reuse (s, m3, v, m2), gko::DimensionMismatch);
10721068 ASSERT_THROW (m1->multiply_add_reuse (s, m3, s, m3), gko::DimensionMismatch);
1073- auto [m4, r] = m1->multiply_add_reuse (s, m3, s, m2);
10741069 ASSERT_THROW (r.update_values (m3, s, m3, s, m2, m4), gko::DimensionMismatch);
10751070 ASSERT_THROW (r.update_values (m2, s, m3, s, m2, m4), gko::ValueMismatch);
10761071 ASSERT_THROW (r.update_values (m1, v, m3, s, m2, m4), gko::DimensionMismatch);
@@ -1092,6 +1087,7 @@ TYPED_TEST(Csr, AddScaleFailsOnWrongDimensions)
10921087 auto & m1 = this ->mtx ;
10931088 auto & m2 = this ->mtx2 ;
10941089 auto & m3 = this ->mtx3_sorted ;
1090+
10951091 ASSERT_THROW (m1->add_scale (v, s, m2), gko::DimensionMismatch);
10961092 ASSERT_THROW (m1->add_scale (s, v, m2), gko::DimensionMismatch);
10971093 ASSERT_THROW (m1->add_scale (s, s, m3), gko::DimensionMismatch);
@@ -1107,10 +1103,11 @@ TYPED_TEST(Csr, AddScaleReuseFailsOnWrongDimensions)
11071103 auto & m1 = this ->mtx ;
11081104 auto & m2 = this ->mtx2 ;
11091105 auto & m3 = this ->mtx3_sorted ;
1106+ auto [m4, r] = m1->add_scale_reuse (s, s, m2);
1107+
11101108 ASSERT_THROW (m1->add_scale_reuse (v, s, m2), gko::DimensionMismatch);
11111109 ASSERT_THROW (m1->add_scale_reuse (s, v, m2), gko::DimensionMismatch);
11121110 ASSERT_THROW (m1->add_scale_reuse (s, s, m3), gko::DimensionMismatch);
1113- auto [m4, r] = m1->add_scale_reuse (s, s, m2);
11141111 ASSERT_THROW (r.update_values (v, m1, s, m2, m4), gko::DimensionMismatch);
11151112 ASSERT_THROW (r.update_values (s, m3, s, m2, m4), gko::DimensionMismatch);
11161113 ASSERT_THROW (r.update_values (s, m2, s, m2, m4), gko::ValueMismatch);
0 commit comments