@@ -44,16 +44,16 @@ def test_isin_basic(dtype):
4444 skip_if_dtype_not_supported (dtype , q )
4545
4646 n = 100
47- x = dpt .arange (n , dtype = dtype )
48- test = dpt .arange (n - 1 , dtype = dtype )
47+ x = dpt .arange (n , dtype = dtype , sycl_queue = q )
48+ test = dpt .arange (n - 1 , dtype = dtype , sycl_queue = q )
4949 r1 = dpt .isin (x , test )
5050 assert dpt .all (r1 [:- 1 ])
5151 assert not r1 [- 1 ]
5252 assert r1 .shape == x .shape
5353
5454 # test with invert keyword
5555 r2 = dpt .isin (x , test , invert = True )
56- assert not dpt .all (r2 [:- 1 ])
56+ assert not dpt .any (r2 [:- 1 ])
5757 assert r2 [- 1 ]
5858 assert r2 .shape == x .shape
5959
@@ -70,7 +70,7 @@ def test_isin_basic_bool():
7070 assert r1 .shape == x .shape
7171
7272 r2 = dpt .isin (x , test , invert = True )
73- assert not dpt .all (r2 [:- 1 ])
73+ assert not dpt .any (r2 [:- 1 ])
7474 assert r2 [- 1 ]
7575 assert r2 .shape == x .shape
7676
@@ -98,37 +98,44 @@ def test_isin_strided(dtype):
9898 skip_if_dtype_not_supported (dtype , q )
9999
100100 n , m = 100 , 20
101- x = dpt .zeros ((n , m ), dtype = dtype , order = "F" )
102- x [:, ::2 ] = dpt .arange (1 , (m / 2 ) + 1 , dtype = dtype )
103- test = dpt .arange (1 , (m / 2 ) + 1 , dtype = dtype )
104- r1 = dpt .isin (x , test )
105- assert dpt .all (r1 [:, ::2 ])
106- assert not dpt .all (r1 [:, 1 ::2 ])
107- assert r1 .shape == x .shape
101+ x = dpt .zeros ((n , m ), dtype = dtype , order = "F" , sycl_queue = q )
102+ x [:, ::2 ] = dpt .arange (1 , (m / 2 ) + 1 , dtype = dtype , sycl_queue = q )
103+ x_s = x [:, ::2 ]
104+ test = dpt .arange (1 , (m / 2 ), dtype = dtype , sycl_queue = q )
105+ r1 = dpt .isin (x_s , test )
106+ assert dpt .all (r1 [:, :- 1 ])
107+ assert not dpt .any (r1 [:, - 1 ])
108+ assert not dpt .any (x [:, 1 ::2 ])
109+ assert r1 .shape == x_s .shape
108110
109111 # test with invert keyword
110- r2 = dpt .isin (x , test , invert = True )
111- assert not dpt .all (r2 [:, ::2 ])
112- assert dpt .all (r2 [:, 1 ::2 ])
113- assert r2 .shape == x .shape
112+ r2 = dpt .isin (x_s , test , invert = True )
113+ assert not dpt .any (r2 [:, :- 1 ])
114+ assert dpt .all (r2 [:, - 1 ])
115+ assert not dpt .any (x [:, 1 :2 ])
116+ assert r2 .shape == x_s .shape
114117
115118
116119def test_isin_strided_bool ():
117120 dt = dpt .bool
121+
118122 n , m = 100 , 20
119- x = dpt .ones ((n , m ), dtype = dt , order = "F" )
120- x [:, ::2 ] = False
121- test = dpt .zeros ((), dtype = dt )
122- r1 = dpt .isin (x , test )
123- assert dpt .all (r1 [:, ::2 ])
124- assert not dpt .all (r1 [:, 1 ::2 ])
125- assert r1 .shape == x .shape
123+ x = dpt .zeros ((n , m ), dtype = dt , order = "F" )
124+ x [:, :- 2 :2 ] = True
125+ x_s = x [:, ::2 ]
126+ test = dpt .ones ((), dtype = dt )
127+ r1 = dpt .isin (x_s , test )
128+ assert dpt .all (r1 [:, :- 1 ])
129+ assert not dpt .any (r1 [:, - 1 ])
130+ assert not dpt .any (x [:, 1 ::2 ])
131+ assert r1 .shape == x_s .shape
126132
127133 # test with invert keyword
128- r2 = dpt .isin (x , test , invert = True )
129- assert not dpt .all (r2 [:, ::2 ])
130- assert dpt .all (r2 [:, 1 ::2 ])
131- assert r2 .shape == x .shape
134+ r2 = dpt .isin (x_s , test , invert = True )
135+ assert not dpt .any (r2 [:, :- 1 ])
136+ assert dpt .all (r2 [:, - 1 ])
137+ assert not dpt .any (x [:, 1 :2 ])
138+ assert r2 .shape == x_s .shape
132139
133140
134141def test_isin_empty_inputs ():
0 commit comments