11use ndarray:: array;
2- use numpy:: { get_array_module, AllowTypeChange , PyArrayLike1 , PyArrayLike2 , PyArrayLikeDyn } ;
2+ use numpy:: {
3+ get_array_module, AllowTypeChange , PyArrayLike1 , PyArrayLike2 , PyArrayLikeDyn ,
4+ PyUntypedArrayMethods ,
5+ } ;
36use pyo3:: {
47 ffi:: c_str,
58 types:: { IntoPyDict , PyAnyMethods , PyDict } ,
@@ -105,7 +108,9 @@ fn convert_1d_list_on_extract() {
105108 Python :: with_gil ( |py| {
106109 let py_list = py. eval ( c_str ! ( "[1,2,3,4]" ) , None , None ) . unwrap ( ) ;
107110 let extracted_array_1d = py_list. extract :: < PyArrayLike1 < ' _ , u32 > > ( ) . unwrap ( ) ;
108- let extracted_array_dyn = py_list. extract :: < PyArrayLikeDyn < ' _ , f64 > > ( ) . unwrap ( ) ;
111+ let extracted_array_dyn = py_list
112+ . extract :: < PyArrayLikeDyn < ' _ , f64 , AllowTypeChange > > ( )
113+ . unwrap ( ) ;
109114
110115 assert_eq ! ( array![ 1 , 2 , 3 , 4 ] , extracted_array_1d. as_array( ) ) ;
111116 assert_eq ! (
@@ -115,6 +120,25 @@ fn convert_1d_list_on_extract() {
115120 } ) ;
116121}
117122
123+ #[ test]
124+ fn preserve_trailing_singleton_dims ( ) {
125+ Python :: with_gil ( |py| {
126+ let locals = get_np_locals ( py) ;
127+ let py_array = py
128+ . eval (
129+ c_str ! ( "np.array([[1], [2], [3]], dtype='int32')" ) ,
130+ Some ( & locals) ,
131+ None ,
132+ )
133+ . unwrap ( ) ;
134+ let extracted_array = py_array
135+ . extract :: < PyArrayLikeDyn < ' _ , f64 , AllowTypeChange > > ( )
136+ . unwrap ( ) ;
137+
138+ assert_eq ! ( extracted_array. shape( ) , & [ 3 , 1 ] ) ;
139+ } )
140+ }
141+
118142#[ test]
119143fn unsafe_cast_shall_fail ( ) {
120144 Python :: with_gil ( |py| {
0 commit comments