Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ class FileUrl(ABC):
url: str
"""The URL of the file."""

_media_type: str | None = field(default=None, repr=False)
"""Optional override for the media type of the file, in case it cannot be inferred from the URL."""

force_download: bool = False
"""If the model supports it:

Expand All @@ -106,11 +109,18 @@ class FileUrl(ABC):
- `GoogleModel`: `VideoUrl.vendor_metadata` is used as `video_metadata`: https://ai.google.dev/gemini-api/docs/video-understanding#customize-video-processing
"""

@property
@abstractmethod
def media_type(self) -> str:
def _infer_media_type_from_url(self) -> str:
"""Return the media type of the file, based on the url."""

@property
def media_type(self) -> str:
"""Return the media type of the file, based on the url or the provided `_media_type`."""
if self._media_type is not None:
return self._media_type
else:
return self._infer_media_type_from_url()

@property
@abstractmethod
def format(self) -> str:
Expand All @@ -129,8 +139,7 @@ class VideoUrl(FileUrl):
kind: Literal['video-url'] = 'video-url'
"""Type identifier, this is available on all parts as a discriminator."""

@property
def media_type(self) -> VideoMediaType:
def _infer_media_type_from_url(self) -> VideoMediaType:
"""Return the media type of the video, based on the url."""
if self.url.endswith('.mkv'):
return 'video/x-matroska'
Expand Down Expand Up @@ -180,8 +189,7 @@ class AudioUrl(FileUrl):
kind: Literal['audio-url'] = 'audio-url'
"""Type identifier, this is available on all parts as a discriminator."""

@property
def media_type(self) -> AudioMediaType:
def _infer_media_type_from_url(self) -> AudioMediaType:
"""Return the media type of the audio file, based on the url.

References:
Expand Down Expand Up @@ -218,8 +226,7 @@ class ImageUrl(FileUrl):
kind: Literal['image-url'] = 'image-url'
"""Type identifier, this is available on all parts as a discriminator."""

@property
def media_type(self) -> ImageMediaType:
def _infer_media_type_from_url(self) -> ImageMediaType:
"""Return the media type of the image, based on the url."""
if self.url.endswith(('.jpg', '.jpeg')):
return 'image/jpeg'
Expand Down Expand Up @@ -251,8 +258,7 @@ class DocumentUrl(FileUrl):
kind: Literal['document-url'] = 'document-url'
"""Type identifier, this is available on all parts as a discriminator."""

@property
def media_type(self) -> str:
def _infer_media_type_from_url(self) -> str:
"""Return the media type of the document, based on the url."""
type_, _ = guess_type(self.url)
if type_ is None:
Expand Down
7 changes: 6 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,12 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[PartDict]:
file_data_dict['video_metadata'] = item.vendor_metadata
content.append(file_data_dict) # type: ignore
elif isinstance(item, FileUrl):
if self.system == 'google-gla' or item.force_download:
if item.force_download or (
# google-gla does not support passing file urls directly, except for youtube videos
# (see above) and files uploaded to the file API (which cannot be downloaded anyway)
self.system == 'google-gla'
and not item.url.startswith(r'https://generativelanguage.googleapis.com/v1beta/files')
):
downloaded_item = await download_item(item, data_format='base64')
inline_data = {'data': downloaded_item['data'], 'mime_type': downloaded_item['data_type']}
content.append({'inline_data': inline_data}) # type: ignore
Expand Down
21 changes: 13 additions & 8 deletions tests/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@ def test_image_url():
assert image_url.media_type == 'image/jpeg'
assert image_url.format == 'jpeg'

image_url = ImageUrl(url='https://example.com/image', _media_type='image/jpeg')
assert image_url.media_type == 'image/jpeg'
assert image_url.format == 'jpeg'

def test_video_url():
with pytest.raises(ValueError, match='Unknown video file extension: https://example.com/video.potato'):
video_url = VideoUrl(url='https://example.com/video.potato')
video_url.media_type

def test_video_url():
video_url = VideoUrl(url='https://example.com/video.mp4')
assert video_url.media_type == 'video/mp4'
assert video_url.format == 'mp4'

video_url = VideoUrl(url='https://example.com/video', _media_type='video/mp4')
assert video_url.media_type == 'video/mp4'
assert video_url.format == 'mp4'


@pytest.mark.parametrize(
'url,is_youtube',
Expand All @@ -45,14 +49,14 @@ def test_youtube_video_url(url: str, is_youtube: bool):


def test_document_url():
with pytest.raises(ValueError, match='Unknown document file extension: https://example.com/document.potato'):
document_url = DocumentUrl(url='https://example.com/document.potato')
document_url.media_type

document_url = DocumentUrl(url='https://example.com/document.pdf')
assert document_url.media_type == 'application/pdf'
assert document_url.format == 'pdf'

document_url = DocumentUrl(url='https://example.com/document', _media_type='application/pdf')
assert document_url.media_type == 'application/pdf'
assert document_url.format == 'pdf'


@pytest.mark.parametrize(
'media_type, format',
Expand Down Expand Up @@ -129,6 +133,7 @@ def test_binary_content_document(media_type: str, format: str):
pytest.param(AudioUrl('foobar.flac'), 'audio/flac', 'flac', id='flac'),
pytest.param(AudioUrl('foobar.aiff'), 'audio/aiff', 'aiff', id='aiff'),
pytest.param(AudioUrl('foobar.aac'), 'audio/aac', 'aac', id='aac'),
pytest.param(AudioUrl('foobar', 'audio/mpeg'), 'audio/mpeg', 'mp3', id='mp3'),
],
)
def test_audio_url(audio_url: AudioUrl, media_type: str, format: str):
Expand Down