Skip to content

Commit 481157c

Browse files
authored
fix: use documents and annotations directly from pie_documents (#499)
that was missed in #476 This also moves the helper method `_config_to_str` to tests package root and fixes an import bug in an *unused* test fixture in tests/taskmodules/test_simple_transformer_text_classification.py.
1 parent c156b3a commit 481157c

12 files changed

+38
-48
lines changed

src/pytorch_ie/metrics/statistics/token_count_collector.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
from typing import Any, Dict, Optional, Type, Union
33

44
from pie_core import Document, DocumentStatistic
5+
from pie_documents.documents import TextBasedDocument
56
from transformers import AutoTokenizer, PreTrainedTokenizer
67

7-
from pytorch_ie.documents import TextBasedDocument
8-
98
logger = logging.getLogger(__name__)
109

1110

src/pytorch_ie/taskmodules/simple_transformer_text_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
import numpy as np
1414
import torch
1515
from pie_core import TaskEncoding, TaskModule
16+
from pie_documents.annotations import Label
17+
from pie_documents.documents import TextDocumentWithLabel
1618
from transformers import AutoTokenizer
1719
from transformers.file_utils import PaddingStrategy
1820
from transformers.tokenization_utils_base import TruncationStrategy
1921
from typing_extensions import TypeAlias
2022

21-
from pytorch_ie.annotations import Label
22-
from pytorch_ie.documents import TextDocumentWithLabel
2323
from pytorch_ie.models.transformer_text_classification import ModelOutputType, ModelStepInputType
2424

2525
logger = logging.getLogger(__name__)

src/pytorch_ie/taskmodules/transformer_re_text_classification.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525
import numpy as np
2626
import torch
2727
from pie_core import AnnotationLayer, Document, TaskEncoding, TaskModule
28+
from pie_documents.documents import (
29+
TextDocument,
30+
TextDocumentWithLabeledSpansAndBinaryRelations,
31+
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
32+
)
2833
from transformers import AutoTokenizer
2934
from transformers.file_utils import PaddingStrategy
3035
from transformers.tokenization_utils_base import TruncationStrategy
@@ -37,11 +42,6 @@
3742
NaryRelation,
3843
Span,
3944
)
40-
from pytorch_ie.documents import (
41-
TextDocument,
42-
TextDocumentWithLabeledSpansAndBinaryRelations,
43-
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
44-
)
4545
from pytorch_ie.models.transformer_text_classification import ModelOutputType, ModelStepInputType
4646
from pytorch_ie.taskmodules.interface import ChangesTokenizerVocabSize
4747
from pytorch_ie.utils.span import get_token_slice, is_contained_in

src/pytorch_ie/taskmodules/transformer_seq2seq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Type, Union
1313

1414
from pie_core import Annotation, TaskEncoding, TaskModule
15+
from pie_documents.annotations import BinaryRelation, LabeledSpan
16+
from pie_documents.documents import TextDocument, TextDocumentWithLabeledSpansAndBinaryRelations
1517
from transformers import AutoTokenizer
1618
from transformers.file_utils import PaddingStrategy
1719
from transformers.tokenization_utils_base import TruncationStrategy
1820
from typing_extensions import TypeAlias
1921

20-
from pytorch_ie.annotations import BinaryRelation, LabeledSpan
21-
from pytorch_ie.documents import TextDocument, TextDocumentWithLabeledSpansAndBinaryRelations
2222
from pytorch_ie.models.transformer_seq2seq import ModelOutputType, ModelStepInputType
2323

2424
InputEncodingType: TypeAlias = Dict[str, Sequence[int]]

src/pytorch_ie/taskmodules/transformer_span_classification.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,18 @@
1414
import torch
1515
import torch.nn.functional as F
1616
from pie_core import TaskEncoding, TaskModule
17-
from transformers import AutoTokenizer
18-
from transformers.file_utils import PaddingStrategy
19-
from transformers.tokenization_utils_base import BatchEncoding, TruncationStrategy
20-
from typing_extensions import TypeAlias
21-
22-
from pytorch_ie.annotations import LabeledSpan, MultiLabeledSpan, Span
23-
from pytorch_ie.documents import (
17+
from pie_documents.annotations import LabeledSpan, MultiLabeledSpan, Span
18+
from pie_documents.documents import (
2419
TextDocument,
2520
TextDocumentWithLabeledSpans,
2621
TextDocumentWithLabeledSpansAndLabeledPartitions,
2722
TextDocumentWithLabeledSpansAndSentences,
2823
)
24+
from transformers import AutoTokenizer
25+
from transformers.file_utils import PaddingStrategy
26+
from transformers.tokenization_utils_base import BatchEncoding, TruncationStrategy
27+
from typing_extensions import TypeAlias
28+
2929
from pytorch_ie.models.transformer_span_classification import ModelOutputType, ModelStepInputType
3030

3131
InputEncodingType: TypeAlias = BatchEncoding

src/pytorch_ie/taskmodules/transformer_text_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@
2525
import numpy as np
2626
import torch
2727
from pie_core import TaskEncoding, TaskModule
28+
from pie_documents.annotations import Label, MultiLabel
29+
from pie_documents.documents import TextDocument, TextDocumentWithLabel, TextDocumentWithMultiLabel
2830
from transformers import AutoTokenizer
2931
from transformers.file_utils import PaddingStrategy
3032
from transformers.tokenization_utils_base import TruncationStrategy
3133
from typing_extensions import TypeAlias
3234

33-
from pytorch_ie.annotations import Label, MultiLabel
34-
from pytorch_ie.documents import TextDocument, TextDocumentWithLabel, TextDocumentWithMultiLabel
3535
from pytorch_ie.models.transformer_text_classification import ModelOutputType, ModelStepInputType
3636

3737
logger = logging.getLogger(__name__)

src/pytorch_ie/taskmodules/transformer_token_classification.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,17 @@
1414
import torch
1515
import torch.nn.functional as F
1616
from pie_core import TaskEncoding, TaskModule
17+
from pie_documents.annotations import LabeledSpan, Span
18+
from pie_documents.documents import (
19+
TextDocument,
20+
TextDocumentWithLabeledSpans,
21+
TextDocumentWithLabeledSpansAndLabeledPartitions,
22+
)
1723
from transformers import AutoTokenizer
1824
from transformers.file_utils import PaddingStrategy
1925
from transformers.tokenization_utils_base import BatchEncoding, TruncationStrategy
2026
from typing_extensions import TypeAlias
2127

22-
from pytorch_ie.annotations import LabeledSpan, Span
23-
from pytorch_ie.documents import (
24-
TextDocument,
25-
TextDocumentWithLabeledSpans,
26-
TextDocumentWithLabeledSpansAndLabeledPartitions,
27-
)
2828
from pytorch_ie.models.transformer_token_classification import ModelOutputType, ModelStepInputType
2929
from pytorch_ie.utils.span import (
3030
bio_tags_to_spans,

src/pytorch_ie/utils/document.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,9 @@
99
from pie_core.document import BaseAnnotationList
1010
from pie_documents.annotations import Span
1111
from pie_documents.document.processing import text_based_document_to_token_based
12-
from pie_documents.documents import TextBasedDocument, TokenBasedDocument
12+
from pie_documents.documents import TextBasedDocument, TokenBasedDocument, WithMetadata
1313
from transformers import PreTrainedTokenizer
1414

15-
from pytorch_ie.documents import WithMetadata
16-
1715
logger = logging.getLogger(__name__)
1816

1917

tests/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
from pathlib import Path
2+
from typing import Any, Dict
23

34
TESTS_ROOT = Path(__file__).parent
45
FIXTURES_ROOT = TESTS_ROOT / "fixtures"
6+
7+
8+
def _config_to_str(cfg: Dict[str, Any]) -> str:
9+
result = "-".join([f"{k}={cfg[k]}" for k in sorted(cfg)])
10+
return result

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
import pytest
55
from pie_core import AnnotationLayer, annotation_field
6+
from pie_documents.annotations import BinaryRelation, LabeledSpan, Span
7+
from pie_documents.documents import TextDocument
68

7-
from pytorch_ie.annotations import BinaryRelation, LabeledSpan, Span
8-
from pytorch_ie.documents import TextDocument
99
from tests import FIXTURES_ROOT
1010

1111

0 commit comments

Comments
 (0)