Skip to content
Open
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class InferenceClient:
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
provider (`str`, *optional*):
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
Name of the provider to use for inference. Can be `"bagelnet"`, `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
If model is a URL or `base_url` is passed, then `provider` is not used.
token (`str`, *optional*):
Expand Down
5 changes: 5 additions & 0 deletions src/huggingface_hub/inference/_providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from huggingface_hub.utils import logging

from ._common import TaskProviderHelper, _fetch_inference_provider_mapping
from .bagelnet import BagelNetConversationalTask
from .black_forest_labs import BlackForestLabsTextToImageTask
from .cerebras import CerebrasConversationalTask
from .cohere import CohereConversationalTask
Expand Down Expand Up @@ -45,6 +46,7 @@


PROVIDER_T = Literal[
"bagelnet",
"black-forest-labs",
"cerebras",
"cohere",
Expand All @@ -66,6 +68,9 @@
PROVIDER_OR_POLICY_T = Union[PROVIDER_T, Literal["auto"]]

PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = {
"bagelnet": {
"conversational": BagelNetConversationalTask(),
},
"black-forest-labs": {
"text-to-image": BlackForestLabsTextToImageTask(),
},
Expand Down
1 change: 1 addition & 0 deletions src/huggingface_hub/inference/_providers/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
# provider_id="Qwen2.5-Coder-32B-Instruct",
# task="conversational",
# status="live")
"bagelnet": {},
"cerebras": {},
"cohere": {},
"fal-ai": {},
Expand Down
6 changes: 6 additions & 0 deletions src/huggingface_hub/inference/_providers/bagelnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ._common import BaseConversationalTask


class BagelNetConversationalTask(BaseConversationalTask):
def __init__(self):
super().__init__(provider="bagelnet", base_url="https://api.bagel.net")
31 changes: 31 additions & 0 deletions tests/test_bagelnet_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest

from huggingface_hub.inference._providers.bagelnet import BagelNetConversationalTask


class TestBagelNetConversationalTask:
def test_init(self):
"""Test BagelNet provider initialization."""
task = BagelNetConversationalTask()
assert task.provider == "bagelnet"
assert task.base_url == "https://api.bagel.net"
assert task.task == "conversational"

def test_inheritance(self):
"""Test BagelNet inherits from BaseConversationalTask."""
from huggingface_hub.inference._providers._common import BaseConversationalTask

task = BagelNetConversationalTask()
assert isinstance(task, BaseConversationalTask)

def test_no_method_overrides(self):
"""Test that BagelNet uses default implementations (no overrides needed)."""
task = BagelNetConversationalTask()

# Should use default route
route = task._prepare_route("test_model", "test_key")
assert route == "/v1/chat/completions"

# Should use default base URL behavior
direct_url = task._prepare_base_url("sk-test-key") # Non-HF key
assert direct_url == "https://api.bagel.net"