Skip to content

Commit b015950

Browse files
committed
improve test structure
1 parent 8ec67e5 commit b015950

File tree

2 files changed

+96
-99
lines changed

2 files changed

+96
-99
lines changed

reference/test/matrix/csr_kernels.cpp

Lines changed: 79 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,22 @@ TYPED_TEST(Csr, MultipliesWithCsrMatrix)
735735

736736

737737
TYPED_TEST(Csr, MultipliesReuseWithCsrMatrix)
738+
{
739+
using Vec = typename TestFixture::Vec;
740+
using T = typename TestFixture::value_type;
741+
742+
auto [result, _reuse] = this->mtx->multiply_reuse(this->mtx3_unsorted);
743+
744+
auto expected = this->mtx->multiply(this->mtx3_unsorted);
745+
ASSERT_EQ(result->get_size(), gko::dim<2>(2, 3));
746+
ASSERT_EQ(result->get_num_stored_elements(), 6);
747+
ASSERT_TRUE(result->is_sorted_by_column_index());
748+
GKO_ASSERT_MTX_EQ_SPARSITY(result, expected);
749+
GKO_ASSERT_MTX_NEAR(result, expected, 0);
750+
}
751+
752+
753+
TYPED_TEST(Csr, MultipliesReuseUpdateWithCsrMatrix)
738754
{
739755
using Vec = typename TestFixture::Vec;
740756
using T = typename TestFixture::value_type;
@@ -747,37 +763,10 @@ TYPED_TEST(Csr, MultipliesReuseWithCsrMatrix)
747763

748764
reuse.update_values(this->mtx, this->mtx3_unsorted, result);
749765

750-
GKO_ASSERT_MTX_EQ_SPARSITY(result, orig_result);
751-
ASSERT_EQ(result->get_size(), gko::dim<2>(2, 3));
752-
ASSERT_EQ(result->get_num_stored_elements(), 6);
766+
auto expected = this->mtx->multiply(this->mtx3_unsorted);
753767
ASSERT_TRUE(result->is_sorted_by_column_index());
754-
auto r = result->get_const_row_ptrs();
755-
auto c = result->get_const_col_idxs();
756-
auto v = result->get_const_values();
757-
auto v2 = orig_result->get_const_values();
758-
// -26 -10 -62
759-
// -30 -10 -80
760-
EXPECT_EQ(r[0], 0);
761-
EXPECT_EQ(r[1], 3);
762-
EXPECT_EQ(r[2], 6);
763-
EXPECT_EQ(c[0], 0);
764-
EXPECT_EQ(c[1], 1);
765-
EXPECT_EQ(c[2], 2);
766-
EXPECT_EQ(c[3], 0);
767-
EXPECT_EQ(c[4], 1);
768-
EXPECT_EQ(c[5], 2);
769-
EXPECT_EQ(v[0], T{-26});
770-
EXPECT_EQ(v[1], T{-10});
771-
EXPECT_EQ(v[2], T{-62});
772-
EXPECT_EQ(v[3], T{-30});
773-
EXPECT_EQ(v[4], T{-10});
774-
EXPECT_EQ(v[5], T{-80});
775-
EXPECT_EQ(v2[0], T{13});
776-
EXPECT_EQ(v2[1], T{5});
777-
EXPECT_EQ(v2[2], T{31});
778-
EXPECT_EQ(v2[3], T{15});
779-
EXPECT_EQ(v2[4], T{5});
780-
EXPECT_EQ(v2[5], T{40});
768+
GKO_ASSERT_MTX_EQ_SPARSITY(result, expected);
769+
GKO_ASSERT_MTX_NEAR(result, expected, 0);
781770
}
782771

783772

@@ -853,6 +842,24 @@ TYPED_TEST(Csr, MultiplyAddsWithCsrMatrices)
853842

854843

855844
TYPED_TEST(Csr, MultiplyAddsReuseWithCsrMatrices)
845+
{
846+
using Vec = typename TestFixture::Vec;
847+
using T = typename TestFixture::value_type;
848+
auto alpha = gko::initialize<Vec>({-1.0}, this->exec);
849+
auto beta = gko::initialize<Vec>({2.0}, this->exec);
850+
851+
auto [result, reuse] = this->mtx->multiply_add_reuse(
852+
alpha, this->mtx3_unsorted, beta, this->mtx2);
853+
854+
auto expected =
855+
this->mtx->multiply_add(alpha, this->mtx3_unsorted, beta, this->mtx2);
856+
ASSERT_TRUE(result->is_sorted_by_column_index());
857+
GKO_ASSERT_MTX_EQ_SPARSITY(result, expected);
858+
GKO_ASSERT_MTX_NEAR(result, expected, 0);
859+
}
860+
861+
862+
TYPED_TEST(Csr, MultiplyAddsReuseUpdateWithCsrMatrices)
856863
{
857864
using Vec = typename TestFixture::Vec;
858865
using T = typename TestFixture::value_type;
@@ -868,50 +875,17 @@ TYPED_TEST(Csr, MultiplyAddsReuseWithCsrMatrices)
868875
T{});
869876
std::fill_n(zero_mtx3->get_values(), zero_mtx3->get_num_stored_elements(),
870877
T{});
871-
872-
auto [zero_result, reuse] =
878+
auto [result, reuse] =
873879
zero_mtx->multiply_add_reuse(zero, zero_mtx3, zero, zero_mtx2);
874-
// we want to test that both the initial result and the updated result are
875-
// correct!
876-
auto [result, _reuse] = this->mtx->multiply_add_reuse(
877-
alpha, this->mtx3_unsorted, beta, this->mtx2);
878-
879-
ASSERT_EQ(result->get_size(), gko::dim<2>(2, 3));
880-
ASSERT_EQ(result->get_num_stored_elements(), 6);
881-
ASSERT_TRUE(result->is_sorted_by_column_index());
882-
GKO_ASSERT_MTX_EQ_SPARSITY(zero_result, result);
883-
auto r = zero_result->get_const_row_ptrs();
884-
auto c = zero_result->get_const_col_idxs();
885-
auto v = zero_result->get_const_values();
886-
auto v2 = result->get_const_values();
887-
EXPECT_EQ(r[0], 0);
888-
EXPECT_EQ(r[1], 3);
889-
EXPECT_EQ(r[2], 6);
890-
EXPECT_EQ(c[0], 0);
891-
EXPECT_EQ(c[1], 1);
892-
EXPECT_EQ(c[2], 2);
893-
EXPECT_EQ(c[3], 0);
894-
EXPECT_EQ(c[4], 1);
895-
EXPECT_EQ(c[5], 2);
896-
// the values should be 0
897-
EXPECT_EQ(v[0], T{});
898-
EXPECT_EQ(v[1], T{});
899-
EXPECT_EQ(v[2], T{});
900-
EXPECT_EQ(v[3], T{});
901-
EXPECT_EQ(v[4], T{});
902-
EXPECT_EQ(v[5], T{});
903880

904881
reuse.update_values(this->mtx, alpha, this->mtx3_sorted, beta, this->mtx2,
905882
result);
906883

907-
// -11 1 -27
908-
// -15 5 -40
909-
EXPECT_EQ(v2[0], T{-11});
910-
EXPECT_EQ(v2[1], T{1});
911-
EXPECT_EQ(v2[2], T{-27});
912-
EXPECT_EQ(v2[3], T{-15});
913-
EXPECT_EQ(v2[4], T{5});
914-
EXPECT_EQ(v2[5], T{-40});
884+
auto expected =
885+
this->mtx->multiply_add(alpha, this->mtx3_sorted, beta, this->mtx2);
886+
ASSERT_TRUE(result->is_sorted_by_column_index());
887+
GKO_ASSERT_MTX_EQ_SPARSITY(result, expected);
888+
GKO_ASSERT_MTX_NEAR(result, expected, 0);
915889
}
916890

917891

@@ -979,6 +953,32 @@ TYPED_TEST(Csr, AddsScaleCsrMatrices)
979953

980954

981955
TYPED_TEST(Csr, AddsScaleReuseCsrMatrices)
956+
{
957+
using T = typename TestFixture::value_type;
958+
using Vec = typename TestFixture::Vec;
959+
using Mtx = typename TestFixture::Mtx;
960+
auto alpha = gko::initialize<Vec>({-3.0}, this->exec);
961+
auto beta = gko::initialize<Vec>({2.0}, this->exec);
962+
auto a = gko::initialize<Mtx>(
963+
{I<T>{2.0, 0.0, 3.0}, I<T>{0.0, 1.0, -1.5}, I<T>{0.0, -2.0, 0.0},
964+
I<T>{5.0, 0.0, 0.0}, I<T>{1.0, 0.0, 4.0}, I<T>{2.0, -2.0, 0.0},
965+
I<T>{0.0, 0.0, 0.0}},
966+
this->exec);
967+
auto b = gko::initialize<Mtx>(
968+
{I<T>{2.0, -2.0, 0.0}, I<T>{1.0, 0.0, 4.0}, I<T>{2.0, 0.0, 3.0},
969+
I<T>{0.0, 1.0, -1.5}, I<T>{1.0, 0.0, 0.0}, I<T>{0.0, 0.0, 0.0},
970+
I<T>{0.0, 0.0, 0.0}},
971+
this->exec);
972+
973+
auto [result, reuse] = a->add_scale_reuse(alpha, beta, b);
974+
975+
auto expected = a->add_scale(alpha, beta, b);
976+
GKO_ASSERT_MTX_EQ_SPARSITY(expected, result);
977+
ASSERT_TRUE(result->is_sorted_by_column_index());
978+
}
979+
980+
981+
TYPED_TEST(Csr, AddsScaleReuseUpdateCsrMatrices)
982982
{
983983
using T = typename TestFixture::value_type;
984984
using Vec = typename TestFixture::Vec;
@@ -1000,20 +1000,12 @@ TYPED_TEST(Csr, AddsScaleReuseCsrMatrices)
10001000
auto zero = gko::initialize<Vec>({0.0}, this->exec);
10011001
std::fill_n(zero_a->get_values(), a->get_num_stored_elements(), T{});
10021002
std::fill_n(zero_b->get_values(), b->get_num_stored_elements(), T{});
1003-
auto expect = gko::initialize<Mtx>(
1004-
{I<T>{-2.0, -4.0, -9.0}, I<T>{2.0, -3.0, 12.5}, I<T>{4.0, 6.0, 6.0},
1005-
I<T>{-15.0, 2.0, -3.0}, I<T>{-1.0, 0.0, -12.0}, I<T>{-6.0, 6.0, 0.0},
1006-
I<T>{0.0, 0.0, 0.0}},
1007-
this->exec);
1008-
1003+
auto expected = a->add_scale(alpha, beta, b);
10091004
auto [result, reuse] = zero_a->add_scale_reuse(zero, zero, zero_b);
10101005

1011-
GKO_ASSERT_MTX_EQ_SPARSITY(expect, result);
1012-
ASSERT_TRUE(result->is_sorted_by_column_index());
1013-
10141006
reuse.update_values(alpha, a, beta, b, result);
10151007

1016-
GKO_ASSERT_MTX_NEAR(result, expect, r<T>::value);
1008+
GKO_ASSERT_MTX_NEAR(result, expected, r<T>::value);
10171009
}
10181010

10191011

@@ -1029,8 +1021,9 @@ TYPED_TEST(Csr, MultiplyReuseFailsOnWrongDimensions)
10291021
auto& m1 = this->mtx;
10301022
auto& m2 = this->mtx3_sorted;
10311023
auto m2_nnz = Mtx::create(this->exec, m2->get_size());
1032-
ASSERT_THROW(m1->multiply_reuse(m1), gko::DimensionMismatch);
10331024
auto [m3, reuse] = m1->multiply_reuse(m2);
1025+
1026+
ASSERT_THROW(m1->multiply_reuse(m1), gko::DimensionMismatch);
10341027
// Check for mismatching dimensions or nnz for every parameter
10351028
ASSERT_THROW(reuse.update_values(m2, m2, m3), gko::DimensionMismatch);
10361029
ASSERT_THROW(reuse.update_values(m3, m2, m3), gko::ValueMismatch);
@@ -1049,6 +1042,7 @@ TYPED_TEST(Csr, MultiplyAddFailsOnWrongDimensions)
10491042
auto& m1 = this->mtx;
10501043
auto& m2 = this->mtx2;
10511044
auto& m3 = this->mtx3_sorted;
1045+
10521046
ASSERT_THROW(m1->multiply_add(v, m3, s, m2), gko::DimensionMismatch);
10531047
ASSERT_THROW(m1->multiply_add(s, m1, s, m2), gko::DimensionMismatch);
10541048
ASSERT_THROW(m1->multiply_add(s, m3, v, m2), gko::DimensionMismatch);
@@ -1066,11 +1060,12 @@ TYPED_TEST(Csr, MultiplyAddReuseFailsOnWrongDimensions)
10661060
auto& m2 = this->mtx2;
10671061
auto& m3 = this->mtx3_sorted;
10681062
auto m3_nnz = Mtx::create(this->exec, m3->get_size());
1063+
auto [m4, r] = m1->multiply_add_reuse(s, m3, s, m2);
1064+
10691065
ASSERT_THROW(m1->multiply_add_reuse(v, m3, s, m2), gko::DimensionMismatch);
10701066
ASSERT_THROW(m1->multiply_add_reuse(s, m2, s, m2), gko::DimensionMismatch);
10711067
ASSERT_THROW(m1->multiply_add_reuse(s, m3, v, m2), gko::DimensionMismatch);
10721068
ASSERT_THROW(m1->multiply_add_reuse(s, m3, s, m3), gko::DimensionMismatch);
1073-
auto [m4, r] = m1->multiply_add_reuse(s, m3, s, m2);
10741069
ASSERT_THROW(r.update_values(m3, s, m3, s, m2, m4), gko::DimensionMismatch);
10751070
ASSERT_THROW(r.update_values(m2, s, m3, s, m2, m4), gko::ValueMismatch);
10761071
ASSERT_THROW(r.update_values(m1, v, m3, s, m2, m4), gko::DimensionMismatch);
@@ -1092,6 +1087,7 @@ TYPED_TEST(Csr, AddScaleFailsOnWrongDimensions)
10921087
auto& m1 = this->mtx;
10931088
auto& m2 = this->mtx2;
10941089
auto& m3 = this->mtx3_sorted;
1090+
10951091
ASSERT_THROW(m1->add_scale(v, s, m2), gko::DimensionMismatch);
10961092
ASSERT_THROW(m1->add_scale(s, v, m2), gko::DimensionMismatch);
10971093
ASSERT_THROW(m1->add_scale(s, s, m3), gko::DimensionMismatch);
@@ -1107,10 +1103,11 @@ TYPED_TEST(Csr, AddScaleReuseFailsOnWrongDimensions)
11071103
auto& m1 = this->mtx;
11081104
auto& m2 = this->mtx2;
11091105
auto& m3 = this->mtx3_sorted;
1106+
auto [m4, r] = m1->add_scale_reuse(s, s, m2);
1107+
11101108
ASSERT_THROW(m1->add_scale_reuse(v, s, m2), gko::DimensionMismatch);
11111109
ASSERT_THROW(m1->add_scale_reuse(s, v, m2), gko::DimensionMismatch);
11121110
ASSERT_THROW(m1->add_scale_reuse(s, s, m3), gko::DimensionMismatch);
1113-
auto [m4, r] = m1->add_scale_reuse(s, s, m2);
11141111
ASSERT_THROW(r.update_values(v, m1, s, m2, m4), gko::DimensionMismatch);
11151112
ASSERT_THROW(r.update_values(s, m3, s, m2, m4), gko::DimensionMismatch);
11161113
ASSERT_THROW(r.update_values(s, m2, s, m2, m4), gko::ValueMismatch);

test/matrix/csr_kernels2.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -549,11 +549,11 @@ TEST_F(Csr, MultiplyAddReuseUpdateCrossExecutor)
549549
alpha->scale(alpha);
550550
beta->scale(beta);
551551

552-
auto ref_result = dmtx->multiply_add(alpha, trans, beta, square_mtx);
552+
auto expected = dmtx->multiply_add(alpha, trans, beta, square_mtx);
553553
mtx = gko::clone(ref, dmtx);
554554
dreuse.update_values(mtx, alpha, trans, beta, square_mtx, result);
555555

556-
GKO_ASSERT_MTX_NEAR(result, ref_result, r<value_type>::value);
556+
GKO_ASSERT_MTX_NEAR(result, expected, r<value_type>::value);
557557
}
558558

559559

@@ -650,10 +650,10 @@ TEST_F(Csr, MultiplyReuseCrossExecutor)
650650
auto trans = gko::as<Mtx>(mtx->transpose());
651651

652652
auto [dresult, _dreuse] = dmtx->multiply_reuse(trans);
653-
auto ref_result = mtx->multiply(trans);
653+
auto expected = mtx->multiply(trans);
654654

655-
GKO_ASSERT_MTX_EQ_SPARSITY(dresult, ref_result);
656-
GKO_ASSERT_MTX_NEAR(dresult, ref_result, r<value_type>::value);
655+
GKO_ASSERT_MTX_EQ_SPARSITY(dresult, expected);
656+
GKO_ASSERT_MTX_NEAR(dresult, expected, r<value_type>::value);
657657
ASSERT_TRUE(dresult->is_sorted_by_column_index());
658658
ASSERT_EQ(dresult->get_executor(), exec);
659659
}
@@ -664,16 +664,16 @@ TEST_F(Csr, MultiplyReuseUpdateCrossExecutor)
664664
set_up_apply_data<Mtx::classical>();
665665
auto trans = gko::as<Mtx>(mtx->transpose());
666666
auto [dresult, dreuse] = dmtx->multiply_reuse(trans);
667-
auto ref_result = mtx->multiply(trans);
668-
auto result = ref_result->clone();
667+
auto expected = mtx->multiply(trans);
668+
auto result = expected->clone();
669669
// modify all involved matrices and scalars
670670
mtx->scale(alpha);
671671
trans->scale(beta);
672672

673673
dreuse.update_values(mtx, trans, result);
674-
ref_result = mtx->multiply(trans);
674+
expected = mtx->multiply(trans);
675675

676-
GKO_ASSERT_MTX_NEAR(result, ref_result, r<value_type>::value);
676+
GKO_ASSERT_MTX_NEAR(result, expected, r<value_type>::value);
677677
}
678678

679679

@@ -737,11 +737,11 @@ TEST_F(Csr, AddScaleReuseCrossExecutor)
737737
dmtx = gko::clone(exec, mtx);
738738

739739
auto [dresult, _dreuse] = dmtx->add_scale_reuse(alpha, beta, mtx2);
740-
auto ref_result = dmtx->add_scale(alpha, beta, mtx2);
741-
auto result = ref_result->clone();
740+
auto expected = dmtx->add_scale(alpha, beta, mtx2);
741+
auto result = expected->clone();
742742

743-
GKO_ASSERT_MTX_EQ_SPARSITY(dresult, ref_result);
744-
GKO_ASSERT_MTX_NEAR(dresult, ref_result, r<value_type>::value);
743+
GKO_ASSERT_MTX_EQ_SPARSITY(dresult, expected);
744+
GKO_ASSERT_MTX_NEAR(dresult, expected, r<value_type>::value);
745745
ASSERT_TRUE(dresult->is_sorted_by_column_index());
746746
ASSERT_EQ(dresult->get_executor(), exec);
747747
}
@@ -754,19 +754,19 @@ TEST_F(Csr, AddScaleReuseUpdateCrossExecutor)
754754
mtx2 = gen_mtx<Mtx>(mtx_size[0], mtx_size[1], 0);
755755
dmtx = gko::clone(exec, mtx);
756756
auto [dresult, dreuse] = dmtx->add_scale_reuse(alpha, beta, mtx2);
757-
auto ref_result = dmtx->add_scale(alpha, beta, mtx2);
758-
auto result = ref_result->clone();
757+
auto expected = dmtx->add_scale(alpha, beta, mtx2);
758+
auto result = expected->clone();
759759
// modify all involved matrices and scalars
760760
dmtx->scale(beta);
761761
mtx2->scale(alpha);
762762
alpha->scale(alpha);
763763
beta->scale(beta);
764764

765-
ref_result = dmtx->add_scale(alpha, beta, mtx2);
765+
expected = dmtx->add_scale(alpha, beta, mtx2);
766766
mtx = gko::clone(ref, dmtx);
767767
dreuse.update_values(alpha, mtx, beta, mtx2, result);
768768

769-
GKO_ASSERT_MTX_NEAR(result, ref_result, r<value_type>::value);
769+
GKO_ASSERT_MTX_NEAR(result, expected, r<value_type>::value);
770770
}
771771

772772

0 commit comments

Comments
 (0)