1919from accelerate import init_empty_weights
2020
2121# Assuming the module is named "module_matching" - adjust import as needed
22- from compressed_tensors .utils .match import (
22+ from compressed_tensors .utils import (
23+ InternalModule ,
2324 is_match ,
24- match_class ,
2525 match_modules_set ,
26- match_name ,
2726 match_named_modules ,
2827 match_named_parameters ,
2928)
29+ from compressed_tensors .utils .match import _match_class , _match_name
3030
3131
3232class DummyModel (nn .Module ):
@@ -66,14 +66,14 @@ def __init__(self):
6666
6767
6868class TestMatchName :
69- """Test cases for match_name function"""
69+ """Test cases for _match_name function"""
7070
7171 def test_exact_match (self ):
7272 """Test exact string matching"""
73- assert match_name ("layer1" , "layer1" ) == True
74- assert match_name ("layer1" , "layer2" ) == False
73+ assert _match_name ("layer1" , "layer1" ) == True
74+ assert _match_name ("layer1" , "layer2" ) == False
7575 assert (
76- match_name (
76+ _match_name (
7777 "transformer.layers.0.self_attn.q_proj" ,
7878 "transformer.layers.0.self_attn.q_proj" ,
7979 )
@@ -82,14 +82,14 @@ def test_exact_match(self):
8282
8383 def test_regex_match (self ):
8484 """Test regex matching with "re:" prefix"""
85- assert match_name ("layer1" , "re:layer.*" ) == True
86- assert match_name ("layer1" , "re:^layer1$" ) == True
87- assert match_name ("layer1" , "re:layer2" ) == False
85+ assert _match_name ("layer1" , "re:layer.*" ) == True
86+ assert _match_name ("layer1" , "re:^layer1$" ) == True
87+ assert _match_name ("layer1" , "re:layer2" ) == False
8888 assert (
89- match_name ("transformer.layers.0.self_attn.q_proj" , "re:.*q_proj" ) == True
89+ _match_name ("transformer.layers.0.self_attn.q_proj" , "re:.*q_proj" ) == True
9090 )
9191 assert (
92- match_name (
92+ _match_name (
9393 "transformer.layers.0.self_attn.q_proj" ,
9494 "re:transformer\\ .layers\\ .\\ d+\\ .self_attn\\ ..*_proj$" ,
9595 )
@@ -98,49 +98,49 @@ def test_regex_match(self):
9898
9999 def test_empty_strings (self ):
100100 """Test edge cases with empty strings"""
101- assert match_name ("" , "" ) == True
102- assert match_name ("layer1" , "" ) == False
103- assert match_name ("" , "layer1" ) == False
101+ assert _match_name ("" , "" ) == True
102+ assert _match_name ("layer1" , "" ) == False
103+ assert _match_name ("" , "layer1" ) == False
104104
105105 def test_regex_special_characters (self ):
106106 """Test regex with special characters"""
107- assert match_name ("layer.1" , "re:layer\\ .1" ) == True
108- assert match_name ("layer.1" , "re:layer.1" ) == True # . matches any char
109- assert match_name ("layer_1" , "re:layer_1" ) == True
107+ assert _match_name ("layer.1" , "re:layer\\ .1" ) == True
108+ assert _match_name ("layer.1" , "re:layer.1" ) == True # . matches any char
109+ assert _match_name ("layer_1" , "re:layer_1" ) == True
110110
111111
112112class TestMatchClass :
113- """Test cases for match_class function"""
113+ """Test cases for _match_class function"""
114114
115115 def test_direct_class_match (self ):
116116 """Test matching direct class names"""
117117 linear = nn .Linear (10 , 20 )
118- assert match_class (linear , "Linear" ) == True
119- assert match_class (linear , "Conv2d" ) == False
118+ assert _match_class (linear , "Linear" ) == True
119+ assert _match_class (linear , "Conv2d" ) == False
120120
121121 norm = nn .LayerNorm (10 )
122- assert match_class (norm , "LayerNorm" ) == True
123- assert match_class (norm , "BatchNorm1d" ) == False
122+ assert _match_class (norm , "LayerNorm" ) == True
123+ assert _match_class (norm , "BatchNorm1d" ) == False
124124
125125 def test_parent_class_match (self ):
126126 """Test matching parent class names"""
127127 linear = nn .Linear (10 , 20 )
128- assert match_class (linear , "Module" ) == True
128+ assert _match_class (linear , "Module" ) == True
129129
130130 conv = nn .Conv2d (3 , 16 , 3 )
131- assert match_class (conv , "Module" ) == True
132- assert match_class (conv , "_ConvNd" ) == True
131+ assert _match_class (conv , "Module" ) == True
132+ assert _match_class (conv , "_ConvNd" ) == True
133133
134134 def test_non_torch_module (self ):
135135 """Test with non-torch modules"""
136136 regular_object = object ()
137- assert match_class (regular_object , "object" ) == False # not a torch.nn.Module
137+ assert _match_class (regular_object , "object" ) == False # not a torch.nn.Module
138138
139139 def test_custom_module (self ):
140140 """Test with custom module classes"""
141141 model = DummyModel ()
142- assert match_class (model , "DummyModel" ) == True
143- assert match_class (model , "Module" ) == True
142+ assert _match_class (model , "DummyModel" ) == True
143+ assert _match_class (model , "Module" ) == True
144144
145145
146146class TestIsMatch :
@@ -171,6 +171,15 @@ def test_regex_in_name_match(self):
171171 assert is_match ("layer1" , linear , "re:layer.*" ) == True
172172 assert is_match ("layer1" , linear , "re:conv.*" ) == False
173173
174+ def test_internal_module_match (self ):
175+ """Test not matching internal modules"""
176+
177+ class InternalLinear (InternalModule , nn .Linear ):
178+ pass
179+
180+ linear = InternalLinear (10 , 20 )
181+ assert is_match ("layer1" , linear , "re:layer.*" ) == False
182+
174183
175184class TestMatchNamedModules :
176185 """Test cases for match_named_modules function"""
@@ -236,6 +245,16 @@ def test_warn_on_fail(self, mock_logger):
236245 assert "Could not match" in warning_msg
237246 assert "nonexistent_module" in warning_msg
238247
248+ def test_internal_match (self ):
249+ """Test not matching internal modules"""
250+
251+ class InternalLinear (InternalModule , nn .Linear ):
252+ pass
253+
254+ linear = InternalLinear (10 , 20 )
255+ matches = list (match_named_modules (linear , ["re:.*" ]))
256+ assert len (matches ) == 0
257+
239258
240259class TestMatchNamedParameters :
241260 """Test cases for match_named_parameters function"""
@@ -298,6 +317,16 @@ def test_warn_on_fail_parameters(self, mock_logger):
298317 assert "Could not match" in warning_msg
299318 assert "nonexistent.param" in warning_msg
300319
320+ def test_internal_match (self ):
321+ """Test not matching internal modules"""
322+
323+ class InternalLinear (InternalModule , nn .Linear ):
324+ pass
325+
326+ linear = InternalLinear (10 , 20 )
327+ matches = list (match_named_parameters (linear , ["re:.*" ]))
328+ assert len (matches ) == 0
329+
301330
302331class TestMatchModulesSet :
303332 """Test cases for match_modules_set function"""
@@ -377,6 +406,16 @@ def test_module_set_with_ignore(self):
377406 # Should have 2 sets (layers 1 and 2, but not 0)
378407 assert len (matches ) == 2
379408
409+ def test_internal_match (self ):
410+ """Test not matching internal modules"""
411+
412+ class InternalLinear (InternalModule , nn .Linear ):
413+ pass
414+
415+ linear = InternalLinear (10 , 20 )
416+ matches = list (match_modules_set (linear , ["re:.*" ]))
417+ assert len (matches ) == 0
418+
380419
381420class TestIntegration :
382421 """Integration tests combining multiple functions"""
0 commit comments