1
1
from __future__ import annotations
2
2
3
+ from contextlib import contextmanager
3
4
from typing import (
4
5
TYPE_CHECKING ,
5
6
Any ,
@@ -162,7 +163,7 @@ def __call__(self, _fn: ValueFn[IT]) -> Self:
162
163
raise TypeError ("Value function must be callable" )
163
164
164
165
# Set value function with extra meta information
165
- self .fn = AsyncValueFn (_fn )
166
+ self .fn = AsyncValueFn (_fn , self )
166
167
167
168
# Copy over function name as it is consistent with how Session and Output
168
169
# retrieve function names
@@ -350,6 +351,7 @@ class AsyncValueFn(Generic[IT]):
350
351
def __init__ (
351
352
self ,
352
353
fn : Callable [[], IT | None ] | Callable [[], Awaitable [IT | None ]],
354
+ renderer : Renderer [Any ],
353
355
):
354
356
if isinstance (fn , AsyncValueFn ):
355
357
raise TypeError (
@@ -358,12 +360,14 @@ def __init__(
358
360
self ._is_async = is_async_callable (fn )
359
361
self ._fn = wrap_async (fn )
360
362
self ._orig_fn = fn
363
+ self ._renderer = renderer
361
364
362
365
async def __call__ (self ) -> IT | None :
363
366
"""
364
367
Call the asynchronous function.
365
368
"""
366
- return await self ._fn ()
369
+ with self ._current_output_id ():
370
+ return await self ._fn ()
367
371
368
372
def is_async (self ) -> bool :
369
373
"""
@@ -404,3 +408,13 @@ def get_sync_fn(self) -> Callable[[], IT | None]:
404
408
)
405
409
sync_fn = cast (Callable [[], IT ], self ._orig_fn )
406
410
return sync_fn
411
+
412
+ @contextmanager
413
+ def _current_output_id (self ):
414
+ from ...session import get_current_session
415
+
416
+ session = get_current_session ()
417
+ if session is not None :
418
+ session .current_output_id = self ._renderer .output_id
419
+ yield
420
+ session .current_output_id = None
0 commit comments