File tree Expand file tree Collapse file tree 1 file changed +45
-0
lines changed Expand file tree Collapse file tree 1 file changed +45
-0
lines changed Original file line number Diff line number Diff line change 1+ import pytest
2+
3+ import arrayfire_wrapper .dtypes as dtypes
4+ import arrayfire_wrapper .lib as wrapper
5+
6+
7+ @pytest .mark .parametrize (
8+ "shape" ,
9+ [
10+ (3 , 3 ),
11+ (3 , 3 , 3 ),
12+ (3 , 3 , 3 , 3 ),
13+ ],
14+ )
15+ def test_diag_is_unit (shape : tuple ) -> None :
16+ """Test if when is_unit_diag in lower returns an array with a unit diagonal"""
17+ dtype = dtypes .s64
18+ constant_array = wrapper .constant (3 , shape , dtype )
19+
20+ lower_array = wrapper .upper (constant_array , True )
21+ diagonal = wrapper .diag_extract (lower_array , 0 )
22+ diagonal_value = wrapper .get_scalar (diagonal , dtype )
23+
24+ assert diagonal_value == 1
25+
26+
27+ @pytest .mark .parametrize (
28+ "shape" ,
29+ [
30+ (3 , 3 ),
31+ (3 , 3 , 3 ),
32+ (3 , 3 , 3 , 3 ),
33+ ],
34+ )
35+ def test_is_original (shape : tuple ) -> None :
36+ """Test if is_original keeps the diagonal the same as the original array"""
37+ dtype = dtypes .s64
38+ constant_array = wrapper .constant (3 , shape , dtype )
39+ original_value = wrapper .get_scalar (constant_array , dtype )
40+
41+ lower_array = wrapper .upper (constant_array , False )
42+ diagonal = wrapper .diag_extract (lower_array , 0 )
43+ diagonal_value = wrapper .get_scalar (diagonal , dtype )
44+
45+ assert original_value == diagonal_value
You can’t perform that action at this time.
0 commit comments