Skip to content

Commit 3abf441

Browse files
yinghsienwucopybara-github
authored andcommitted
fix: Fix the bug to support Gemini Batch inlined requests system instruction
PiperOrigin-RevId: 795249849
1 parent a3c46f5 commit 3abf441

File tree

2 files changed

+157
-1
lines changed

2 files changed

+157
-1
lines changed

google/genai/batches.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2257,6 +2257,17 @@ def create(
22572257
)
22582258
print(batch_job.state)
22592259
"""
2260+
parameter_model = types._CreateBatchJobParameters(
2261+
model=model,
2262+
src=src,
2263+
config=config,
2264+
)
2265+
http_options: Optional[types.HttpOptions] = None
2266+
if (
2267+
parameter_model.config is not None
2268+
and parameter_model.config.http_options is not None
2269+
):
2270+
http_options = parameter_model.config.http_options
22602271
if self._api_client.vertexai:
22612272
if isinstance(src, list):
22622273
raise ValueError(
@@ -2265,6 +2276,65 @@ def create(
22652276
)
22662277

22672278
config = _extra_utils.format_destination(src, config)
2279+
else:
2280+
if isinstance(parameter_model.src, list) or (
2281+
not isinstance(parameter_model.src, str)
2282+
and parameter_model.src
2283+
and parameter_model.src.inlined_requests
2284+
):
2285+
# Handle system instruction in InlinedRequests.
2286+
request_url_dict: Optional[dict[str, str]]
2287+
request_dict: dict[str, Any] = _CreateBatchJobParameters_to_mldev(
2288+
self._api_client, parameter_model
2289+
)
2290+
request_url_dict = request_dict.get('_url')
2291+
if request_url_dict:
2292+
path = '{model}:batchGenerateContent'.format_map(request_url_dict)
2293+
else:
2294+
path = '{model}:batchGenerateContent'
2295+
query_params = request_dict.get('_query')
2296+
if query_params:
2297+
path = f'{path}?{urlencode(query_params)}'
2298+
request_dict.pop('config', None)
2299+
2300+
request_dict = _common.convert_to_dict(request_dict)
2301+
request_dict = _common.encode_unserializable_types(request_dict)
2302+
# Move system instruction to 'request':
2303+
# {'systemInstruction': system_instruction}
2304+
requests = []
2305+
batch_dict = request_dict.get('batch')
2306+
if batch_dict and isinstance(batch_dict, dict):
2307+
input_config_dict = batch_dict.get('inputConfig')
2308+
if input_config_dict and isinstance(input_config_dict, dict):
2309+
requests_dict = input_config_dict.get('requests')
2310+
if requests_dict and isinstance(requests_dict, dict):
2311+
requests = requests_dict.get('requests')
2312+
new_requests = []
2313+
if requests:
2314+
for req in requests:
2315+
if req.get('systemInstruction'):
2316+
value = req.pop('systemInstruction')
2317+
req['request'].update({'systemInstruction': value})
2318+
new_requests.append(req)
2319+
request_dict['batch']['inputConfig']['requests'][ # type: ignore
2320+
'requests'
2321+
] = new_requests
2322+
2323+
response = self._api_client.request(
2324+
'post', path, request_dict, http_options
2325+
)
2326+
2327+
response_dict = '' if not response.body else json.loads(response.body)
2328+
2329+
response_dict = _BatchJob_from_mldev(response_dict)
2330+
2331+
return_value = types.BatchJob._from_response(
2332+
response=response_dict, kwargs=parameter_model.model_dump()
2333+
)
2334+
2335+
self._api_client._verify_response(return_value)
2336+
return return_value
2337+
22682338
return self._create(model=model, src=src, config=config)
22692339

22702340
def list(
@@ -2691,6 +2761,17 @@ async def create(
26912761
src="gs://path/to/input/data",
26922762
)
26932763
"""
2764+
parameter_model = types._CreateBatchJobParameters(
2765+
model=model,
2766+
src=src,
2767+
config=config,
2768+
)
2769+
http_options: Optional[types.HttpOptions] = None
2770+
if (
2771+
parameter_model.config is not None
2772+
and parameter_model.config.http_options is not None
2773+
):
2774+
http_options = parameter_model.config.http_options
26942775
if self._api_client.vertexai:
26952776
if isinstance(src, list):
26962777
raise ValueError(
@@ -2699,6 +2780,65 @@ async def create(
26992780
)
27002781

27012782
config = _extra_utils.format_destination(src, config)
2783+
else:
2784+
if isinstance(parameter_model.src, list) or (
2785+
not isinstance(parameter_model.src, str)
2786+
and parameter_model.src
2787+
and parameter_model.src.inlined_requests
2788+
):
2789+
# Handle system instruction in InlinedRequests.
2790+
request_url_dict: Optional[dict[str, str]]
2791+
request_dict: dict[str, Any] = _CreateBatchJobParameters_to_mldev(
2792+
self._api_client, parameter_model
2793+
)
2794+
request_url_dict = request_dict.get('_url')
2795+
if request_url_dict:
2796+
path = '{model}:batchGenerateContent'.format_map(request_url_dict)
2797+
else:
2798+
path = '{model}:batchGenerateContent'
2799+
query_params = request_dict.get('_query')
2800+
if query_params:
2801+
path = f'{path}?{urlencode(query_params)}'
2802+
request_dict.pop('config', None)
2803+
2804+
request_dict = _common.convert_to_dict(request_dict)
2805+
request_dict = _common.encode_unserializable_types(request_dict)
2806+
# Move system instruction to 'request':
2807+
# {'systemInstruction': system_instruction}
2808+
requests = []
2809+
batch_dict = request_dict.get('batch')
2810+
if batch_dict and isinstance(batch_dict, dict):
2811+
input_config_dict = batch_dict.get('inputConfig')
2812+
if input_config_dict and isinstance(input_config_dict, dict):
2813+
requests_dict = input_config_dict.get('requests')
2814+
if requests_dict and isinstance(requests_dict, dict):
2815+
requests = requests_dict.get('requests')
2816+
new_requests = []
2817+
if requests:
2818+
for req in requests:
2819+
if req.get('systemInstruction'):
2820+
value = req.pop('systemInstruction')
2821+
req['request'].update({'systemInstruction': value})
2822+
new_requests.append(req)
2823+
request_dict['batch']['inputConfig']['requests'][ # type: ignore
2824+
'requests'
2825+
] = new_requests
2826+
2827+
response = await self._api_client.async_request(
2828+
'post', path, request_dict, http_options
2829+
)
2830+
2831+
response_dict = '' if not response.body else json.loads(response.body)
2832+
2833+
response_dict = _BatchJob_from_mldev(response_dict)
2834+
2835+
return_value = types.BatchJob._from_response(
2836+
response=response_dict, kwargs=parameter_model.model_dump()
2837+
)
2838+
2839+
self._api_client._verify_response(return_value)
2840+
return return_value
2841+
27022842
return await self._create(model=model, src=src, config=config)
27032843

27042844
async def list(

google/genai/tests/batches/test_create_with_inlined_requests.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,17 @@
3939
_INLINED_TEXT_REQUEST = {
4040
'contents': [{
4141
'parts': [{
42-
'text': 'What is the QQQ stock price?',
42+
'text': 'high',
4343
}],
4444
'role': 'user',
4545
}],
4646
'config': {
4747
'response_modalities': ['TEXT'],
48+
'system_instruction': 'I say high, you say low',
49+
'thinking_config': {
50+
'include_thoughts': True,
51+
'thinking_budget': 4000,
52+
},
4853
},
4954
}
5055
_INLINED_IMAGE_REQUEST = {
@@ -130,6 +135,17 @@
130135
),
131136
pytest_helper.TestTableItem(
132137
name='test_with_inlined_request',
138+
parameters=types._CreateBatchJobParameters(
139+
model=_MLDEV_GEMINI_MODEL,
140+
src={'inlined_requests': [_INLINED_REQUEST]},
141+
config={
142+
'display_name': _DISPLAY_NAME,
143+
},
144+
),
145+
exception_if_vertex='not supported',
146+
),
147+
pytest_helper.TestTableItem(
148+
name='test_union_with_inlined_request_system_instruction',
133149
parameters=types._CreateBatchJobParameters(
134150
model=_MLDEV_GEMINI_MODEL,
135151
src={'inlined_requests': [_INLINED_TEXT_REQUEST]},

0 commit comments

Comments
 (0)