diff --git a/.fernignore b/.fernignore index c10518a6e..a1e62a28d 100644 --- a/.fernignore +++ b/.fernignore @@ -12,6 +12,7 @@ src/label_studio_sdk/client.py src/label_studio_sdk/tasks/client_ext.py src/label_studio_sdk/projects/client_ext.py src/label_studio_sdk/projects/exports/client_ext.py +src/label_studio_sdk/projects/stats/client_ext.py src/label_studio_sdk/tokens/client_ext.py src/label_studio_sdk/core/client_wrapper.py diff --git a/src/label_studio_sdk/projects/client_ext.py b/src/label_studio_sdk/projects/client_ext.py index 37c197d5e..702c7c642 100644 --- a/src/label_studio_sdk/projects/client_ext.py +++ b/src/label_studio_sdk/projects/client_ext.py @@ -1,11 +1,13 @@ import typing from typing_extensions import Annotated + from .client import ProjectsClient, AsyncProjectsClient from pydantic import model_validator, validator, Field, ConfigDict from label_studio_sdk._extensions.pager_ext import SyncPagerExt, AsyncPagerExt, T from label_studio_sdk.types.project import Project from label_studio_sdk.label_interface import LabelInterface from .exports.client_ext import ExportsClientExt, AsyncExportsClientExt +from .stats.client_ext import StatsClientExt, AsyncStatsClientExt from ..core import RequestOptions @@ -21,6 +23,7 @@ class ProjectsClientExt(ProjectsClient): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.exports = ExportsClientExt(client_wrapper=self._client_wrapper) + self.stats = StatsClientExt(client_wrapper=self._client_wrapper) def list(self, **kwargs) -> SyncPagerExt[T]: return SyncPagerExt.from_sync_pager(super().list(**kwargs)) @@ -38,6 +41,7 @@ class AsyncProjectsClientExt(AsyncProjectsClient): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.exports = AsyncExportsClientExt(client_wrapper=self._client_wrapper) + self.stats = AsyncStatsClientExt(client_wrapper=self._client_wrapper) async def get(self, id: int, *, request_options: typing.Optional[RequestOptions] = None) -> ProjectExt: return ProjectExt(**dict(await super().get(id, request_options=request_options))) diff --git a/src/label_studio_sdk/projects/stats/client_ext.py b/src/label_studio_sdk/projects/stats/client_ext.py new file mode 100644 index 000000000..42aa8645f --- /dev/null +++ b/src/label_studio_sdk/projects/stats/client_ext.py @@ -0,0 +1,157 @@ +import typing +from json.decoder import JSONDecodeError + +from label_studio_sdk.projects.stats.client import AsyncStatsClient, StatsClient + +from ...core.api_error import ApiError +from ...core.jsonable_encoder import jsonable_encoder +from ...core.request_options import RequestOptions +from ...core.unchecked_base_model import construct_type +from .types.stats_total_agreement_response import StatsTotalAgreementResponse + +class StatsClientExt(StatsClient): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + + def total_agreement( + self, + id: int, + *, + per_label: typing.Optional[bool] = None, + request_options: typing.Optional[RequestOptions] = None, + ) -> typing.Optional[StatsTotalAgreementResponse]: + """ + Overall or per-label total agreement across the project. + + Parameters + ---------- + id : int + + per_label : typing.Optional[bool] + Return agreement per label + + request_options : typing.Optional[RequestOptions] + Request-specific configuration. + + Returns + ------- + StatsTotalAgreementResponse + Total agreement + None + No data to compute agreement + + Examples + -------- + from label_studio_sdk import LabelStudio + + client = LabelStudio( + api_key="YOUR_API_KEY", + ) + client.projects.stats.total_agreement( + id=1, + ) + """ + _response = self._client_wrapper.httpx_client.request( + f"api/projects/{jsonable_encoder(id)}/stats/total_agreement", + method="GET", + params={ + "per_label": per_label, + }, + request_options=request_options, + ) + try: + if _response.status_code == 204: + return None + if 200 <= _response.status_code < 300: + return typing.cast( + StatsTotalAgreementResponse, + construct_type( + type_=StatsTotalAgreementResponse, # type: ignore + object_=_response.json(), + ), + ) + _response_json = _response.json() + except JSONDecodeError: + raise ApiError(status_code=_response.status_code, body=_response.text) + raise ApiError(status_code=_response.status_code, body=_response_json) + + + + +class AsyncStatsClientExt(AsyncStatsClient): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + + async def total_agreement( + self, + id: int, + *, + per_label: typing.Optional[bool] = None, + request_options: typing.Optional[RequestOptions] = None, + ) -> typing.Optional[StatsTotalAgreementResponse]: + """ + Overall or per-label total agreement across the project. + + Parameters + ---------- + id : int + + per_label : typing.Optional[bool] + Return agreement per label + + request_options : typing.Optional[RequestOptions] + Request-specific configuration. + + Returns + ------- + StatsTotalAgreementResponse + Total agreement + None + No data to compute agreement + + Examples + -------- + import asyncio + + from label_studio_sdk import AsyncLabelStudio + + client = AsyncLabelStudio( + api_key="YOUR_API_KEY", + ) + + + async def main() -> None: + await client.projects.stats.total_agreement( + id=1, + ) + + + asyncio.run(main()) + """ + _response = await self._client_wrapper.httpx_client.request( + f"api/projects/{jsonable_encoder(id)}/stats/total_agreement", + method="GET", + params={ + "per_label": per_label, + }, + request_options=request_options, + ) + try: + if _response.status_code == 204: + return None + if 200 <= _response.status_code < 300: + return typing.cast( + StatsTotalAgreementResponse, + construct_type( + type_=StatsTotalAgreementResponse, # type: ignore + object_=_response.json(), + ), + ) + _response_json = _response.json() + except JSONDecodeError: + raise ApiError(status_code=_response.status_code, body=_response.text) + raise ApiError(status_code=_response.status_code, body=_response_json)