From e9d9d95a3106eab6ee32559f9a8afd67075315cb Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Sun, 6 Apr 2025 14:15:51 +0200 Subject: [PATCH 1/2] refactor annotations --- src/pytorch_ie/annotations.py | 189 ++++++++++++++++------------------ 1 file changed, 91 insertions(+), 98 deletions(-) diff --git a/src/pytorch_ie/annotations.py b/src/pytorch_ie/annotations.py index ac023e9d..cf023076 100644 --- a/src/pytorch_ie/annotations.py +++ b/src/pytorch_ie/annotations.py @@ -1,75 +1,74 @@ from dataclasses import dataclass, field -from typing import Any, Optional, Tuple +from typing import Any, Tuple -from pytorch_ie.core.document import Annotation +from pie_core.document import Annotation -def _post_init_single_label(self): - if not isinstance(self.label, str): - raise ValueError("label must be a single string.") +@dataclass(frozen=True) +class WithPostInit: + """Base class for annotations that require post-initialization checks.""" - if not isinstance(self.score, float): - raise ValueError("score must be a single float.") + def __post_init__(self) -> None: + pass -def _post_init_multi_label(self): - if self.score is None: - score = tuple([1.0] * len(self.label)) - object.__setattr__(self, "score", score) +@dataclass(frozen=True) +class WithLabel(WithPostInit): + label: str + score: float = field(default=1.0, compare=False) - if not isinstance(self.label, tuple): - object.__setattr__(self, "label", tuple(self.label)) + def __post_init__(self) -> None: + if not isinstance(self.label, str): + raise ValueError("label must be a single string.") - if not isinstance(self.score, tuple): - object.__setattr__(self, "score", tuple(self.score)) + if not isinstance(self.score, float): + raise ValueError("score must be a single float.") - if len(self.label) != len(self.score): - raise ValueError( - f"Number of labels ({len(self.label)}) and scores ({len(self.score)}) must be equal." - ) + super().__post_init__() -def _post_init_multi_span(self): - if isinstance(self.slices, list): - object.__setattr__(self, "slices", tuple(tuple(s) for s in self.slices)) +@dataclass(frozen=True) +class WithMultiLabel(WithPostInit): + label: Tuple[str, ...] + score: Tuple[float, ...] = field(default=(), compare=False) + def __post_init__(self) -> None: + if not isinstance(self.label, tuple): + object.__setattr__(self, "label", tuple(self.label)) + + if len(self.score) == 0: + if len(self.label) == 0: + raise ValueError("label and score cannot be empty.") + score = tuple([1.0] * len(self.label)) + object.__setattr__(self, "score", score) + else: + if not isinstance(self.score, tuple): + object.__setattr__(self, "score", tuple(self.score)) -def _post_init_arguments_and_roles(self): - if len(self.arguments) != len(self.roles): - raise ValueError( - f"Number of arguments ({len(self.arguments)}) and roles ({len(self.roles)}) must be equal" - ) - if not isinstance(self.arguments, tuple): - object.__setattr__(self, "arguments", tuple(self.arguments)) - if not isinstance(self.roles, tuple): - object.__setattr__(self, "roles", tuple(self.roles)) + if len(self.label) != len(self.score): + raise ValueError( + f"Number of labels ({len(self.label)}) and scores " + f"({len(self.score)}) must be equal." + ) + super().__post_init__() -@dataclass(eq=True, frozen=True) -class Label(Annotation): - label: str - score: float = field(default=1.0, compare=False) - def __post_init__(self) -> None: - _post_init_single_label(self) +@dataclass(frozen=True) +class Label(WithLabel, Annotation): def resolve(self) -> Any: return self.label -@dataclass(eq=True, frozen=True) -class MultiLabel(Annotation): - label: Tuple[str, ...] - score: Optional[Tuple[float, ...]] = field(default=None, compare=False) - - def __post_init__(self) -> None: - _post_init_multi_label(self) +@dataclass(frozen=True) +class MultiLabel(WithMultiLabel, Annotation): def resolve(self) -> Any: return self.label -@dataclass(eq=True, frozen=True) +@dataclass(frozen=True) class Span(Annotation): start: int end: int @@ -86,41 +85,35 @@ def resolve(self) -> Any: raise ValueError(f"{self} is not attached to a target.") -@dataclass(eq=True, frozen=True) -class LabeledSpan(Span): - label: str - score: float = field(default=1.0, compare=False) - - def __post_init__(self) -> None: - _post_init_single_label(self) +@dataclass(frozen=True) +class LabeledSpan(WithLabel, Span): def resolve(self) -> Any: return self.label, super().resolve() -@dataclass(eq=True, frozen=True) -class MultiLabeledSpan(Span): - label: Tuple[str, ...] - score: Optional[Tuple[float, ...]] = field(default=None, compare=False) - - def __post_init__(self) -> None: - _post_init_multi_label(self) +@dataclass(frozen=True) +class MultiLabeledSpan(WithMultiLabel, Span): def resolve(self) -> Any: return self.label, super().resolve() -@dataclass(eq=True, frozen=True) -class MultiSpan(Annotation): +@dataclass(frozen=True) +class MultiSpan(WithPostInit, Annotation): slices: Tuple[Tuple[int, int], ...] def __post_init__(self) -> None: - _post_init_multi_span(self) + if isinstance(self.slices, list): + object.__setattr__(self, "slices", tuple(tuple(s) for s in self.slices)) + + super().__post_init__() def __str__(self) -> str: if not self.is_attached: return super().__str__() - return str(tuple(self.target[start:end] for start, end in self.slices)) + else: + return str(self.resolve()) def resolve(self) -> Any: if self.is_attached: @@ -129,60 +122,60 @@ def resolve(self) -> Any: raise ValueError(f"{self} is not attached to a target.") -@dataclass(eq=True, frozen=True) -class LabeledMultiSpan(MultiSpan): - label: str - score: float = field(default=1.0, compare=False) - - def __post_init__(self) -> None: - super().__post_init__() - _post_init_single_label(self) +@dataclass(frozen=True) +class LabeledMultiSpan(WithLabel, MultiSpan): def resolve(self) -> Any: return self.label, super().resolve() -@dataclass(eq=True, frozen=True) -class BinaryRelation(Annotation): +@dataclass(frozen=True) +class AnnotationWithHeadAndTail(Annotation): head: Annotation tail: Annotation - label: str - score: float = field(default=1.0, compare=False) - - def __post_init__(self) -> None: - _post_init_single_label(self) def resolve(self) -> Any: - return self.label, (self.head.resolve(), self.tail.resolve()) + return self.head.resolve(), self.tail.resolve() -@dataclass(eq=True, frozen=True) -class MultiLabeledBinaryRelation(Annotation): - head: Annotation - tail: Annotation - label: Tuple[str, ...] - score: Optional[Tuple[float, ...]] = field(default=None, compare=False) +@dataclass(frozen=True) +class BinaryRelation(WithLabel, AnnotationWithHeadAndTail): + + def resolve(self) -> Any: + return self.label, super().resolve() - def __post_init__(self) -> None: - _post_init_multi_label(self) + +@dataclass(frozen=True) +class MultiLabeledBinaryRelation(WithMultiLabel, AnnotationWithHeadAndTail): def resolve(self) -> Any: - return self.label, (self.head.resolve(), self.tail.resolve()) + return self.label, super().resolve() -@dataclass(eq=True, frozen=True) -class NaryRelation(Annotation): +@dataclass(frozen=True) +class AnnotationWithArgumentsAndRoles(WithPostInit, Annotation): arguments: Tuple[Annotation, ...] roles: Tuple[str, ...] - label: str - score: float = field(default=1.0, compare=False) def __post_init__(self) -> None: - _post_init_arguments_and_roles(self) - _post_init_single_label(self) + if len(self.arguments) != len(self.roles): + raise ValueError( + f"Number of arguments ({len(self.arguments)}) and roles " + f"({len(self.roles)}) must be equal" + ) + if not isinstance(self.arguments, tuple): + object.__setattr__(self, "arguments", tuple(self.arguments)) + if not isinstance(self.roles, tuple): + object.__setattr__(self, "roles", tuple(self.roles)) + + super().__post_init__() + + def resolve(self) -> Any: + return tuple((role, arg.resolve()) for role, arg in zip(self.roles, self.arguments)) + + +@dataclass(frozen=True) +class NaryRelation(WithLabel, AnnotationWithArgumentsAndRoles): def resolve(self) -> Any: - return ( - self.label, - tuple((role, arg.resolve()) for arg, role in zip(self.arguments, self.roles)), - ) + return self.label, super().resolve() From 031dd1da229591fd4d5a0109b56e414bcd78c183 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Sun, 6 Apr 2025 19:25:53 +0200 Subject: [PATCH 2/2] fix import --- src/pytorch_ie/annotations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_ie/annotations.py b/src/pytorch_ie/annotations.py index cf023076..5bda9726 100644 --- a/src/pytorch_ie/annotations.py +++ b/src/pytorch_ie/annotations.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from typing import Any, Tuple -from pie_core.document import Annotation +from pytorch_ie.core.document import Annotation @dataclass(frozen=True)