9
9
import textwrap
10
10
import threading
11
11
import traceback
12
+ from contextlib import asynccontextmanager
12
13
from datetime import datetime , timezone
13
14
from enum import Enum , auto , unique
14
- from typing import TYPE_CHECKING , Any , Awaitable , Callable , Dict , Optional , Type
15
+ from typing import (
16
+ TYPE_CHECKING ,
17
+ Any ,
18
+ AsyncGenerator ,
19
+ Awaitable ,
20
+ Callable ,
21
+ Dict ,
22
+ Optional ,
23
+ Type ,
24
+ )
15
25
16
26
import structlog
17
27
import uvicorn
@@ -120,9 +130,32 @@ def create_app( # pylint: disable=too-many-arguments,too-many-locals,too-many-s
120
130
is_build : bool = False ,
121
131
await_explicit_shutdown : bool = False , # pylint: disable=redefined-outer-name
122
132
) -> MyFastAPI :
133
+ started_at = datetime .now (tz = timezone .utc )
134
+
135
+ @asynccontextmanager
136
+ async def lifespan (app : MyFastAPI ) -> AsyncGenerator [None , None ]:
137
+ # Startup code (was previously in @app.on_event("startup"))
138
+ # check for early setup failures
139
+ if (
140
+ app .state .setup_result
141
+ and app .state .setup_result .status == schema .Status .FAILED
142
+ ):
143
+ # signal shutdown if interactive run
144
+ if shutdown_event and not await_explicit_shutdown :
145
+ shutdown_event .set ()
146
+ else :
147
+ setup_task = runner .setup ()
148
+ setup_task .add_done_callback (_handle_setup_done )
149
+
150
+ yield
151
+
152
+ # Shutdown code (was previously in @app.on_event("shutdown"))
153
+ worker .terminate ()
154
+
123
155
app = MyFastAPI ( # pylint: disable=redefined-outer-name
124
156
title = "Cog" , # TODO: mention model name?
125
157
# version=None # TODO
158
+ lifespan = lifespan ,
126
159
)
127
160
128
161
def custom_openapi () -> Dict [str , Any ]:
@@ -149,7 +182,6 @@ def custom_openapi() -> Dict[str, Any]:
149
182
150
183
app .state .health = Health .STARTING
151
184
app .state .setup_result = None
152
- started_at = datetime .now (tz = timezone .utc )
153
185
154
186
# shutdown is needed no matter what happens
155
187
@app .post ("/shutdown" )
@@ -318,24 +350,6 @@ def cancel_training(
318
350
add_setup_failed_routes (app , started_at , msg )
319
351
return app
320
352
321
- @app .on_event ("startup" )
322
- def startup () -> None :
323
- # check for early setup failures
324
- if (
325
- app .state .setup_result
326
- and app .state .setup_result .status == schema .Status .FAILED
327
- ):
328
- # signal shutdown if interactive run
329
- if shutdown_event and not await_explicit_shutdown :
330
- shutdown_event .set ()
331
- else :
332
- setup_task = runner .setup ()
333
- setup_task .add_done_callback (_handle_setup_done )
334
-
335
- @app .on_event ("shutdown" )
336
- def shutdown () -> None :
337
- worker .terminate ()
338
-
339
353
@app .get ("/" )
340
354
async def root () -> Any :
341
355
return index_document
0 commit comments