Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class InferenceClient:
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
provider (`str`, *optional*):
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, `"sambanova"`, `"scaleway"` or `"together"`.
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
If model is a URL or `base_url` is passed, then `provider` is not used.
token (`str`, *optional*):
Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class AsyncInferenceClient:
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
provider (`str`, *optional*):
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, `"sambanova"`, `"scaleway"` or `"together"`.
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
If model is a URL or `base_url` is passed, then `provider` is not used.
token (`str`, *optional*):
Expand Down
6 changes: 6 additions & 0 deletions src/huggingface_hub/inference/_providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .openai import OpenAIConversationalTask
from .replicate import ReplicateImageToImageTask, ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask
from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
from .scaleway import ScalewayConversationalTask, ScalewayFeatureExtractionTask
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask


Expand All @@ -61,6 +62,7 @@
"replicate",
"sambanova",
"together",
"scaleway",
]

PROVIDER_OR_POLICY_T = Union[PROVIDER_T, Literal["auto"]]
Expand Down Expand Up @@ -159,6 +161,10 @@
"conversational": TogetherConversationalTask(),
"text-generation": TogetherTextGenerationTask(),
},
"scaleway": {
"conversational": ScalewayConversationalTask(),
"feature-extraction": ScalewayFeatureExtractionTask(),
},
}


Expand Down
1 change: 1 addition & 0 deletions src/huggingface_hub/inference/_providers/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"replicate": {},
"sambanova": {},
"together": {},
"scaleway": {},
}


Expand Down
28 changes: 28 additions & 0 deletions src/huggingface_hub/inference/_providers/scaleway.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any, Dict, Optional, Union

from huggingface_hub.inference._common import RequestParameters, _as_dict

from ._common import BaseConversationalTask, InferenceProviderMapping, TaskProviderHelper, filter_none


class ScalewayConversationalTask(BaseConversationalTask):
def __init__(self):
super().__init__(provider="scaleway", base_url="https://api.scaleway.ai")


class ScalewayFeatureExtractionTask(TaskProviderHelper):
def __init__(self):
super().__init__(provider="scaleway", base_url="https://api.scaleway.ai", task="feature-extraction")

def _prepare_route(self, mapped_model: str, api_key: str) -> str:
return "/v1/embeddings"

def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
parameters = filter_none(parameters)
return {"input": inputs, "model": provider_mapping_info.provider_id, **parameters}

def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
embeddings = _as_dict(response)["data"]
return [embedding["embedding"] for embedding in embeddings]
70 changes: 70 additions & 0 deletions tests/test_inference_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
ReplicateTextToSpeechTask,
)
from huggingface_hub.inference._providers.sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
from huggingface_hub.inference._providers.scaleway import ScalewayConversationalTask, ScalewayFeatureExtractionTask
from huggingface_hub.inference._providers.together import TogetherTextToImageTask

from .testing_utils import assert_in_logs
Expand Down Expand Up @@ -1077,6 +1078,75 @@ def test_prepare_url_conversational(self):
assert url == "https://api.novita.ai/v3/openai/chat/completions"


class TestScalewayProvider:
def test_prepare_hf_url_conversational(self):
helper = ScalewayConversationalTask()
url = helper._prepare_url("hf_token", "username/repo_name")
assert url == "https://router.huggingface.co/scaleway/v1/chat/completions"

def test_prepare_url_conversational(self):
helper = ScalewayConversationalTask()
url = helper._prepare_url("scw_token", "username/repo_name")
assert url == "https://api.scaleway.ai/v1/chat/completions"

def test_prepare_payload_as_dict(self):
helper = ScalewayConversationalTask()
payload = helper._prepare_payload_as_dict(
[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello!"},
],
{
"max_tokens": 512,
"temperature": 0.15,
"top_p": 1,
"presence_penalty": 0,
"stream": True,
},
InferenceProviderMapping(
provider="scaleway",
hf_model_id="meta-llama/Llama-3.1-8B-Instruct",
providerId="meta-llama/llama-3.1-8B-Instruct",
task="conversational",
status="live",
),
)
assert payload == {
"max_tokens": 512,
"messages": [
{"content": "You are a helpful assistant", "role": "system"},
{"role": "user", "content": "Hello!"},
],
"model": "meta-llama/llama-3.1-8B-Instruct",
"presence_penalty": 0,
"stream": True,
"temperature": 0.15,
"top_p": 1,
}

def test_prepare_url_feature_extraction(self):
helper = ScalewayFeatureExtractionTask()
assert (
helper._prepare_url("hf_token", "username/repo_name")
== "https://router.huggingface.co/scaleway/v1/embeddings"
)

def test_prepare_payload_as_dict_feature_extraction(self):
helper = ScalewayFeatureExtractionTask()
payload = helper._prepare_payload_as_dict(
"Example text to embed",
{"truncate": True},
InferenceProviderMapping(
provider="scaleway",
hf_model_id="username/repo_name",
providerId="provider-id",
task="feature-extraction",
status="live",
),
)
assert payload == {"input": "Example text to embed", "model": "provider-id", "truncate": True}


class TestNscaleProvider:
def test_prepare_route_text_to_image(self):
helper = NscaleTextToImageTask()
Expand Down
Loading