@@ -33,12 +33,8 @@ class _TestingAgent(BaseAgent):
33
33
delay : float = 0
34
34
"""The delay before the agent generates an event."""
35
35
36
- @override
37
- async def _run_async_impl (
38
- self , ctx : InvocationContext
39
- ) -> AsyncGenerator [Event , None ]:
40
- await asyncio .sleep (self .delay )
41
- yield Event (
36
+ def event (self , ctx : InvocationContext ):
37
+ return Event (
42
38
author = self .name ,
43
39
branch = ctx .branch ,
44
40
invocation_id = ctx .invocation_id ,
@@ -47,6 +43,13 @@ async def _run_async_impl(
47
43
),
48
44
)
49
45
46
+ @override
47
+ async def _run_async_impl (
48
+ self , ctx : InvocationContext
49
+ ) -> AsyncGenerator [Event , None ]:
50
+ await asyncio .sleep (self .delay )
51
+ yield self .event (ctx )
52
+
50
53
51
54
async def _create_parent_invocation_context (
52
55
test_name : str , agent : BaseAgent
@@ -135,3 +138,102 @@ async def test_run_async_branches(request: pytest.FixtureRequest):
135
138
# Sub-agents should have different branches.
136
139
assert events [2 ].branch != events [1 ].branch
137
140
assert events [2 ].branch != events [0 ].branch
141
+
142
+
143
+ class _TestingAgentWithMultipleEvents (_TestingAgent ):
144
+ """Mock agent for testing."""
145
+
146
+ @override
147
+ async def _run_async_impl (
148
+ self , ctx : InvocationContext
149
+ ) -> AsyncGenerator [Event , None ]:
150
+ for _ in range (0 , 3 ):
151
+ event = self .event (ctx )
152
+ yield event
153
+ # Check that the event was processed by the consumer.
154
+ assert event .custom_metadata is not None
155
+ assert event .custom_metadata ['processed' ]
156
+
157
+
158
+ @pytest .mark .asyncio
159
+ async def test_generating_one_event_per_agent_at_once (
160
+ request : pytest .FixtureRequest ,
161
+ ):
162
+ # This test is to verify that the parallel agent won't generate more than one
163
+ # event per agent at a time.
164
+ agent1 = _TestingAgentWithMultipleEvents (
165
+ name = f'{ request .function .__name__ } _test_agent_1'
166
+ )
167
+ agent2 = _TestingAgentWithMultipleEvents (
168
+ name = f'{ request .function .__name__ } _test_agent_2'
169
+ )
170
+ parallel_agent = ParallelAgent (
171
+ name = f'{ request .function .__name__ } _test_parallel_agent' ,
172
+ sub_agents = [
173
+ agent1 ,
174
+ agent2 ,
175
+ ],
176
+ )
177
+ parent_ctx = await _create_parent_invocation_context (
178
+ request .function .__name__ , parallel_agent
179
+ )
180
+
181
+ agen = parallel_agent .run_async (parent_ctx )
182
+ async for event in agen :
183
+ event .custom_metadata = {'processed' : True }
184
+ # Asserts on event are done in _TestingAgentWithMultipleEvents.
185
+
186
+
187
+ class _TestingAgentWithException (_TestingAgent ):
188
+ """Mock agent for testing."""
189
+
190
+ @override
191
+ async def _run_async_impl (
192
+ self , ctx : InvocationContext
193
+ ) -> AsyncGenerator [Event , None ]:
194
+ yield self .event (ctx )
195
+ raise Exception ()
196
+
197
+
198
+ class _TestingAgentInfiniteEvents (_TestingAgent ):
199
+ """Mock agent for testing."""
200
+
201
+ @override
202
+ async def _run_async_impl (
203
+ self , ctx : InvocationContext
204
+ ) -> AsyncGenerator [Event , None ]:
205
+ while True :
206
+ yield self .event (ctx )
207
+
208
+
209
+ @pytest .mark .asyncio
210
+ async def test_stop_agent_if_sub_agent_fails (
211
+ request : pytest .FixtureRequest ,
212
+ ):
213
+ # This test is to verify that the parallel agent and subagents will all stop
214
+ # processing and throw exception to top level runner in case of exception.
215
+ agent1 = _TestingAgentWithException (
216
+ name = f'{ request .function .__name__ } _test_agent_1'
217
+ )
218
+ agent2 = _TestingAgentInfiniteEvents (
219
+ name = f'{ request .function .__name__ } _test_agent_2'
220
+ )
221
+ parallel_agent = ParallelAgent (
222
+ name = f'{ request .function .__name__ } _test_parallel_agent' ,
223
+ sub_agents = [
224
+ agent1 ,
225
+ agent2 ,
226
+ ],
227
+ )
228
+ parent_ctx = await _create_parent_invocation_context (
229
+ request .function .__name__ , parallel_agent
230
+ )
231
+
232
+ agen = parallel_agent .run_async (parent_ctx )
233
+ # We expect to receive an exception from one of subagents.
234
+ # The exception should be propagated to root agent and other subagents.
235
+ # Otherwise we'll have an infinite loop.
236
+ with pytest .raises (Exception ):
237
+ async for _ in agen :
238
+ # The infinite agent could iterate a few times depending on scheduling.
239
+ pass
0 commit comments