Skip to content

Commit 5785931

Browse files
zhangjyrJingyuan Zhang
authored andcommitted
[Feat] Batch API Service, working with temporary existing local batch driver (vllm-project#1298)
* Basic metadata server with new batchjob work with legacy local batch driver. * Connect job_manager with batch API * Support async create_job * Finish e2e batch server test. * lint-fix * Review fix * Adjust unit tests' folder. * Add __init__.py in batch test * Add missing package. * Lint fix * Restore mis-merged code to pass unit tests. * Remove JobCache from this PR, which will be major refactored in later PR. * Remove k8s_transformer from this PR since JobCache was removed * Unify import format. Signed-off-by: Jingyuan Zhang <[email protected]> Signed-off-by: Jingyuan <[email protected]> Co-authored-by: Jingyuan Zhang <[email protected]> Signed-off-by: ChethanUK <[email protected]>
1 parent cb34f87 commit 5785931

34 files changed

+3489
-415
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ development/simulator/cache
5656

5757
# python build artifacts
5858
python/aibrix/dist/
59+
python/aibrix/aibrix/batch/storage/data
5960

6061
# setuptools-scm generated version files
6162
**/_version.py

python/aibrix/aibrix/batch/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from .driver import BatchDriver
1415

15-
16-
__all__ = ["create_batch_input", "retrieve_batch_job_content"]
16+
__all__ = ["BatchDriver"]

python/aibrix/aibrix/batch/driver.py

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,42 +13,47 @@
1313
# limitations under the License.
1414

1515
import asyncio
16+
from typing import Optional
1617

1718
import aibrix.batch.storage as _storage
1819
from aibrix.batch.constant import DEFAULT_JOB_POOL_SIZE
20+
from aibrix.batch.job_entity import JobEntityManager
1921
from aibrix.batch.job_manager import JobManager
2022
from aibrix.batch.request_proxy import RequestProxy
2123
from aibrix.batch.scheduler import JobScheduler
24+
from aibrix.metadata.logger import init_logger
25+
26+
logger = init_logger(__name__)
2227

2328

2429
class BatchDriver:
25-
def __init__(self):
30+
def __init__(self, job_entity_manager: Optional[JobEntityManager] = None):
2631
"""
2732
This is main entrance to bind all components to serve job requests.
2833
"""
2934
_storage.initialize_storage()
3035
self._storage = _storage
31-
self._job_manager = JobManager()
32-
self._scheduler = JobScheduler(self._job_manager, DEFAULT_JOB_POOL_SIZE)
33-
self._proxy = RequestProxy(self._storage, self._job_manager)
34-
asyncio.create_task(self.jobs_running_loop())
35-
36-
def upload_batch_data(self, input_file_name):
37-
job_id = self._storage.submit_job_input(input_file_name)
38-
return job_id
39-
40-
def create_job(self, job_id, endpoint, window_due_time):
41-
self._job_manager.create_job(job_id, endpoint, window_due_time)
36+
self._job_manager: JobManager = JobManager(job_entity_manager)
37+
self._scheduler: Optional[JobScheduler] = None
38+
self._scheduling_task: Optional[asyncio.Task] = None
39+
self._proxy: RequestProxy = RequestProxy(self._storage, self._job_manager)
40+
# Only create jobs_running_loop if JobEntityManager does not have its own sched
41+
if not job_entity_manager or not job_entity_manager.is_scheduler_enabled():
42+
self._scheduler = JobScheduler(self._job_manager, DEFAULT_JOB_POOL_SIZE)
43+
self._job_manager.set_scheduler(self._scheduler)
44+
self._scheduling_task = asyncio.create_task(self.jobs_running_loop())
4245

43-
due_time = self._job_manager.get_job_window_due(job_id)
44-
self._scheduler.append_job(job_id, due_time)
46+
@property
47+
def job_manager(self) -> JobManager:
48+
return self._job_manager
4549

46-
def get_job_status(self, job_id):
47-
return self._job_manager.get_job_status(job_id)
50+
def upload_batch_data(self, input_file_name):
51+
file_id = self._storage.submit_job_input(input_file_name)
52+
return file_id
4853

49-
def retrieve_job_result(self, job_id):
50-
num_requests = _storage.get_job_num_request(job_id)
51-
req_results = _storage.get_job_results(job_id, 0, num_requests)
54+
def retrieve_job_result(self, file_id):
55+
num_requests = _storage.get_job_num_request(file_id)
56+
req_results = _storage.get_job_results(file_id, 0, num_requests)
5257
return req_results
5358

5459
async def jobs_running_loop(self):
@@ -57,11 +62,37 @@ async def jobs_running_loop(self):
5762
For now, the executing unit is one request. Later if necessary,
5863
we can support a batch size of request per execution.
5964
"""
65+
logger.info("Starting scheduling...")
6066
while True:
61-
one_job = self._scheduler.round_robin_get_job()
67+
one_job = await self._scheduler.round_robin_get_job()
6268
if one_job:
6369
await self._proxy.execute_queries(one_job)
6470
await asyncio.sleep(0)
6571

72+
async def close(self):
73+
"""Properly shutdown the driver and cancel running tasks"""
74+
if self._scheduling_task and not self._scheduling_task.done():
75+
self._scheduling_task.cancel()
76+
try:
77+
await self._scheduling_task
78+
except (asyncio.CancelledError, RuntimeError) as e:
79+
if isinstance(e, RuntimeError) and "different loop" in str(e):
80+
logger.warning(
81+
"Task cancellation from different event loop, forcing cancellation"
82+
)
83+
pass
84+
if self._scheduler:
85+
await self._scheduler.close()
86+
6687
def clear_job(self, job_id):
67-
self._storage.delete_job(job_id)
88+
job = self._job_manager.get_job(job_id)
89+
if job is None:
90+
return
91+
92+
self._job_manager.job_deleted_handler(job)
93+
if self._job_manager.get_job(job_id) is None:
94+
self._storage.delete_job(job.spec.input_file_id)
95+
if job.status.output_file_id is not None:
96+
self._storage.delete_job(job.status.output_file_id)
97+
if job.status.error_file_id is not None:
98+
self._storage.delete_job(job.status.error_file_id)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2024 The Aibrix Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .batch_job import (
16+
BatchJob,
17+
BatchJobEndpoint,
18+
BatchJobError,
19+
BatchJobErrorCode,
20+
BatchJobSpec,
21+
BatchJobState,
22+
BatchJobStatus,
23+
CompletionWindow,
24+
Condition,
25+
ConditionStatus,
26+
ConditionType,
27+
ObjectMeta,
28+
RequestCountStats,
29+
TypeMeta,
30+
)
31+
from .job_entity_manager import JobEntityManager
32+
33+
__all__ = [
34+
"BatchJob",
35+
"BatchJobEndpoint",
36+
"BatchJobSpec",
37+
"BatchJobState",
38+
"BatchJobErrorCode",
39+
"BatchJobError",
40+
"BatchJobStatus",
41+
"CompletionWindow",
42+
"Condition",
43+
"ConditionStatus",
44+
"ConditionType",
45+
"JobEntityManager",
46+
"ObjectMeta",
47+
"RequestCountStats",
48+
"TypeMeta",
49+
]

0 commit comments

Comments
 (0)