Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
642b803
add mypy pydantic plugin settings
voorhs Sep 3, 2025
20e155a
implement base class and interface class
voorhs Sep 3, 2025
db9d0f1
refactor embedder config
voorhs Sep 3, 2025
29981f6
add sentence transformer embedding backend
voorhs Sep 3, 2025
89b313d
add openai embedding backend
voorhs Sep 3, 2025
c5b08e1
re-refactor embedder configs
voorhs Sep 3, 2025
1bc59b8
re-refactor dump/load
voorhs Sep 3, 2025
1091044
add proper dump/load to Embedder
voorhs Sep 3, 2025
7176660
handle default embedder config usage
voorhs Sep 3, 2025
a0bb255
fix some typing errors
voorhs Sep 3, 2025
266b45c
fix a couple more
voorhs Sep 3, 2025
2cd16bb
fix some more typing errors
voorhs Sep 3, 2025
b0a48b6
one more error
voorhs Sep 3, 2025
3160621
is it all?
voorhs Sep 3, 2025
913987b
Update optimizer_config.schema.json
github-actions[bot] Sep 3, 2025
1b4a3ee
bug fix
voorhs Sep 4, 2025
51f6d0e
fix some tests
voorhs Sep 4, 2025
f911ab1
temporary way to fix tests
voorhs Sep 4, 2025
489c182
refactor embedder tests
voorhs Sep 4, 2025
67872c2
fix some tests
voorhs Sep 4, 2025
cb24d83
Update optimizer_config.schema.json
github-actions[bot] Sep 4, 2025
7c20546
try to fix dynamic schema issues
voorhs Sep 4, 2025
5a6917d
Update optimizer_config.schema.json
github-actions[bot] Sep 4, 2025
6b80fb2
upd vector index tests
voorhs Sep 9, 2025
1534056
upd inference test
voorhs Sep 9, 2025
5c3f81a
upd tutorials
voorhs Sep 9, 2025
813c179
ignore ds store
voorhs Sep 20, 2025
5ce3e5d
set similarity_fn default to None
voorhs Sep 20, 2025
d5442dc
upd callback test
voorhs Sep 20, 2025
7263359
remove unnecessary import
voorhs Sep 20, 2025
4ecb061
run code formatter
voorhs Sep 20, 2025
a4217da
remove unnecessary import
voorhs Sep 20, 2025
181aac9
add openai base url option
voorhs Sep 21, 2025
77f0287
remove openai api key everywhere for security reasons
voorhs Sep 21, 2025
88b339a
ignore extra envs in mcp server
voorhs Sep 21, 2025
b9e4994
add typed marker
voorhs Sep 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,4 @@ vector_db*
/wandb
model_output/
my.py
.DS_store
83 changes: 2 additions & 81 deletions docs/optimizer_config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -123,54 +123,8 @@
},
"EmbedderConfig": {
"additionalProperties": false,
"description": "Base class for embedder configurations.",
"properties": {
"model_name": {
"default": "sentence-transformers/all-MiniLM-L6-v2",
"description": "Name of the hugging face model.",
"title": "Model Name",
"type": "string"
},
"batch_size": {
"default": 32,
"description": "Batch size for model inference.",
"exclusiveMinimum": 0,
"title": "Batch Size",
"type": "integer"
},
"device": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"description": "Torch notation for CPU or CUDA.",
"title": "Device"
},
"bf16": {
"default": false,
"description": "Whether to use mixed precision training (not all devices support this).",
"title": "Bf16",
"type": "boolean"
},
"fp16": {
"default": false,
"description": "Whether to use mixed precision training (not all devices support this).",
"title": "Fp16",
"type": "boolean"
},
"tokenizer_config": {
"$ref": "#/$defs/TokenizerConfig"
},
"trust_remote_code": {
"default": false,
"description": "Whether to trust the remote code when loading the model.",
"title": "Trust Remote Code",
"type": "boolean"
},
"default_prompt": {
"anyOf": [
{
Expand Down Expand Up @@ -249,18 +203,6 @@
"description": "Prompt for passage.",
"title": "Passage Prompt"
},
"similarity_fn_name": {
"default": "cosine",
"description": "Name of the similarity function to use.",
"enum": [
"cosine",
"dot",
"euclidean",
"manhattan"
],
"title": "Similarity Fn Name",
"type": "string"
},
"use_cache": {
"default": true,
"description": "Whether to use embeddings caching.",
Expand Down Expand Up @@ -552,28 +494,7 @@
}
},
"embedder_config": {
"$ref": "#/$defs/EmbedderConfig",
"default": {
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
"batch_size": 32,
"device": null,
"bf16": false,
"fp16": false,
"tokenizer_config": {
"max_length": null,
"padding": true,
"truncation": true
},
"trust_remote_code": false,
"default_prompt": null,
"classification_prompt": null,
"cluster_prompt": null,
"sts_prompt": null,
"query_prompt": null,
"passage_prompt": null,
"similarity_fn_name": "cosine",
"use_cache": true
}
"$ref": "#/$defs/EmbedderConfig"
},
"cross_encoder_config": {
"$ref": "#/$defs/CrossEncoderConfig",
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ plugins = [
mypy_path = "src/autointent"
disable_error_code = ["override"]

[tool.pydantic-mypy]
init_forbid_extra = true
init_typed = true

[[tool.mypy.overrides]]
module = [
"scipy",
Expand Down
20 changes: 17 additions & 3 deletions src/autointent/_optimization_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from typing import Any

from pydantic import BaseModel, PositiveInt
from pydantic import BaseModel, Field, PositiveInt, field_validator

from .configs import CrossEncoderConfig, DataConfig, EmbedderConfig, HFModelConfig, HPOConfig, LoggingConfig
from .configs import (
CrossEncoderConfig,
DataConfig,
EmbedderConfig,
HFModelConfig,
HPOConfig,
LoggingConfig,
initialize_embedder_config,
)


class OptimizationConfig(BaseModel):
Expand All @@ -20,7 +28,13 @@ class OptimizationConfig(BaseModel):
logging_config: LoggingConfig = LoggingConfig()
"""See tutorial on logging configuration."""

embedder_config: EmbedderConfig = EmbedderConfig()
embedder_config: EmbedderConfig = Field(default_factory=lambda: initialize_embedder_config(None))

@field_validator("embedder_config", mode="before")
@classmethod
def validate_embedder_config(cls, v: Any) -> EmbedderConfig: # noqa: ANN401
"""Validate and convert embedder config to proper type."""
return initialize_embedder_config(v)

cross_encoder_config: CrossEncoderConfig = CrossEncoderConfig()

Expand Down
22 changes: 18 additions & 4 deletions src/autointent/_pipeline/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
InferenceNodeConfig,
LoggingConfig,
VectorIndexConfig,
get_default_embedder_config,
get_default_vector_index_config,
)
from autointent.custom_types import ListOfGenericLabels, NodeType, SearchSpacePreset, SearchSpaceValidationMode
Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(

if isinstance(nodes[0], NodeOptimizer):
self.logging_config = LoggingConfig()
self.embedder_config = EmbedderConfig()
self.embedder_config = get_default_embedder_config()
self.cross_encoder_config = CrossEncoderConfig()
self.data_config = DataConfig()
self.transformer_config = HFModelConfig()
Expand Down Expand Up @@ -111,7 +112,7 @@ def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed
return cls(nodes=nodes, seed=seed)

@classmethod
def from_preset(cls, name: SearchSpacePreset, seed: int | None = 42) -> "Pipeline":
def from_preset(cls, name: SearchSpacePreset, seed: int = 42) -> "Pipeline":
"""Instantiate pipeline optimizer from a preset."""
optimization_config = load_preset(name)
config = OptimizationConfig(seed=seed, **optimization_config)
Expand Down Expand Up @@ -395,6 +396,19 @@ def _refit(self, context: Context) -> None:
decision_module.clear_cache()
decision_module.fit(scores, context.data_handler.train_labels(1), context.data_handler.tags)

def _convert_score_to_float_list(self, score: Any) -> list[float]: # noqa: ANN401
"""Convert score to list of floats for InferencePipelineUtteranceOutput."""
if hasattr(score, "tolist"):
result = score.tolist()
return result if isinstance(result, list) else [float(result)]
if score is None:
return []
if isinstance(score, int | float):
return [float(score)]
if hasattr(score, "__iter__") and not isinstance(score, str):
return [float(x) for x in score]
return [float(score)]

def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutput:
"""Predict the labels for the utterances with metadata.

Expand Down Expand Up @@ -422,13 +436,13 @@ def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutpu
regex_prediction_metadata=regex_predictions_metadata[idx]
if regex_predictions_metadata is not None
else None,
score=scores[idx],
score=self._convert_score_to_float_list(scores[idx]),
score_metadata=scores_metadata[idx] if scores_metadata is not None else None,
)
outputs.append(output)

return InferencePipelineOutput(
predictions=predictions,
predictions=predictions, # type: ignore[arg-type]
regex_predictions=regex_predictions,
utterances=outputs,
)
Expand Down
13 changes: 13 additions & 0 deletions src/autointent/_wrappers/embedder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Embedder module with multiple backend support."""

from .base import BaseEmbeddingBackend
from .embedder import Embedder
from .openai import OpenaiEmbeddingBackend
from .sentence_transformers import SentenceTransformerEmbeddingBackend

__all__ = [
"BaseEmbeddingBackend",
"Embedder",
"OpenaiEmbeddingBackend",
"SentenceTransformerEmbeddingBackend",
]
102 changes: 102 additions & 0 deletions src/autointent/_wrappers/embedder/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Literal, overload

import numpy as np
import numpy.typing as npt
import torch

from autointent.configs import EmbedderConfig, TaskTypeEnum


class BaseEmbeddingBackend(ABC):
"""Abstract base class for embedding backends."""

supports_training: bool = False

@abstractmethod
def __init__(self, config: EmbedderConfig) -> None:
"""Initialize the embedding backend with configuration."""
...

@abstractmethod
def clear_ram(self) -> None:
"""Clear the backend from RAM."""
...

@overload
@abstractmethod
def embed(
self, utterances: list[str], task_type: TaskTypeEnum | None = None, *, return_tensors: Literal[True]
) -> torch.Tensor: ...

@overload
@abstractmethod
def embed(
self, utterances: list[str], task_type: TaskTypeEnum | None = None, *, return_tensors: Literal[False] = False
) -> npt.NDArray[np.float32]: ...

@abstractmethod
def embed(
self,
utterances: list[str],
task_type: TaskTypeEnum | None = None,
return_tensors: bool = False,
) -> npt.NDArray[np.float32] | torch.Tensor:
"""Calculate embeddings for a list of utterances.

Args:
utterances: List of input texts to calculate embeddings for.
task_type: Type of task for which embeddings are calculated.
return_tensors: If True, return a PyTorch tensor; otherwise, return a numpy array.

Returns:
A numpy array or PyTorch tensor of embeddings.
"""
...

@abstractmethod
def similarity(
self, embeddings1: npt.NDArray[np.float32], embeddings2: npt.NDArray[np.float32]
) -> npt.NDArray[np.float32]:
"""Calculate similarity between two sets of embeddings.

Args:
embeddings1: First set of embeddings (size n).
embeddings2: Second set of embeddings (size m).

Returns:
A numpy array of similarities (size n x m).
"""
...

@abstractmethod
def get_hash(self) -> int:
"""Compute a hash value for the backend configuration and model state.

Returns:
The hash value of the backend.
"""
...

@abstractmethod
def dump(self, path: Path) -> None:
"""Save the backend state to disk.

Args:
path: Path to the directory where the backend will be saved.
"""
...

@classmethod
@abstractmethod
def load(cls, path: Path) -> "BaseEmbeddingBackend":
"""Load the backend state from disk.

Args:
path: Path to the directory where the backend is stored.

Returns:
Loaded backend instance.
"""
...
Loading
Loading