@@ -12,31 +12,63 @@ class TestStarlettePatch(TestCase):
12
12
@patch ("amazon.opentelemetry.distro.patches._starlette_patches._logger" )
13
13
def test_starlette_patch_applied_successfully (self , mock_logger ):
14
14
"""Test that the Starlette instrumentation patch is applied successfully."""
15
- # Create a mock StarletteInstrumentor class
16
- mock_instrumentor_class = MagicMock ()
17
- mock_instrumentor_class .__name__ = "StarletteInstrumentor"
18
-
19
- # Create a mock module
20
- mock_starlette_module = MagicMock ()
21
- mock_starlette_module .StarletteInstrumentor = mock_instrumentor_class
22
-
23
- # Mock the import
24
- with patch .dict ("sys.modules" , {"opentelemetry.instrumentation.starlette" : mock_starlette_module }):
25
- # Apply the patch
26
- _apply_starlette_instrumentation_patches ()
27
-
28
- # Verify the instrumentation_dependencies method was replaced
29
- self .assertTrue (hasattr (mock_instrumentor_class , "instrumentation_dependencies" ))
30
-
31
- # Test the patched method returns the expected value
32
- mock_instance = MagicMock ()
33
- result = mock_instrumentor_class .instrumentation_dependencies (mock_instance )
34
- self .assertEqual (result , ("starlette >= 0.13" ,))
35
-
36
- # Verify logging
37
- mock_logger .debug .assert_called_once_with (
38
- "Successfully patched Starlette instrumentation_dependencies method"
39
- )
15
+ for agent_enabled in [True , False ]:
16
+ with self .subTest (agent_enabled = agent_enabled ):
17
+ with patch .dict ("os.environ" , {"AGENT_OBSERVABILITY_ENABLED" : "true" if agent_enabled else "false" }):
18
+ # Create a mock StarletteInstrumentor class
19
+ mock_instrumentor_class = MagicMock ()
20
+ mock_instrumentor_class .__name__ = "StarletteInstrumentor"
21
+
22
+ def create_middleware_class ():
23
+ class MockMiddleware :
24
+ def __init__ (self , app , ** kwargs ):
25
+ pass
26
+
27
+ return MockMiddleware
28
+
29
+ mock_middleware_class = create_middleware_class ()
30
+
31
+ mock_starlette_module = MagicMock ()
32
+ mock_starlette_module .StarletteInstrumentor = mock_instrumentor_class
33
+
34
+ mock_asgi_module = MagicMock ()
35
+ mock_asgi_module .OpenTelemetryMiddleware = mock_middleware_class
36
+
37
+ with patch .dict (
38
+ "sys.modules" ,
39
+ {
40
+ "opentelemetry.instrumentation.starlette" : mock_starlette_module ,
41
+ "opentelemetry.instrumentation.asgi" : mock_asgi_module ,
42
+ },
43
+ ):
44
+ # Apply the patch
45
+ _apply_starlette_instrumentation_patches ()
46
+
47
+ # Verify the instrumentation_dependencies method was replaced
48
+ self .assertTrue (hasattr (mock_instrumentor_class , "instrumentation_dependencies" ))
49
+
50
+ # Test the patched method returns the expected value
51
+ mock_instance = MagicMock ()
52
+ result = mock_instrumentor_class .instrumentation_dependencies (mock_instance )
53
+ self .assertEqual (result , ("starlette >= 0.13" ,))
54
+
55
+ mock_middleware_instance = MagicMock ()
56
+ mock_middleware_instance .exclude_receive_span = False
57
+ mock_middleware_instance .exclude_send_span = False
58
+ mock_middleware_class .__init__ (mock_middleware_instance , "app" )
59
+
60
+ # Test middleware patching sets exclude flags
61
+ if agent_enabled :
62
+ self .assertTrue (mock_middleware_instance .exclude_receive_span )
63
+ self .assertTrue (mock_middleware_instance .exclude_send_span )
64
+ else :
65
+ self .assertFalse (mock_middleware_instance .exclude_receive_span )
66
+ self .assertFalse (mock_middleware_instance .exclude_send_span )
67
+
68
+ # Verify logging
69
+ mock_logger .debug .assert_called_with (
70
+ "Successfully patched Starlette instrumentation_dependencies method"
71
+ )
40
72
41
73
@patch ("amazon.opentelemetry.distro.patches._starlette_patches._logger" )
42
74
def test_starlette_patch_handles_import_error (self , mock_logger ):
0 commit comments