Skip to content

Commit 39a25be

Browse files
add taskmodules and models from pie-modules (#498)
This PR implements #459, i.e., it adds the models and taskmodules implemented originally in [pie-modules](https://github.com/ArneBinder/pie-modules) (except for QA and span-pair based RE, see potential follow-ups below). - added models: - `SequenceClassificationModelWithPooler` - `SequencePairSimilarityModelWithPooler` - `SimpleTokenClassificationModel` - `SimpleGenerativeModel` - `SimpleSequenceClassificationModel` - `TokenClassificationModelWithSeq2SeqEncoderAndCrf` - added taskmodules: - `RETextClassificationWithIndicesTaskModule` - `TextToTextTaskModule` - `LabeledSpanExtractionByTokenClassificationTaskModule` - `PointerNetworkTaskModuleForEnd2EndRE` - `CrossTextBinaryCorefTaskModule` **IMPORTANT: This restricts the version of transformers to `>=4.35.0,<4.37.0`! So, this is breaking.** requires: - #482 - #499 Additional changes: - add `tabulate`, and `pytorch-crf` to dev dependencies - set dependence `torchmetrics[text] >=1.5, <2` to solve conflicts with `nltk` (`text` loads the required additional dependencies and `>=1.5` ensures that no deprecated nltk models are loaded. Note that we already use the modern nltk models in [`pie_documents.document.processing.NltkSentenceSplitter`](https://github.com/ArneBinder/pie-documents/blob/main/src/pie_documents/document/processing/sentence_splitter.py)) - add `SpanNotAlignedWithTokenException` and `get_aligned_token_span` to `utils.document` - add `RequiresMaxInputLength` and `RequiresTaskmoduleConfig` to `models.interface` potential follow-ups: - [ ] add remaining models (SimpleExtractiveQuestionAnsweringModel and SpanTupleClassificationModel) - [ ] add remaining taskmodules (ExtractiveQuestionAnsweringTaskModule, and RESpanPairClassificationTaskModule) --------- Co-authored-by: Danylo Mysak <[email protected]>
1 parent a6bb91d commit 39a25be

File tree

77 files changed

+22270
-167
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+22270
-167
lines changed

codecov.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ coverage:
22
status:
33
project:
44
default:
5-
target: 68% # TODO: switch back to auto
5+
target: 80% # TODO: switch back to auto
66
threshold: 1% # the leniency in hitting the target
77
patch:
88
default:
9-
target: 75% # TODO: switch back to 100%
9+
target: 85% # TODO: switch back to 100%

poetry.lock

Lines changed: 396 additions & 158 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@ dependencies = [
2121
"pie-documents >=0.1.0, <0.2.0",
2222
"torch >=1.10",
2323
"pytorch-lightning >=2, <3",
24-
"torchmetrics >1, <2",
25-
"transformers >=4.18, <5",
24+
"torchmetrics[text] >=1.5, <2",
25+
# >=4.35 because of BartModelWithDecoderPositionIds,
26+
# <4.37 because of generation config created from model config in BartAsPointerNetwork
27+
# TODO: check the upper bound, since this should be already fixed (https://github.com/ArneBinder/pie-modules/pull/205)
28+
"transformers >=4.35.0,<4.37.0",
2629
]
2730

2831
[project.urls]
@@ -51,14 +54,22 @@ classifiers = [
5154
optional = true
5255

5356
[tool.poetry.group.dev.dependencies]
54-
# testing (with coverage and run-time type checking)
57+
58+
# lazily imported packges
59+
# for taskmodule tests with collect statistics
60+
tabulate = "^0.9"
61+
# for TokenClassificationModelWithSeq2SeqEncoderAndCrf
62+
pytorch-crf = ">=0.7.2"
63+
64+
# testing utilities (with coverage and run-time type checking)
5565
pytest = ">=6.2.5"
5666
pytest-xdist = "^3.8.0"
5767
pytest-cov = "^6.2.1"
5868
typeguard = ">=2.13.3"
5969
sh = "^2"
6070
types-requests = "^2.27.7"
6171
python-dotenv = "^0.20.0"
72+
6273
# code quality and static type checking (via pre-commit)
6374
pre-commit = ">=2.16.0"
6475
pre-commit-hooks = ">=4.1.0"

src/pytorch_ie/models/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
from pytorch_ie.models.sequence_classification_with_pooler import (
2+
SequenceClassificationModelWithPooler,
3+
SequencePairSimilarityModelWithPooler,
4+
)
5+
from pytorch_ie.models.simple_generative import SimpleGenerativeModel
6+
from pytorch_ie.models.simple_sequence_classification import SimpleSequenceClassificationModel
7+
from pytorch_ie.models.simple_token_classification import SimpleTokenClassificationModel
8+
from pytorch_ie.models.token_classification_with_seq2seq_encoder_and_crf import (
9+
TokenClassificationModelWithSeq2SeqEncoderAndCrf,
10+
)
111
from pytorch_ie.models.transformer_seq2seq import TransformerSeq2SeqModel
212
from pytorch_ie.models.transformer_span_classification import TransformerSpanClassificationModel
313
from pytorch_ie.models.transformer_text_classification import TransformerTextClassificationModel
@@ -8,4 +18,10 @@
818
"TransformerSpanClassificationModel",
919
"TransformerTextClassificationModel",
1020
"TransformerTokenClassificationModel",
21+
"SequenceClassificationModelWithPooler",
22+
"SequencePairSimilarityModelWithPooler",
23+
"SimpleTokenClassificationModel",
24+
"SimpleGenerativeModel",
25+
"SimpleSequenceClassificationModel",
26+
"TokenClassificationModelWithSeq2SeqEncoderAndCrf",
1127
]
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .bart_as_pointer_network import BartAsPointerNetwork
2+
from .bart_with_decoder_position_ids import BartModelWithDecoderPositionIds

0 commit comments

Comments
 (0)