2020 split_batch_concatenated_tensor ,
2121)
2222from pytorch_sparse_utils .validation import validate_atleast_nd
23+ from .. import random_sparse_tensor
2324
2425
2526@pytest .fixture
@@ -504,6 +505,14 @@ def test_scalar_feature(self, device):
504505 assert torch .equal (result , expected_result )
505506 assert torch .equal (batch_offsets , expected_batch_offsets )
506507
508+ def test_empty_tensor (self , device ):
509+ tensor = torch .randn (0 , 10 , 32 , device = device )
510+ result , batch_offsets = padded_to_concatenated (tensor )
511+
512+ assert result .numel () == 0
513+ assert result .shape == (0 , 32 )
514+ assert torch .equal (batch_offsets , torch .tensor ([0 ], device = device ))
515+
507516 def test_error_handling (self , device ):
508517 """Test error handling."""
509518 # Test with tensor with less than 3 dimensions
@@ -523,6 +532,26 @@ def test_error_handling(self, device):
523532 ):
524533 padded_to_concatenated (tensor , padding_mask_wrong_batch )
525534
535+ # Wrong padding mask dim
536+ padding_mask_3d = torch .zeros (3 , 4 , 5 , device = device , dtype = torch .bool )
537+ padding_mask_3d [0 , - 1 ] = True
538+ with pytest .raises (
539+ (ValueError , torch .jit .Error ), # pyright: ignore[reportArgumentType]
540+ match = "Expected padding_mask to be 2D" ,
541+ ):
542+ padded_to_concatenated (tensor , padding_mask_3d )
543+
544+ # Sequence length mismatch
545+ padding_mask_wrong_seq_length = torch .zeros (
546+ 3 , 5 , device = device , dtype = torch .bool
547+ )
548+ padding_mask_wrong_seq_length [0 , - 1 ] = True
549+ with pytest .raises (
550+ (ValueError , torch .jit .Error ), # pyright: ignore[reportArgumentType]
551+ match = "Sequence length mismatch" ,
552+ ):
553+ padded_to_concatenated (tensor , padding_mask_wrong_seq_length )
554+
526555
527556@pytest .mark .cpu_and_cuda
528557class TestBatchDimToLeadingIndex :
@@ -739,3 +768,25 @@ def test_error_not_sparse(self, device):
739768 match = "Received non-sparse tensor" ,
740769 ):
741770 sparse_tensor_to_concatenated (tensor )
771+
772+
773+ class TestConcatenatedToSparseTensor :
774+ def test_basic_functionality (self , device ):
775+ """Test basic functions"""
776+ sparse_tensor = random_sparse_tensor ([4 , 5 , 5 ], [8 ], 0.5 , seed = 0 , device = device )
777+
778+ values , indices , batch_offsets = sparse_tensor_to_concatenated (sparse_tensor )
779+
780+ out = concatenated_to_sparse_tensor (values , indices , sparse_tensor .shape )
781+
782+ assert isinstance (out , Tensor )
783+ assert out .is_sparse
784+
785+ assert torch .equal (sparse_tensor .indices (), out .indices ())
786+ assert torch .equal (sparse_tensor .values (), out .values ())
787+ assert sparse_tensor .shape == out .shape
788+
789+ # Test without shape param
790+ out_2 = concatenated_to_sparse_tensor (values , indices )
791+ assert torch .equal (sparse_tensor .indices (), out_2 .indices ())
792+ assert torch .equal (sparse_tensor .values (), out_2 .values ())
0 commit comments