Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 73 additions & 15 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def otel_event(self, settings: InstrumentationSettings) -> Event:
__repr__ = _utils.dataclasses_no_defaults_repr


@dataclass(repr=False)
@dataclass(init=False, repr=False)
class FileUrl(ABC):
"""Abstract base class for any URL-based file."""

Expand All @@ -106,11 +106,29 @@ class FileUrl(ABC):
- `GoogleModel`: `VideoUrl.vendor_metadata` is used as `video_metadata`: https://ai.google.dev/gemini-api/docs/video-understanding#customize-video-processing
"""

@property
_media_type: str | None = field(init=False, repr=False)

def __init__(
self,
url: str,
force_download: bool = False,
vendor_metadata: dict[str, Any] | None = None,
media_type: str | None = None,
) -> None:
self.url = url
self.vendor_metadata = vendor_metadata
self.force_download = force_download
self._media_type = media_type

@abstractmethod
def media_type(self) -> str:
def _infer_media_type(self) -> str:
"""Return the media type of the file, based on the url."""

@property
def media_type(self) -> str:
"""Return the media type of the file, based on the url or the provided `_media_type`."""
return self._media_type or self._infer_media_type()

@property
@abstractmethod
def format(self) -> str:
Expand All @@ -119,7 +137,7 @@ def format(self) -> str:
__repr__ = _utils.dataclasses_no_defaults_repr


@dataclass(repr=False)
@dataclass(init=False, repr=False)
class VideoUrl(FileUrl):
"""A URL to a video."""

Expand All @@ -129,8 +147,18 @@ class VideoUrl(FileUrl):
kind: Literal['video-url'] = 'video-url'
"""Type identifier, this is available on all parts as a discriminator."""

@property
def media_type(self) -> VideoMediaType:
def __init__(
self,
url: str,
force_download: bool = False,
vendor_metadata: dict[str, Any] | None = None,
media_type: str | None = None,
kind: Literal['video-url'] = 'video-url',
) -> None:
super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
self.kind = kind

def _infer_media_type(self) -> VideoMediaType:
"""Return the media type of the video, based on the url."""
if self.url.endswith('.mkv'):
return 'video/x-matroska'
Expand Down Expand Up @@ -170,7 +198,7 @@ def format(self) -> VideoFormat:
return _video_format_lookup[self.media_type]


@dataclass(repr=False)
@dataclass(init=False, repr=False)
class AudioUrl(FileUrl):
"""A URL to an audio file."""

Expand All @@ -180,8 +208,18 @@ class AudioUrl(FileUrl):
kind: Literal['audio-url'] = 'audio-url'
"""Type identifier, this is available on all parts as a discriminator."""

@property
def media_type(self) -> AudioMediaType:
def __init__(
self,
url: str,
force_download: bool = False,
vendor_metadata: dict[str, Any] | None = None,
media_type: str | None = None,
kind: Literal['audio-url'] = 'audio-url',
) -> None:
super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
self.kind = kind

def _infer_media_type(self) -> AudioMediaType:
"""Return the media type of the audio file, based on the url.

References:
Expand All @@ -208,7 +246,7 @@ def format(self) -> AudioFormat:
return _audio_format_lookup[self.media_type]


@dataclass(repr=False)
@dataclass(init=False, repr=False)
class ImageUrl(FileUrl):
"""A URL to an image."""

Expand All @@ -218,8 +256,18 @@ class ImageUrl(FileUrl):
kind: Literal['image-url'] = 'image-url'
"""Type identifier, this is available on all parts as a discriminator."""

@property
def media_type(self) -> ImageMediaType:
def __init__(
self,
url: str,
force_download: bool = False,
vendor_metadata: dict[str, Any] | None = None,
media_type: str | None = None,
kind: Literal['image-url'] = 'image-url',
) -> None:
super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
self.kind = kind

def _infer_media_type(self) -> ImageMediaType:
"""Return the media type of the image, based on the url."""
if self.url.endswith(('.jpg', '.jpeg')):
return 'image/jpeg'
Expand All @@ -241,7 +289,7 @@ def format(self) -> ImageFormat:
return _image_format_lookup[self.media_type]


@dataclass(repr=False)
@dataclass(init=False, repr=False)
class DocumentUrl(FileUrl):
"""The URL of the document."""

Expand All @@ -251,8 +299,18 @@ class DocumentUrl(FileUrl):
kind: Literal['document-url'] = 'document-url'
"""Type identifier, this is available on all parts as a discriminator."""

@property
def media_type(self) -> str:
def __init__(
self,
url: str,
force_download: bool = False,
vendor_metadata: dict[str, Any] | None = None,
media_type: str | None = None,
kind: Literal['document-url'] = 'document-url',
) -> None:
super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
self.kind = kind

def _infer_media_type(self) -> str:
"""Return the media type of the document, based on the url."""
type_, _ = guess_type(self.url)
if type_ is None:
Expand Down
7 changes: 6 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,12 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[PartDict]:
file_data_dict['video_metadata'] = item.vendor_metadata
content.append(file_data_dict) # type: ignore
elif isinstance(item, FileUrl):
if self.system == 'google-gla' or item.force_download:
if item.force_download or (
# google-gla does not support passing file urls directly, except for youtube videos
# (see above) and files uploaded to the file API (which cannot be downloaded anyway)
self.system == 'google-gla'
and not item.url.startswith(r'https://generativelanguage.googleapis.com/v1beta/files')
):
downloaded_item = await download_item(item, data_format='base64')
inline_data = {'data': downloaded_item['data'], 'mime_type': downloaded_item['data_type']}
content.append({'inline_data': inline_data}) # type: ignore
Expand Down
21 changes: 13 additions & 8 deletions tests/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@ def test_image_url():
assert image_url.media_type == 'image/jpeg'
assert image_url.format == 'jpeg'

image_url = ImageUrl(url='https://example.com/image', media_type='image/jpeg')
assert image_url.media_type == 'image/jpeg'
assert image_url.format == 'jpeg'

def test_video_url():
with pytest.raises(ValueError, match='Unknown video file extension: https://example.com/video.potato'):
video_url = VideoUrl(url='https://example.com/video.potato')
video_url.media_type

def test_video_url():
video_url = VideoUrl(url='https://example.com/video.mp4')
assert video_url.media_type == 'video/mp4'
assert video_url.format == 'mp4'

video_url = VideoUrl(url='https://example.com/video', media_type='video/mp4')
assert video_url.media_type == 'video/mp4'
assert video_url.format == 'mp4'


@pytest.mark.parametrize(
'url,is_youtube',
Expand All @@ -45,14 +49,14 @@ def test_youtube_video_url(url: str, is_youtube: bool):


def test_document_url():
with pytest.raises(ValueError, match='Unknown document file extension: https://example.com/document.potato'):
document_url = DocumentUrl(url='https://example.com/document.potato')
document_url.media_type

document_url = DocumentUrl(url='https://example.com/document.pdf')
assert document_url.media_type == 'application/pdf'
assert document_url.format == 'pdf'

document_url = DocumentUrl(url='https://example.com/document', media_type='application/pdf')
assert document_url.media_type == 'application/pdf'
assert document_url.format == 'pdf'


@pytest.mark.parametrize(
'media_type, format',
Expand Down Expand Up @@ -129,6 +133,7 @@ def test_binary_content_document(media_type: str, format: str):
pytest.param(AudioUrl('foobar.flac'), 'audio/flac', 'flac', id='flac'),
pytest.param(AudioUrl('foobar.aiff'), 'audio/aiff', 'aiff', id='aiff'),
pytest.param(AudioUrl('foobar.aac'), 'audio/aac', 'aac', id='aac'),
pytest.param(AudioUrl('foobar', media_type='audio/mpeg'), 'audio/mpeg', 'mp3', id='mp3'),
],
)
def test_audio_url(audio_url: AudioUrl, media_type: str, format: str):
Expand Down