Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 90 additions & 97 deletions src/pytorch_ie/annotations.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,74 @@
from dataclasses import dataclass, field
from typing import Any, Optional, Tuple
from typing import Any, Tuple

from pytorch_ie.core.document import Annotation


def _post_init_single_label(self):
if not isinstance(self.label, str):
raise ValueError("label must be a single string.")
@dataclass(frozen=True)
class WithPostInit:
"""Base class for annotations that require post-initialization checks."""

if not isinstance(self.score, float):
raise ValueError("score must be a single float.")
def __post_init__(self) -> None:
pass


def _post_init_multi_label(self):
if self.score is None:
score = tuple([1.0] * len(self.label))
object.__setattr__(self, "score", score)
@dataclass(frozen=True)
class WithLabel(WithPostInit):
label: str
score: float = field(default=1.0, compare=False)

if not isinstance(self.label, tuple):
object.__setattr__(self, "label", tuple(self.label))
def __post_init__(self) -> None:
if not isinstance(self.label, str):
raise ValueError("label must be a single string.")

if not isinstance(self.score, tuple):
object.__setattr__(self, "score", tuple(self.score))
if not isinstance(self.score, float):
raise ValueError("score must be a single float.")

if len(self.label) != len(self.score):
raise ValueError(
f"Number of labels ({len(self.label)}) and scores ({len(self.score)}) must be equal."
)
super().__post_init__()


def _post_init_multi_span(self):
if isinstance(self.slices, list):
object.__setattr__(self, "slices", tuple(tuple(s) for s in self.slices))
@dataclass(frozen=True)
class WithMultiLabel(WithPostInit):
label: Tuple[str, ...]
score: Tuple[float, ...] = field(default=(), compare=False)

def __post_init__(self) -> None:
if not isinstance(self.label, tuple):
object.__setattr__(self, "label", tuple(self.label))

if len(self.score) == 0:
if len(self.label) == 0:
raise ValueError("label and score cannot be empty.")
score = tuple([1.0] * len(self.label))
object.__setattr__(self, "score", score)
else:
if not isinstance(self.score, tuple):
object.__setattr__(self, "score", tuple(self.score))

def _post_init_arguments_and_roles(self):
if len(self.arguments) != len(self.roles):
raise ValueError(
f"Number of arguments ({len(self.arguments)}) and roles ({len(self.roles)}) must be equal"
)
if not isinstance(self.arguments, tuple):
object.__setattr__(self, "arguments", tuple(self.arguments))
if not isinstance(self.roles, tuple):
object.__setattr__(self, "roles", tuple(self.roles))
if len(self.label) != len(self.score):
raise ValueError(
f"Number of labels ({len(self.label)}) and scores "
f"({len(self.score)}) must be equal."
)

super().__post_init__()

@dataclass(eq=True, frozen=True)
class Label(Annotation):
label: str
score: float = field(default=1.0, compare=False)

def __post_init__(self) -> None:
_post_init_single_label(self)
@dataclass(frozen=True)
class Label(WithLabel, Annotation):

def resolve(self) -> Any:
return self.label


@dataclass(eq=True, frozen=True)
class MultiLabel(Annotation):
label: Tuple[str, ...]
score: Optional[Tuple[float, ...]] = field(default=None, compare=False)

def __post_init__(self) -> None:
_post_init_multi_label(self)
@dataclass(frozen=True)
class MultiLabel(WithMultiLabel, Annotation):

def resolve(self) -> Any:
return self.label


@dataclass(eq=True, frozen=True)
@dataclass(frozen=True)
class Span(Annotation):
start: int
end: int
Expand All @@ -86,41 +85,35 @@ def resolve(self) -> Any:
raise ValueError(f"{self} is not attached to a target.")


@dataclass(eq=True, frozen=True)
class LabeledSpan(Span):
label: str
score: float = field(default=1.0, compare=False)

def __post_init__(self) -> None:
_post_init_single_label(self)
@dataclass(frozen=True)
class LabeledSpan(WithLabel, Span):

def resolve(self) -> Any:
return self.label, super().resolve()


@dataclass(eq=True, frozen=True)
class MultiLabeledSpan(Span):
label: Tuple[str, ...]
score: Optional[Tuple[float, ...]] = field(default=None, compare=False)

def __post_init__(self) -> None:
_post_init_multi_label(self)
@dataclass(frozen=True)
class MultiLabeledSpan(WithMultiLabel, Span):

def resolve(self) -> Any:
return self.label, super().resolve()


@dataclass(eq=True, frozen=True)
class MultiSpan(Annotation):
@dataclass(frozen=True)
class MultiSpan(WithPostInit, Annotation):
slices: Tuple[Tuple[int, int], ...]

def __post_init__(self) -> None:
_post_init_multi_span(self)
if isinstance(self.slices, list):
object.__setattr__(self, "slices", tuple(tuple(s) for s in self.slices))

super().__post_init__()

def __str__(self) -> str:
if not self.is_attached:
return super().__str__()
return str(tuple(self.target[start:end] for start, end in self.slices))
else:
return str(self.resolve())

def resolve(self) -> Any:
if self.is_attached:
Expand All @@ -129,60 +122,60 @@ def resolve(self) -> Any:
raise ValueError(f"{self} is not attached to a target.")


@dataclass(eq=True, frozen=True)
class LabeledMultiSpan(MultiSpan):
label: str
score: float = field(default=1.0, compare=False)

def __post_init__(self) -> None:
super().__post_init__()
_post_init_single_label(self)
@dataclass(frozen=True)
class LabeledMultiSpan(WithLabel, MultiSpan):

def resolve(self) -> Any:
return self.label, super().resolve()


@dataclass(eq=True, frozen=True)
class BinaryRelation(Annotation):
@dataclass(frozen=True)
class AnnotationWithHeadAndTail(Annotation):
head: Annotation
tail: Annotation
label: str
score: float = field(default=1.0, compare=False)

def __post_init__(self) -> None:
_post_init_single_label(self)

def resolve(self) -> Any:
return self.label, (self.head.resolve(), self.tail.resolve())
return self.head.resolve(), self.tail.resolve()


@dataclass(eq=True, frozen=True)
class MultiLabeledBinaryRelation(Annotation):
head: Annotation
tail: Annotation
label: Tuple[str, ...]
score: Optional[Tuple[float, ...]] = field(default=None, compare=False)
@dataclass(frozen=True)
class BinaryRelation(WithLabel, AnnotationWithHeadAndTail):

def resolve(self) -> Any:
return self.label, super().resolve()

def __post_init__(self) -> None:
_post_init_multi_label(self)

@dataclass(frozen=True)
class MultiLabeledBinaryRelation(WithMultiLabel, AnnotationWithHeadAndTail):

def resolve(self) -> Any:
return self.label, (self.head.resolve(), self.tail.resolve())
return self.label, super().resolve()


@dataclass(eq=True, frozen=True)
class NaryRelation(Annotation):
@dataclass(frozen=True)
class AnnotationWithArgumentsAndRoles(WithPostInit, Annotation):
arguments: Tuple[Annotation, ...]
roles: Tuple[str, ...]
label: str
score: float = field(default=1.0, compare=False)

def __post_init__(self) -> None:
_post_init_arguments_and_roles(self)
_post_init_single_label(self)
if len(self.arguments) != len(self.roles):
raise ValueError(
f"Number of arguments ({len(self.arguments)}) and roles "
f"({len(self.roles)}) must be equal"
)
if not isinstance(self.arguments, tuple):
object.__setattr__(self, "arguments", tuple(self.arguments))
if not isinstance(self.roles, tuple):
object.__setattr__(self, "roles", tuple(self.roles))

super().__post_init__()

def resolve(self) -> Any:
return tuple((role, arg.resolve()) for role, arg in zip(self.roles, self.arguments))


@dataclass(frozen=True)
class NaryRelation(WithLabel, AnnotationWithArgumentsAndRoles):

def resolve(self) -> Any:
return (
self.label,
tuple((role, arg.resolve()) for arg, role in zip(self.arguments, self.roles)),
)
return self.label, super().resolve()
Loading