@@ -290,17 +290,13 @@ public void TestTensorDefaultPrint()
290290 Tensor t = torch . zeros ( 2 , 2 ) ;
291291 string expectedOutput = t . ToString ( TensorStringStyle . Default ) + Environment . NewLine ;
292292 var originalOut = Console . Out ;
293- using ( var sw = new StringWriter ( ) )
294- {
295- try
296- {
293+ using ( var sw = new StringWriter ( ) ) {
294+ try {
297295 Console . SetOut ( sw ) ;
298296 t . print ( ) ;
299297 var result = sw . ToString ( ) ;
300298 Assert . Equal ( expectedOutput , result ) ;
301- }
302- finally
303- {
299+ } finally {
304300 Console . SetOut ( originalOut ) ;
305301 }
306302 }
@@ -807,7 +803,7 @@ public void FromArrayFactory()
807803 ( ) => Assert . Equal ( 1 , t . ndim ) ,
808804 ( ) => Assert . Equal ( ScalarType . Byte , t . dtype ) ) ;
809805 }
810-
806+
811807 {
812808 var array = new Memory < long > ( new long [ 8 ] ) ;
813809 using var t = torch . tensor ( array , new long [ ] { 8 } , device : device ) ;
@@ -816,11 +812,11 @@ public void FromArrayFactory()
816812 ( ) => Assert . Equal ( 1 , t . ndim ) ,
817813 ( ) => Assert . Equal ( ScalarType . Int64 , t . dtype ) ) ;
818814 }
819-
815+
820816 {
821817 var array = new long [ 18 ] ;
822818 array [ 5 ] = 17 ;
823- var mem = new Memory < long > ( array , 4 , 10 ) ;
819+ var mem = new Memory < long > ( array , 4 , 10 ) ;
824820 using var t = torch . tensor ( mem , new long [ ] { 8 } , device : device ) ;
825821 Assert . Multiple (
826822 ( ) => Assert . Equal ( device . type , t . device_type ) ,
@@ -3165,6 +3161,86 @@ public void IndexFill2()
31653161 ( ) => Assert . Equal ( 1.0 , x [ 2 , 2 ] . ToSingle ( ) ) ) ;
31663162 }
31673163
3164+ [ Fact ]
3165+ [ TestOf ( nameof ( Tensor . index_put_ ) ) ]
3166+ public void IndexPutOneValueOneIndex ( )
3167+ {
3168+ using var _ = NewDisposeScope ( ) ;
3169+
3170+ var tensor = ones ( 5 ) ;
3171+ var indices = new TensorIndex [ ] { TensorIndex . Tensor ( 1 ) } ;
3172+ var values = torch . tensor ( 5.0f ) ;
3173+
3174+ // default accumulate value is false, should only replace value at index 1 with 5
3175+ tensor . index_put_ ( values , indices ) ;
3176+ Assert . True ( tensor . Equals ( torch . tensor ( new float [ ] { 1.0f , 5.0f , 1.0f , 1.0f , 1.0f } ) ) ) ;
3177+
3178+ tensor = ones ( 5 ) ;
3179+ // accumulate value is false, explicitly set, should only replace value at index 1 with 5
3180+ tensor . index_put_ ( values , indices , accumulate : false ) ;
3181+ Assert . True ( tensor . Equals ( torch . tensor ( new float [ ] { 1.0f , 5.0f , 1.0f , 1.0f , 1.0f } ) ) ) ;
3182+
3183+ tensor = ones ( 5 ) ;
3184+ // accumulate value is true, should add value to index 1, 1 + 5 = 6
3185+ tensor . index_put_ ( values , indices , accumulate : true ) ;
3186+ Assert . True ( tensor . Equals ( torch . tensor ( new float [ ] { 1.0f , 6.0f , 1.0f , 1.0f , 1.0f } ) ) ) ;
3187+ }
3188+
3189+ [ Fact ]
3190+ [ TestOf ( nameof ( Tensor . index_put_ ) ) ]
3191+ public void IndexPutOneValueMultipleIndexes ( )
3192+ {
3193+ using var _ = NewDisposeScope ( ) ;
3194+
3195+ var tensor = ones ( 5 ) ;
3196+ var indices = new TensorIndex [ ] { TensorIndex . Tensor ( new long [ ] { 1 , 2 } ) } ;
3197+ var values = torch . tensor ( 10.0f ) ;
3198+
3199+ // default accumulate value is false, should only replace value at given indexes
3200+ tensor . index_put_ ( values , indices ) ;
3201+ Assert . True ( tensor . Equals ( torch . tensor ( new float [ ] { 1.0f , 10.0f , 10.0f , 1.0f , 1.0f } ) ) ) ;
3202+
3203+ tensor = ones ( 5 ) ;
3204+ // accumulate value is true, should add value to given indexes
3205+ tensor . index_put_ ( values , indices , true ) ;
3206+ Assert . True ( tensor . Equals ( torch . tensor ( new float [ ] { 1.0f , 11.0f , 11.0f , 1.0f , 1.0f } ) ) ) ;
3207+
3208+ // accumulate value is false, explicitly set, should replace value at given indexes
3209+ tensor . index_put_ ( values , indices , false ) ;
3210+ Assert . True ( tensor . Equals ( torch . tensor ( new float [ ] { 1.0f , 10.0f , 10.0f , 1.0f , 1.0f } ) ) ) ;
3211+ }
3212+
3213+ [ Fact ]
3214+ [ TestOf ( nameof ( Tensor . index_put_ ) ) ]
3215+ public void IndexPutMultipleValuesMultipleIndexes ( )
3216+ {
3217+ using var _ = NewDisposeScope ( ) ;
3218+
3219+ var tensor = ones ( 5 , 2 ) ;
3220+ var indices = new TensorIndex [ ]
3221+ {
3222+ TensorIndex . Tensor ( new long [ ] { 1 , 2 , 0 , 3 } ) , // for first tensor dimension (row)
3223+ TensorIndex . Tensor ( new long [ ] { 0 , 1 , 0 , 0 } ) // for second tensor dimension (column)
3224+ } ;
3225+ var values = torch . tensor ( new float [ ] { 3.0f , 4.0f , 5.0f , 10f } ) ;
3226+
3227+ // default accumulate value is false, should only replace values at given indices with 3, 4, 5, 10
3228+ // Indexes to be replaced: (1, 0) -> 3.0, (2, 1) -> 4.0, (0, 0) -> 5.0, (3, 0) -> 10.0
3229+ tensor . index_put_ ( values , indices ) ;
3230+ Assert . True ( tensor . Equals ( torch . tensor ( new float [ , ] { { 5.0f , 1.0f } , { 3.0f , 1.0f } , { 1.0f , 4.0f } , { 10.0f , 1.0f } , { 1.0f , 1.0f } } ) ) ) ;
3231+
3232+ tensor = ones ( 5 , 2 ) ;
3233+ // accumulate value is true, should perform addition at given indices, 1 + 3 = 4, 1 + 4 = 5, 1 + 5 = 6, 1 + 10 = 11
3234+ // Indexes to be replaced: (1, 0) -> 4.0, (2, 1) -> 5.0, (0, 0) -> 6.0, (3, 0) -> 11.0
3235+ tensor . index_put_ ( values , indices , true ) ;
3236+ Assert . True ( tensor . Equals ( torch . tensor ( new float [ , ] { { 6.0f , 1.0f } , { 4.0f , 1.0f } , { 1.0f , 5.0f } , { 11.0f , 1.0f } , { 1.0f , 1.0f } } ) ) ) ;
3237+
3238+ // accumulate value is false, explicitly set, should only replace values at given indices with 3, 4, 5, 10
3239+ // Indexes to be replaced: (1, 0) -> 3.0, (2, 1) -> 4.0, (0, 0) -> 5.0, (3, 0) -> 10.0
3240+ tensor . index_put_ ( values , indices , false ) ;
3241+ Assert . True ( tensor . Equals ( torch . tensor ( new float [ , ] { { 5.0f , 1.0f } , { 3.0f , 1.0f } , { 1.0f , 4.0f } , { 10.0f , 1.0f } , { 1.0f , 1.0f } } ) ) ) ;
3242+ }
3243+
31683244 [ Fact ]
31693245 [ TestOf ( nameof ( TensorExtensionMethods . ToTensor ) ) ]
31703246 public void ScalarToTensor ( )
@@ -3257,7 +3333,7 @@ public void ScalarToTensor3()
32573333 [ TestOf ( nameof ( Tensor ) ) ]
32583334 public void ScalarToTensorDoesNotLeakMemory ( )
32593335 {
3260- AssertTensorDoesNotLeak ( ( ) => {
3336+ AssertTensorDoesNotLeak ( ( ) => {
32613337 Tensor tensor = 1 ;
32623338 return tensor ;
32633339 } ) ;
@@ -3273,20 +3349,20 @@ public void ScalarToTensorDoesNotLeakMemory()
32733349 [ TestOf ( nameof ( Tensor ) ) ]
32743350 public void ScalarArrayToTensorDoesNotLeakMemory ( )
32753351 {
3276- AssertTensorDoesNotLeak ( ( ) => ( new byte [ ] { 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3277- AssertTensorDoesNotLeak ( ( ) => ( new sbyte [ ] { - 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3278- AssertTensorDoesNotLeak ( ( ) => ( new short [ ] { - 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3279- AssertTensorDoesNotLeak ( ( ) => ( new long [ ] { - 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3280- AssertTensorDoesNotLeak ( ( ) => ( new float [ ] { - 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3281- AssertTensorDoesNotLeak ( ( ) => ( new double [ ] { - 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3352+ AssertTensorDoesNotLeak ( ( ) => ( new byte [ ] { 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3353+ AssertTensorDoesNotLeak ( ( ) => ( new sbyte [ ] { - 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3354+ AssertTensorDoesNotLeak ( ( ) => ( new short [ ] { - 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3355+ AssertTensorDoesNotLeak ( ( ) => ( new long [ ] { - 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3356+ AssertTensorDoesNotLeak ( ( ) => ( new float [ ] { - 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3357+ AssertTensorDoesNotLeak ( ( ) => ( new double [ ] { - 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
32823358 }
32833359
32843360 [ Fact ]
32853361 [ TestOf ( nameof ( Tensor ) ) ]
32863362 public void ComplexNumberOfDoubleDoesNotLeakMemory ( )
32873363 {
3288- AssertTensorDoesNotLeak ( ( ) => ( torch . tensor ( ( double ) - 1 , ( double ) - 2 ) ) ) ;
3289- AssertTensorDoesNotLeak ( ( ) => ( torch . tensor ( ( ( double ) - 1 , ( double ) - 2 ) ) ) ) ;
3364+ AssertTensorDoesNotLeak ( ( ) => ( torch . tensor ( ( double ) - 1 , ( double ) - 2 ) ) ) ;
3365+ AssertTensorDoesNotLeak ( ( ) => ( torch . tensor ( ( ( double ) - 1 , ( double ) - 2 ) ) ) ) ;
32903366 }
32913367
32923368 [ Fact ]
@@ -4106,7 +4182,7 @@ public void CastMoveAndDisposeAfter()
41064182 Assert . True ( input . IsInvalid ) ;
41074183 Assert . False ( cast . IsInvalid ) ;
41084184 // make sure we can access the values
4109- Assert . Equal ( 1 , cast [ 0 ] . ToInt32 ( ) ) ;
4185+ Assert . Equal ( 1 , cast [ 0 ] . ToInt32 ( ) ) ;
41104186 }
41114187 if ( torch . cuda . is_available ( ) ) {
41124188 {
@@ -8517,28 +8593,27 @@ public void DefaultDTypeCreation()
85178593 {
85188594 var dt = torch . get_default_dtype ( ) ;
85198595
8520- var t = torch . zeros ( 5 , 5 ) ;
8596+ var t = torch . zeros ( 5 , 5 ) ;
85218597 Assert . Equal ( torch . float32 , t . dtype ) ;
85228598
85238599 try {
8524- torch . set_default_dtype ( torch . float64 ) ;
8525-
8526- t = torch . zeros ( 5 , 5 ) ;
8600+ torch . set_default_dtype ( torch . float64 ) ;
8601+
8602+ t = torch . zeros ( 5 , 5 ) ;
85278603 Assert . Equal ( torch . float64 , t . dtype ) ;
85288604
8529- t = torch . ones ( 5 , 5 ) ;
8605+ t = torch . ones ( 5 , 5 ) ;
85308606 Assert . Equal ( torch . float64 , t . dtype ) ;
85318607
8532- t = torch . rand ( 5 , 5 ) ;
8608+ t = torch . rand ( 5 , 5 ) ;
85338609 Assert . Equal ( torch . float64 , t . dtype ) ;
85348610
8535- t = torch . randn ( 5 , 5 ) ;
8611+ t = torch . randn ( 5 , 5 ) ;
85368612 Assert . Equal ( torch . float64 , t . dtype ) ;
85378613
85388614 t = torch . logspace ( 5 , 15 , 20 ) ;
85398615 Assert . Equal ( torch . float64 , t . dtype ) ;
8540- }
8541- finally {
8616+ } finally {
85428617 torch . set_default_dtype ( dt ) ;
85438618 }
85448619 }
@@ -8548,28 +8623,27 @@ public void DefaultDeviceCreation()
85488623 {
85498624 var dt = torch . get_default_device ( ) ;
85508625
8551- var t = torch . zeros ( 5 , 5 ) ;
8626+ var t = torch . zeros ( 5 , 5 ) ;
85528627 Assert . Equal ( DeviceType . CPU , t . device_type ) ;
85538628
85548629 try {
8555- torch . set_default_device ( torch . META ) ;
8556-
8557- t = torch . zeros ( 5 , 5 ) ;
8630+ torch . set_default_device ( torch . META ) ;
8631+
8632+ t = torch . zeros ( 5 , 5 ) ;
85588633 Assert . Equal ( DeviceType . META , t . device_type ) ;
85598634
8560- t = torch . ones ( 5 , 5 ) ;
8635+ t = torch . ones ( 5 , 5 ) ;
85618636 Assert . Equal ( DeviceType . META , t . device_type ) ;
85628637
8563- t = torch . rand ( 5 , 5 ) ;
8638+ t = torch . rand ( 5 , 5 ) ;
85648639 Assert . Equal ( DeviceType . META , t . device_type ) ;
85658640
8566- t = torch . randn ( 5 , 5 ) ;
8641+ t = torch . randn ( 5 , 5 ) ;
85678642 Assert . Equal ( DeviceType . META , t . device_type ) ;
85688643
85698644 t = torch . logspace ( 5 , 15 , 20 ) ;
85708645 Assert . Equal ( DeviceType . META , t . device_type ) ;
8571- }
8572- finally {
8646+ } finally {
85738647 torch . set_default_device ( dt ) ;
85748648 }
85758649 }
0 commit comments