Skip to content

Commit 382b150

Browse files
kiberguscopybara-github
authored andcommitted
feat: Make genai.Part constructible from PartUnionDict.
PiperOrigin-RevId: 813307548
1 parent 29262e1 commit 382b150

File tree

2 files changed

+130
-1
lines changed

2 files changed

+130
-1
lines changed

google/genai/tests/types/test_types.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import sys
1919
import typing
2020
from typing import Optional, assert_never
21+
import PIL.Image
2122
import pydantic
2223
import pytest
2324
from ... import types
@@ -262,6 +263,58 @@ def test_factory_method_from_mcp_call_tool_function_response_embedded_resource()
262263
assert isinstance(my_function_response, types.FunctionResponse)
263264

264265

266+
def test_part_constructor_with_string_value():
267+
part = types.Part('hello')
268+
assert part.text == 'hello'
269+
assert part.file_data is None
270+
assert part.inline_data is None
271+
272+
273+
def test_part_constructor_with_part_value():
274+
other_part = types.Part(text='hello from other part')
275+
part = types.Part(other_part)
276+
assert part.text == 'hello from other part'
277+
278+
279+
def test_part_constructor_with_part_dict_value():
280+
part = types.Part({'text': 'hello from dict'})
281+
assert part.text == 'hello from dict'
282+
283+
284+
def test_part_constructor_with_file_data_dict_value():
285+
part = types.Part(
286+
{'file_uri': 'gs://my-bucket/file-data', 'mime_type': 'text/plain'}
287+
)
288+
assert part.file_data.file_uri == 'gs://my-bucket/file-data'
289+
assert part.file_data.mime_type == 'text/plain'
290+
291+
292+
def test_part_constructor_with_kwargs_and_value_fails():
293+
with pytest.raises(
294+
ValueError, match='Positional and keyword arguments can not be combined'
295+
):
296+
types.Part('hello', text='world')
297+
298+
299+
def test_part_constructor_with_file_value():
300+
f = types.File(
301+
uri='gs://my-bucket/my-file',
302+
mime_type='text/plain',
303+
display_name='test file',
304+
)
305+
part = types.Part(f)
306+
assert part.file_data.file_uri == 'gs://my-bucket/my-file'
307+
assert part.file_data.mime_type == 'text/plain'
308+
assert part.file_data.display_name == 'test file'
309+
310+
311+
def test_part_constructor_with_pil_image():
312+
img = PIL.Image.new('RGB', (1, 1), color='red')
313+
part = types.Part(img)
314+
assert part.inline_data.mime_type == 'image/jpeg'
315+
assert isinstance(part.inline_data.data, bytes)
316+
317+
265318
class FakeClient:
266319

267320
def __init__(self, vertexai=False) -> None:

google/genai/types.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@
1919
import datetime
2020
from enum import Enum, EnumMeta
2121
import inspect
22+
import io
2223
import json
2324
import logging
2425
import sys
2526
import types as builtin_types
2627
import typing
27-
from typing import Any, Callable, Literal, Optional, Sequence, Union, _UnionGenericAlias # type: ignore
28+
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union, _UnionGenericAlias # type: ignore
2829
import pydantic
2930
from pydantic import ConfigDict, Field, PrivateAttr, model_validator
3031
from typing_extensions import Self, TypedDict
@@ -1260,6 +1261,81 @@ class Part(_common.BaseModel):
12601261
default=None, description="""Optional. Text part (can be code)."""
12611262
)
12621263

1264+
def __init__(
1265+
self,
1266+
value: Optional['PartUnionDict'] = None,
1267+
/,
1268+
*,
1269+
video_metadata: Optional[VideoMetadata] = None,
1270+
thought: Optional[bool] = None,
1271+
inline_data: Optional[Blob] = None,
1272+
file_data: Optional[FileData] = None,
1273+
thought_signature: Optional[bytes] = None,
1274+
function_call: Optional[FunctionCall] = None,
1275+
code_execution_result: Optional[CodeExecutionResult] = None,
1276+
executable_code: Optional[ExecutableCode] = None,
1277+
function_response: Optional[FunctionResponse] = None,
1278+
text: Optional[str] = None,
1279+
# Pydantic allows CamelCase in addition to snake_case attribute
1280+
# names. kwargs here catch these aliases.
1281+
**kwargs: Any,
1282+
):
1283+
part_dict = dict(
1284+
video_metadata=video_metadata,
1285+
thought=thought,
1286+
inline_data=inline_data,
1287+
file_data=file_data,
1288+
thought_signature=thought_signature,
1289+
function_call=function_call,
1290+
code_execution_result=code_execution_result,
1291+
executable_code=executable_code,
1292+
function_response=function_response,
1293+
text=text,
1294+
)
1295+
part_dict = {k: v for k, v in part_dict.items() if v is not None}
1296+
1297+
if part_dict and value is not None:
1298+
raise ValueError(
1299+
'Positional and keyword arguments can not be combined when '
1300+
'initializing a Part.'
1301+
)
1302+
1303+
if value is None:
1304+
pass
1305+
elif isinstance(value, str):
1306+
part_dict['text'] = value
1307+
elif isinstance(value, File):
1308+
if not value.uri or not value.mime_type:
1309+
raise ValueError('file uri and mime_type are required.')
1310+
part_dict['file_data'] = FileData(
1311+
file_uri=value.uri,
1312+
mime_type=value.mime_type,
1313+
display_name=value.display_name,
1314+
)
1315+
elif isinstance(value, dict):
1316+
try:
1317+
Part.model_validate(value)
1318+
part_dict.update(value) # type: ignore[arg-type]
1319+
except pydantic.ValidationError:
1320+
part_dict['file_data'] = FileData.model_validate(value)
1321+
elif isinstance(value, Part):
1322+
part_dict.update(value.dict())
1323+
elif 'image' in value.__class__.__name__.lower():
1324+
# PIL.Image case.
1325+
1326+
suffix = value.format.lower() if value.format else 'jpeg'
1327+
mimetype = f'image/{suffix}'
1328+
bytes_io = io.BytesIO()
1329+
value.save(bytes_io, suffix.upper())
1330+
1331+
part_dict['inline_data'] = Blob(
1332+
data=bytes_io.getvalue(), mime_type=mimetype
1333+
)
1334+
else:
1335+
raise ValueError(f'Unsupported content part type: {type(value)}')
1336+
1337+
super().__init__(**part_dict, **kwargs)
1338+
12631339
def as_image(self) -> Optional['Image']:
12641340
"""Returns the part as a PIL Image, or None if the part is not an image."""
12651341
if not self.inline_data:

0 commit comments

Comments
 (0)