From 3e3f2aa17d1f30cbef4363a91222af589f0f36b0 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 10 Dec 2024 16:38:06 +0100 Subject: [PATCH] deduplicate annotations in pipeline --- src/pytorch_ie/pipeline.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/pytorch_ie/pipeline.py b/src/pytorch_ie/pipeline.py index e80d2658..c872c5f8 100644 --- a/src/pytorch_ie/pipeline.py +++ b/src/pytorch_ie/pipeline.py @@ -21,6 +21,7 @@ TaskModule, TaskOutput, ) +from pytorch_ie.utils.document import deduplicate_annotations logger = logging.getLogger(__name__) @@ -243,6 +244,7 @@ def postprocess( self, model_inputs: Sequence[TaskEncoding], model_outputs: Sequence[TaskOutput], + deduplicate_annotations: bool = False, **postprocess_parameters, ) -> Sequence[Document]: """ @@ -250,11 +252,14 @@ def postprocess( something more friendly. Generally it will output a list of documents. """ # This creates annotations from the model outputs and attaches them to the correct documents. - return self.taskmodule.decode( + result = self.taskmodule.decode( task_encodings=model_inputs, task_outputs=model_outputs, **postprocess_parameters, ) + if deduplicate_annotations: + result = [document.deduplicate_annotations() for document in result] + return result def get_inference_context(self): inference_context = ( @@ -308,12 +313,6 @@ def __call__( postprocess_params, ) = self._sanitize_parameters(**kwargs) - in_place: bool = postprocess_params.get("inplace", True) - if in_place and not isinstance(documents, (MutableSequence, Document)): - raise InplaceNotSupportedException( - "Immutable sequences of Documents (such as Datasets) can't be modified in place. Please set inplace=False." - ) - if "TOKENIZERS_PARALLELISM" not in os.environ: logger.info( "Disabling tokenizer parallelism, we're using DataLoader multithreading already" @@ -326,6 +325,16 @@ def __call__( forward_params = {**self._forward_params, **forward_params} postprocess_params = {**self._postprocess_params, **postprocess_params} + in_place: bool = postprocess_params.get("inplace", True) + if in_place and not isinstance(documents, (MutableSequence, Document)): + raise InplaceNotSupportedException( + "Immutable sequences of Documents (such as Datasets) can't be modified in place. Please set inplace=False." + ) + if postprocess_params.get("deduplicate_annotations", False) and in_place: + raise ValueError( + "Deduplicating annotations requires inplace=False. Please set inplace=False." + ) + single_document = False if isinstance(documents, Document): single_document = True