From 97fb80293b78f45701c8b3e7d5610d73b2c1ed73 Mon Sep 17 00:00:00 2001 From: Martyna Marcinkowska Date: Fri, 6 Jun 2025 12:44:09 +0100 Subject: [PATCH 01/21] new example --- config/config.yaml | 18 + config/data_config.json | 12 + config/experiment_config.yaml | 31 ++ docs/how-to/ConfigureExperiments.md | 208 ++++++++ docs/how-to/PromptsAndExtractionStrategies.md | 119 +++++ .../invoice_processing/components/predict.yml | 36 ++ mlops/invoice_processing/components/prep.yml | 25 + mlops/invoice_processing/components/score.yml | 34 ++ .../invoice_processing/environment/conda.yml | 26 + mlops/invoice_processing/src/__init__.py | 0 .../invoice_processing/src/mlops_pipeline.py | 284 +++++++++++ .../start_local_pipeline.py | 5 + src/invoice_processing/__init__.py | 0 .../predict_component/predict/__init__.py | 3 + .../predict/data_extraction/__init__.py | 0 .../data_extraction/config/__init__.py | 0 .../config/configuration_container.py | 24 + .../data_extraction/data_extractor_factory.py | 55 +++ .../data_extraction/extractors/__init__.py | 0 .../extractors/base_extractor.py | 25 + .../extractors/gpt_only_extractor.py | 154 ++++++ .../data_extraction/models/__init__.py | 0 .../models/extraction_response.py | 36 ++ .../data_extraction/prompts/__init__.py | 0 .../data_extraction/prompts/prompt_manager.py | 49 ++ .../templates/medical_claim_reimbursement.j2 | 14 + .../predict_component/predict/helpers.py | 33 ++ .../predict/mlflow_logger.py | 9 + .../predict_component/predict/predict.py | 246 ++++++++++ .../predict_component/predict/pyproject.toml | 25 + .../prep_component/prep/__init__.py | 3 + .../prep_component/prep/prep.py | 93 ++++ .../score_component/score/__init__.py | 3 + .../score/extraction_evaluator.py | 293 +++++++++++ .../score/matchers/__init__.py | 0 .../score/matchers/amount_exact_matcher.py | 51 ++ .../score/matchers/base_matcher.py | 17 + .../score/matchers/date_exact_matcher.py | 57 +++ .../score/matchers/levenshtein_matcher.py | 89 ++++ .../score/matchers/text_exact_matcher.py | 82 ++++ .../score_component/score/score.py | 439 +++++++++++++++++ .../score_component/score/utils.py | 114 +++++ .../data_extraction/assets/config.json | 10 + .../data_extraction/assets/mock_extractor.py | 29 ++ .../extractors/test_gpt_only_extractor.py | 202 ++++++++ .../models/test_extraction_response.py | 54 +++ .../prompts/test_prompt_manager.py | 45 ++ .../test_configuration_container.py | 29 ++ .../data_extraction/test_extractor_factory.py | 77 +++ .../predict_component/predict/test_helpers.py | 33 ++ .../predict_component/predict/test_predict.py | 107 ++++ .../test_experiment_config.yaml | 15 + .../test_extraction_evaluator.py | 119 +++++ .../score_component/test_score.py | 457 ++++++++++++++++++ .../score_component/test_utils.py | 22 + 55 files changed, 3911 insertions(+) create mode 100644 config/experiment_config.yaml create mode 100644 docs/how-to/ConfigureExperiments.md create mode 100644 docs/how-to/PromptsAndExtractionStrategies.md create mode 100644 mlops/invoice_processing/components/predict.yml create mode 100644 mlops/invoice_processing/components/prep.yml create mode 100644 mlops/invoice_processing/components/score.yml create mode 100644 mlops/invoice_processing/environment/conda.yml create mode 100644 mlops/invoice_processing/src/__init__.py create mode 100644 mlops/invoice_processing/src/mlops_pipeline.py create mode 100644 mlops/invoice_processing/start_local_pipeline.py create mode 100644 src/invoice_processing/__init__.py create mode 100644 src/invoice_processing/predict_component/predict/__init__.py create mode 100644 src/invoice_processing/predict_component/predict/data_extraction/__init__.py create mode 100644 src/invoice_processing/predict_component/predict/data_extraction/config/__init__.py create mode 100644 src/invoice_processing/predict_component/predict/data_extraction/config/configuration_container.py create mode 100644 src/invoice_processing/predict_component/predict/data_extraction/data_extractor_factory.py create mode 100644 src/invoice_processing/predict_component/predict/data_extraction/extractors/__init__.py create mode 100644 src/invoice_processing/predict_component/predict/data_extraction/extractors/base_extractor.py create mode 100644 src/invoice_processing/predict_component/predict/data_extraction/extractors/gpt_only_extractor.py create mode 100644 src/invoice_processing/predict_component/predict/data_extraction/models/__init__.py create mode 100644 src/invoice_processing/predict_component/predict/data_extraction/models/extraction_response.py create mode 100644 src/invoice_processing/predict_component/predict/data_extraction/prompts/__init__.py create mode 100644 src/invoice_processing/predict_component/predict/data_extraction/prompts/prompt_manager.py create mode 100644 src/invoice_processing/predict_component/predict/data_extraction/prompts/templates/medical_claim_reimbursement.j2 create mode 100644 src/invoice_processing/predict_component/predict/helpers.py create mode 100644 src/invoice_processing/predict_component/predict/mlflow_logger.py create mode 100644 src/invoice_processing/predict_component/predict/predict.py create mode 100644 src/invoice_processing/predict_component/predict/pyproject.toml create mode 100644 src/invoice_processing/prep_component/prep/__init__.py create mode 100644 src/invoice_processing/prep_component/prep/prep.py create mode 100644 src/invoice_processing/score_component/score/__init__.py create mode 100644 src/invoice_processing/score_component/score/extraction_evaluator.py create mode 100644 src/invoice_processing/score_component/score/matchers/__init__.py create mode 100644 src/invoice_processing/score_component/score/matchers/amount_exact_matcher.py create mode 100644 src/invoice_processing/score_component/score/matchers/base_matcher.py create mode 100644 src/invoice_processing/score_component/score/matchers/date_exact_matcher.py create mode 100644 src/invoice_processing/score_component/score/matchers/levenshtein_matcher.py create mode 100644 src/invoice_processing/score_component/score/matchers/text_exact_matcher.py create mode 100644 src/invoice_processing/score_component/score/score.py create mode 100644 src/invoice_processing/score_component/score/utils.py create mode 100644 test/invoice_processing/predict_component/predict/data_extraction/assets/config.json create mode 100644 test/invoice_processing/predict_component/predict/data_extraction/assets/mock_extractor.py create mode 100644 test/invoice_processing/predict_component/predict/data_extraction/extractors/test_gpt_only_extractor.py create mode 100644 test/invoice_processing/predict_component/predict/data_extraction/models/test_extraction_response.py create mode 100644 test/invoice_processing/predict_component/predict/data_extraction/prompts/test_prompt_manager.py create mode 100644 test/invoice_processing/predict_component/predict/data_extraction/test_configuration_container.py create mode 100644 test/invoice_processing/predict_component/predict/data_extraction/test_extractor_factory.py create mode 100644 test/invoice_processing/predict_component/predict/test_helpers.py create mode 100644 test/invoice_processing/predict_component/predict/test_predict.py create mode 100644 test/invoice_processing/score_component/test_experiment_config.yaml create mode 100644 test/invoice_processing/score_component/test_extraction_evaluator.py create mode 100644 test/invoice_processing/score_component/test_score.py create mode 100644 test/invoice_processing/score_component/test_utils.py diff --git a/config/config.yaml b/config/config.yaml index d4d15fba..1cc42682 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -41,6 +41,15 @@ pipeline_configs: docker_context_path: mlops/docker_taxi/environment aml_env_name: docker_taxi_env dataset_name: docker_taxi_pr_data + + invoice_processing_pr: + cluster_region: eastus + cluster_size: STANDARD_DS3_v2 + cluster_name: cpucluster + conda_path: mlops/invoice_processing/environment/conda.yml + aml_env_name: invoice_processing_env + dataset_name: invoice_processing_test + gt_name: invoice_processing_test_gt london_taxi_dev: cluster_region: eastus @@ -75,6 +84,15 @@ pipeline_configs: aml_env_name: docker_taxi_env dataset_name: docker_taxi_pr_data + invoice_processing_dev: + cluster_region: eastus + cluster_size: STANDARD_DS3_v2 + cluster_name: cpucluster + conda_path: mlops/invoice_processing/environment/conda.yml + aml_env_name: invoice_processing_env + dataset_name: validated_gt_images + gt_name: validated_gt_annotations + deployment_configs: london_taxi_batch_dev: score_file_name: score.py diff --git a/config/data_config.json b/config/data_config.json index b13e9d5f..78eefc24 100644 --- a/config/data_config.json +++ b/config/data_config.json @@ -53,6 +53,18 @@ "DATA_PATH":"mlops/docker_taxi/data", "DATASET_NAME":"docker_taxi_training", "DATASET_DESC":"this dataset is for training models" + }, + { + "DATA_PURPOSE": "test_data", + "DATA_PATH":"mlops/invoice_processing/data/raw_data", + "DATASET_NAME":"invoice_processing_test", + "DATASET_DESC":"this dataset is for pr validation only" + }, + { + "DATA_PURPOSE": "ground_truth", + "DATA_PATH":"mlops/invoice_processing/data/ground_truth", + "DATASET_NAME":"invoice_processing_test_gt", + "DATASET_DESC":"this dataset is for pr validation only" } ] } \ No newline at end of file diff --git a/config/experiment_config.yaml b/config/experiment_config.yaml new file mode 100644 index 00000000..96129f5c --- /dev/null +++ b/config/experiment_config.yaml @@ -0,0 +1,31 @@ +experiment_description: + user_name: + title: + hypothesis: + +prep_config: + samples_amount: 4 + sampling_seed: 42 + +predict_config: + strategy: gpt_only + gpt_deployment_name: gpt-4o + temperature: 0 + prompt_config: + prompt_name: medical_claim_reimbursement + line_item_instructions: complex + +score_config: + fuzzy_match_config: + field_match_threshold: 0.0 + exact_match_fields: + start_date_match: true + end_date_match: true + amount_match: true + find_best_matches_strategy: levenshtein + matchers_dict: + serviceStartDate: date_exact_match + serviceEndDate: date_exact_match + amount: amount_exact_match + description: description_levenshtein + diff --git a/docs/how-to/ConfigureExperiments.md b/docs/how-to/ConfigureExperiments.md new file mode 100644 index 00000000..50e09b94 --- /dev/null +++ b/docs/how-to/ConfigureExperiments.md @@ -0,0 +1,208 @@ +# Configure Repository + +This document describes how to configure the repository when running experiments. + +## .env File + +Before running any experiments, the user must create an empty file (`.env`) and copy [.env_sample](../../.env.sample) into the created file. This file contains the environment variables required to connect to AML as well as Azure OpenAI credentials. The file will contain secrets and should therefore not be pushed to the repo. + +```yaml +SUBSCRIPTION_ID="" +RESOURCE_GROUP_NAME="" +WORKSPACE_NAME="" +BUILD_BUILDID="local" +VNET_NAME="" +SUBNET_NAME="" +USER_ASSIGNED_IDENTITY_RESOURCE_ID="" +AZURE_OPENAI_API_KEY="" +AZURE_OPENAI_ENDPOINT="" +``` + +- subscription_id: The subscription id in Azure hosting the Azure Machine Learning workspace. +- resource_group_name: The name of the resource group hosting the Azure Machine Learning workspace. +- workspace_name: The name of the Azure Machine Learning workspace in which the models will be trained and served. +- vnet_name: The name of the existing virtual network for compute deployment. +- subnet_name: The name of the existing subnet for compute deployment. +- user_assigned_identity_resource_id: The resource id of the user assigned identity to assign to the compute instance. Formatted as "`/subscriptions//resourcegroups//providers/Microsoft.Manag + +To download the `SUBSCRIPTION_ID`, `RESOURCE_GROUP_NAME` and `WORKSPACE_NAME`: + +1. Sign in to AML studio +1. In the upper right Azure Machine Learning studio toolbar, select your workspace name +1. Select the Download config file link +1. Copy the relevant information from the file + +To find the `VNET_NAME` and `SUBNET_NAME`: + +1. go to the Azure Machine Learning studio +1. Select `Compute` on the left panel +1. Copy the information into the relevant fields + +To find the OpenAI credentials, go to the Azure AI Foundary. + +## config Folder + +The [config folder](/config) contains three configuration files: + +- [`config.yml`](/config/config.yaml) - configuration for Azure Machine Learning (AML) and the pipelines, +- [`experiment_config.yml`](/config/experiment_config.yaml) - experiment configurations, +- [`data_config.json`](/config/data_config.json) - configuration for registering data sets in AML from local. + +### `experiment_config.yml` file + +This file is only used by the invoice_processing pipeline and is used to configure the experiment that will be run in AML. It contains the parameters for data preparation, prediction, and evaluation steps of the pipeline. + +The file has several sections, each section configures a different component of the experiment. + +The `experiment_description` section enables the users to add more information about the experiment they are about to run. The user can provide their user name, give a title to the experiment and explain what hypothesis is being tested in this experiment. This information will be logged into AML's job description. + +``` yaml +experiment_description: + user_name: + title: + hypothesis: +``` + +Providing this information will make it easier to differentiate between the different AML runs. + +The next section in the config file configures the data preparation step of the pipeline. + +```yaml +prep_config: + samples_amount: 4 + sampling_seed: 42 +``` + +Adjust the `sample_amount` to be the number of samples on which you would like to run the pipeline (setting this value to zero means running on the entire data set). + +Adjust the `sampling_seed` if you would like the sample to be reproducible in future experiments (otherwise set it to be -1) + +Next, configure the prediction step: + +```yaml +predict_config: + strategy: gpt_only + gpt_deployment_name: gpt-4o + temperature: 0 + prompt_config: + prompt_name: medical_claim_reimbursement_implicit_dates + line_item_instructions: complex +``` + +See the [Prompt and Extraction Strategies](./PromptsAndExtractionStrategies.md) for more details. + +Lastly, to configure the evaluation step: + +```yaml +score_config: + fuzzy_match_config: + field_match_threshold: 0.0 + exact_match_fields: + start_date_match: true + end_date_match: true + amount_match: true + find_best_matches_strategy: levenshtein + matchers_dict: + serviceStartDate: date_exact_match + serviceEndDate: date_exact_match + amount: amount_exact_match + description: description_levenshtein +``` + +### `data_config.json` file + +The `data_config.json` file is used when registering a new data asset in AML. + +```yaml +"DATA_PURPOSE": "test_data", +"DATA_PATH": "mlops/invoice_processing/data/raw_data", +"DATASET_NAME": "invoice_processing_test", +"DATASET_DESC": "this dataset is for pr validation only" +``` + +`DATA_PURPOSE`: what is the purpose of the dataset? + +`DATA_PATH`: the local or Azure path from which to upload the data e.g. "azureml://subscriptions//resourcegroups//workspaces//datastores//paths/" + +`DATASET_NAME`: the name of the registered data asset + +`DATASET_DESC`: description of the dataset + +After configuring the parameters run: + +```bash +python -m mlops.common.register_data_asset --data_config_path= +``` + +The script will register the dataset in AML under data assets. + +### `config.yaml` file + +The [`/config/config.yaml`](/config/config.yaml) file contains a few sections, configuring different aspects of the AML pipeline. + +`aml_config`: the following values are extracted from the `.env` file (do not modify or replace values in this section!). + +```yaml +aml_config: + subscription_id: ${SUBSCRIPTION_ID} + resource_group_name: ${RESOURCE_GROUP_NAME} + workspace_name: ${WORKSPACE_NAME} + vnet_name: ${VNET_NAME} + subnet_name: ${SUBNET_NAME} + user_assigned_identity_resource_id: ${USER_ASSIGNED_IDENTITY_RESOURCE_ID} +``` + +`environment_configuration`: Set the properties for the environment when executing build validation or continuous integration pipelines. When choosing a base image for training and inferencing in Azure Machine Learning take into consideration compatibility with the libraries, dependencies, and performance characteristics of your model. Also consider image maintainability, size, and licensing. + +- env_base_image: Base image to be used for training and model execution +- build_reference: Name of the build to run in AML, by default will be `local` +- env_version: Env version to load. If -1, the latest version is loaded + +```yaml +environment_configuration: + env_base_image: mcr.microsoft.com/azureml/openmpi4.1.0-ubuntu22.04 + build_reference: ${BUILD_BUILDID} + env_version: -1 +``` + +`pipeline_configs`: Stores the configuration for ci and dev pipelines for each model supported by the solution. + +- cluster_region: Azure region in which the AML compute cluster should be hosted. +- cluster_size: Set to an Azure VM Size according to the naming convention here: [Azure VM Sizes](https://learn.microsoft.com/en-us/azure/virtual-machines/sizes). +- cluster_name: A string representing the name of the compute cluster. +- conda_path: The path within the solution to the conda file used to establish the dependencies needed by a given model. (Optional if using `dockerfile_path` and `docker_context_path`) +- aml_env_name: A string denoting the name of a given environment for a given model. +- dataset_name: The name of the data asset which contains the images we want to extract data from. +- gt_name: The name of the data asset which contains the corresponding ground truth (manually annotated image extractions). + +**Important note:** If you would like to use a different dataset for your experiments, please modify the invoice_processing_dev and leave invoice_processing_pr as-is to enable quick PR validation. + +```yaml +pipeline_configs: + invoice_processing_pr: + cluster_region: eastus + cluster_size: STANDARD_DS3_v2 + cluster_name: cpucluster + conda_path: mlops/invoice_processing/environment/conda.yml + aml_env_name: invoice_processing_env + dataset_name: invoice_processing_test + gt_name: invoice_processing_test_gt + + invoice_processing_dev: + cluster_region: eastus + cluster_size: STANDARD_DS3_v2 + cluster_name: cpucluster + conda_path: mlops/invoice_processing/environment/conda.yml + aml_env_name: invoice_processing_env + dataset_name: validated_gt_images + gt_name: validated_gt_annotations +``` + +`invoice_processing_pr` is for the CI pipeline, and runs the pipeline on a small dataset. `invoice_processing_dev` is used in development mode. +The choice between running in pr or dev mode is configured in [start_local_pipeline.py](../../mlops/invoice_processing/start_local_pipeline.py): + +```python +mlops_pipeline.prepare_and_execute("invoice_processing", "pr", "True", None, None) +``` + +`deployment_configs`: Stores online and batch configuration for deployments for each model. diff --git a/docs/how-to/PromptsAndExtractionStrategies.md b/docs/how-to/PromptsAndExtractionStrategies.md new file mode 100644 index 00000000..26872a8d --- /dev/null +++ b/docs/how-to/PromptsAndExtractionStrategies.md @@ -0,0 +1,119 @@ +# Prompts and Extraction Strategies + +This document describes how to configure prompts and extraction strategies for the invoice_processing example. + +## Hyperparameters + +The LLM's temperature can be modified through the [`experiment_config.yaml](config/experiment_config.yaml). + +```yaml +predict_config: + strategy: gpt_only + gpt_deployment_name: gpt-4o + temperature: 0 + +``` + +## Strategies + +The predict component in the experimentation framework can support multiple extraction strategies. Strategies are defined in [this folder](/src/invoice_processing/predict_component/predict/data_extraction). +The [predict.py](/src/invoice_processing/predict_component/predict/predict.py) chooses which strategy to use according to the definitions in the [experiment_config.yaml](/config/experiment_config.yaml) file, as seen below: + +``` yaml +predict_config: + strategy: gpt_only + gpt_deployment_name: gpt-4o + ... +``` + +Currently, the experimentation framework has only one strategy [GPT Only](/src/invoice_processing/predict_component/predict/data_extraction/extractors/gpt_only_extractor.py), defined and implemented. +The strategy parses the images as `.png`, `.jpg` or `.jpeg` to GPT model which provides an answer as `.json`. Different model deployment name can be specified in the same configuration file with `gpt_deployment_name`. Azure OpenAI API key and endpoint have to be set in `.env` to use this strategy. + +### Add a new strategy + +To add a new strategy for the images: + +1. In [extractors folder](/src/invoice_processing/predict_component/predict/data_extraction/extractors) add a new file with your class implementing [base_extractor.py](/src/invoice_processing/predict_component/predict/data_extraction/extractors/base_extractor.py), similar to [gpt_only_extractor.py](/src/invoice_processing/predict_component/predict/data_extraction/extractors/gpt_only_extractor.py) +1. Load the new strategy as part of the factory, add your new strategy in function `load_default_extractors` located at the [data_extractor_factory.py](/src/invoice_processing/predict_component/predict/data_extraction/data_extractor_factory.py) + +## Cost estimation + +The estimated cost is calculated by the `estimate_cost` function within [`predict.py`](/src/invoice_processing/predict_component/predict/predict.py). The values are hard coded based on [this documentation](https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/?msockid=068db13b7d9c6c9a095ca4127cb76d73#pricing) and East US 2 region and need to be configured for any new model (or model in a new region) you want to run in the AML pipeline. + +## Prompts + +Prompts for data extraction from images are defined in `prompts/templates` folder as jinja files. The prompt is then loaded using the [`PromptManager`](/src/invoice_processing/predict_component/predict/data_extraction/prompts/prompt_manager.py) from the [`experiment_config.yaml`](/config/experiment_config.yaml) file as follows (the [example](/src/invoice_processing/predict_component/predict/data_extraction/extractors/gpt_only_extractor.py) is from the GPT only extractor): + +```python +user_prompt = PromptManager.get_prompt( + self.config["prompt_name"], + line_item_instructions=self.config["line_item_instructions"], + structure={json.dumps(structure)} +) +``` + +```yaml +predict_config: + ... + prompt_config: + prompt_name: medical_claim_reimbursement + line_item_instructions: complex +``` + +The experiment will use [medical_claim_reimbursement.j2](../../src/invoice_processing/predict_component/predict/data_extraction/prompts/templates/medical_claim_reimbursement.j2) as the prompt provided to the LLM. + +### Add a new prompt + +To add a new prompt, create a new jinja file within the [prompts/templates](/src/invoice_processing/predict_component/predict/data_extraction/prompts/templates/) folder. If you add a variable or another required input, these will need to be added within the [`experiment_config.yaml` file](/config/experiment_config.yaml) and where the prompt is loaded within each strategy. + +For example, let's imagine I created a prompt called `pharmacy_charges_claims.j2` as follows: + +``` jinja +### Instructions ### +As a Pharmacy Charges Claim Reimbursement Processor, your primary responsibility involves examination of the provided pharmacy receipt in order to accurately extract key information necessary for reimbursement procedures. + +### Required Details ### +- Provider's Name +- Final Charges +{% if additional_fields == 'client_name' %} +- Client's Name +{% endif %} +``` + +I will need to change the [`experiment_config.yaml` file](/config/experiment_config.yaml) as follows: + +```yaml +predict_config: + ... + prompt_config: + prompt_name: pharmacy_charges_claims + additional_fields: client_name +``` + +and the [GPT only extractor](/src/invoice_processing/predict_component/predict/data_extraction/extractors/gpt_only_extractor.py) as follows: + +```python +user_prompt = PromptManager.get_prompt( + self.config["prompt_name"], + additional_fields=self.config["additional_fields"], + structure={json.dumps(structure)} +) +``` + +**Note:** If you are **NOT** using any additional variables within your prompt, remember to remove unnecessary parameters from [`experiment_config.yaml` file](/config/experiment_config.yaml) and [GPT only extractor](/src/invoice_processing/predict_component/predict/data_extraction/extractors/gpt_only_extractor.py): + +```yaml +predict_config: + ... + prompt_config: + prompt_name: pharmacy_charges_claims +``` + +```python +user_prompt = PromptManager.get_prompt( + self.config["prompt_name"], + structure={json.dumps(structure)} +) +``` + +For more information on Jinja, check out [their documentation](https://jinja.palletsprojects.com/en/stable/). diff --git a/mlops/invoice_processing/components/predict.yml b/mlops/invoice_processing/components/predict.yml new file mode 100644 index 00000000..89a76108 --- /dev/null +++ b/mlops/invoice_processing/components/predict.yml @@ -0,0 +1,36 @@ +$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json +name: predict_labels +version: 1 +display_name: PredictLabels +type: command +enable_caching: true +inputs: + strategy: + type: string + temperature: + type: number + gpt_deployment_name: + type: string + prompt_config: + type: string + test_data: + type: uri_folder + azure_openai_endpoint: + type: string + azure_openai_api_key: + type: string +outputs: + predictions: + type: uri_folder +environment: azureml:AzureML-sklearn-1.1-ubuntu20.04-py38-cpu@latest +code: ./../../../src/invoice_processing/predict_component +command: >- + python -m predict.predict + --strategy ${{inputs.strategy}} + --temperature ${{inputs.temperature}} + --gpt_deployment_name ${{inputs.gpt_deployment_name}} + --azure_openai_endpoint ${{inputs.azure_openai_endpoint}} + --azure_openai_api_key ${{inputs.azure_openai_api_key}} + --prompt_config "${{inputs.prompt_config}}" + --test_data ${{inputs.test_data}} + --predictions ${{outputs.predictions}} diff --git a/mlops/invoice_processing/components/prep.yml b/mlops/invoice_processing/components/prep.yml new file mode 100644 index 00000000..345b9aaf --- /dev/null +++ b/mlops/invoice_processing/components/prep.yml @@ -0,0 +1,25 @@ +$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json +name: prepare_invoices +display_name: PrepInvoices +version: 1 +type: command +enable_caching: true +inputs: + raw_data: + type: uri_folder + samples_amount: + type: integer + sampling_seed: + type: integer + +outputs: + prep_data: + type: uri_folder +code: ./../../../src/invoice_processing/prep_component +environment: azureml:AzureML-sklearn-1.1-ubuntu20.04-py38-cpu@latest +command: >- + python -m prep.prep + --raw_data ${{inputs.raw_data}} + --samples_amount ${{inputs.samples_amount}} + --sampling_seed ${{inputs.sampling_seed}} + --prep_data ${{outputs.prep_data}} diff --git a/mlops/invoice_processing/components/score.yml b/mlops/invoice_processing/components/score.yml new file mode 100644 index 00000000..328df80d --- /dev/null +++ b/mlops/invoice_processing/components/score.yml @@ -0,0 +1,34 @@ +$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json +name: score_model +version: 1 +display_name: ScoreModel +type: command +enable_caching: true +inputs: + predictions: + type: uri_folder + ground_truth: + type: uri_folder + score_config: + type: string +outputs: + score_report: + type: uri_file + missing_refs: + type: uri_file + all_unmatched_gt: + type: uri_file + all_unmatched_pred: + type: uri_file + +environment: azureml:AzureML-sklearn-1.1-ubuntu20.04-py38-cpu@latest +code: ./../../../src/invoice_processing/score_component +command: >- + python -m score.score + --predictions ${{inputs.predictions}} + --ground_truth ${{inputs.ground_truth}} + --score_report ${{outputs.score_report}} + --missing_refs ${{outputs.missing_refs}} + --all_unmatched_gt ${{outputs.all_unmatched_gt}} + --all_unmatched_pred ${{outputs.all_unmatched_pred}} + --score_config "${{inputs.score_config}}" diff --git a/mlops/invoice_processing/environment/conda.yml b/mlops/invoice_processing/environment/conda.yml new file mode 100644 index 00000000..aa4293e9 --- /dev/null +++ b/mlops/invoice_processing/environment/conda.yml @@ -0,0 +1,26 @@ +name: prs-env +channels: + - conda-forge +dependencies: + - python=3.9 + - pip + - pip: + - python-dotenv + - pandas + - numpy==1.23.5 + - scikit-learn==1.3.2 + - mlflow>=2.9.2 + - azureml-mlflow>=1.59 + - azure-ai-ml>=1.10.0 + - azureml-fsspec>=1.3.1 + - azure-identity>=1.15.0 + - azure-keyvault-secrets>=4.7.0 + - openai + - pydantic + - Levenshtein + - python-dateutil + - datetime + - python-frontmatter + - jinja2 + - tqdm + - python-retry diff --git a/mlops/invoice_processing/src/__init__.py b/mlops/invoice_processing/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mlops/invoice_processing/src/mlops_pipeline.py b/mlops/invoice_processing/src/mlops_pipeline.py new file mode 100644 index 00000000..6d7ab4ef --- /dev/null +++ b/mlops/invoice_processing/src/mlops_pipeline.py @@ -0,0 +1,284 @@ +""" +This module defines a machine learning pipeline for processing, training, and evaluating data. + +The pipeline executes the following steps in order: +1. Prepare Sample Data: Preprocesses raw data to make it suitable for further processing and analysis. +2. Predict with Sample Data: Uses the trained model to make predictions on new data. +3. Score with Sample Data: Evaluates the model's performance based on its predictions. +""" + +import argparse +from azure.ai.ml.dsl import pipeline +from azure.ai.ml import Input +from azure.ai.ml import load_component +import os +import yaml + +from mlops.common.config_utils import MLOpsConfig +from mlops.common.naming_utils import generate_model_name +from mlops.common.pipeline_job_config import PipelineJobConfig +from mlops.common.pipeline_utils import prepare_and_execute_pipeline + +gl_pipeline_components = [] + + +@pipeline() +def invoice_processing_data_regression( + pipeline_job_input: Input, + model_name: str, + build_reference: str, + strategy: str, + temperature: float, + gpt_deployment_name: str, + prompt_config: str, + ground_truth_data: Input, + score_config: str, + samples_amount: int, + sampling_seed: int, +): + """ + Run a pipeline for regression analysis on invoice data. + + Args: + pipeline_job_input (Input): The raw input data for the pipeline. + model_name (str): The name of the model to be used. + build_reference (str): A reference identifier for the build. + gpt_deployment_name(str): GPT Deployment name. + strategy(str): strategy for predict step e.g. prompt + prompt_config(str): config to use for predict step e.g. prompt + ground_truth_data(Input): ground truth input + score_config (str): dictionary loaded from file as string + samples_amount (int): amount of samples to randomly use from the data set, 0 means all + sampling_seed (int): seed for random sampling of dataset, -1 means no seed + + Returns: + dict: A dictionary containing the outputs of various stages of the pipeline: + """ + prepare_sample_data = gl_pipeline_components[0]( + raw_data=pipeline_job_input, + samples_amount=samples_amount, + sampling_seed=sampling_seed, + ) + predict_with_sample_data = gl_pipeline_components[1]( + strategy=strategy, + temperature=temperature, + gpt_deployment_name=gpt_deployment_name, + prompt_config=prompt_config, + test_data=prepare_sample_data.outputs.prep_data, + azure_openai_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + azure_openai_api_key=os.getenv("AZURE_OPENAI_API_KEY"), + ) + score_with_sample_data = gl_pipeline_components[2]( + predictions=predict_with_sample_data.outputs.predictions, + ground_truth=ground_truth_data, + score_config=score_config, + ) + + pipeline_outputs = { + "pipeline_job_prepped_data": prepare_sample_data.outputs.prep_data, + "pipeline_job_predictions": predict_with_sample_data.outputs.predictions, + "pipeline_job_score_report": score_with_sample_data.outputs.score_report, + "pipeline_job_missing_refs": score_with_sample_data.outputs.missing_refs, + "pipeline_job_all_unmatched_gt": score_with_sample_data.outputs.all_unmatched_gt, + "pipeline_job_all_unmatched_pred": score_with_sample_data.outputs.all_unmatched_pred, + } + + return pipeline_outputs + + +@pipeline() +def invoice_processing_score_only( + pipeline_job_input: Input, + model_name: str, + build_reference: str, + strategy: str, + gpt_deployment_name: str, + prompt_config: str, + ground_truth_data: Input, + score_config: str, + samples_amount: int, + sampling_seed: int, + predictions_file: Input, + temperature: float, +): + """ + Run a pipeline for regression analysis on invoice data. + + Args: + pipeline_job_input (Input): The raw input data for the pipeline. + model_name (str): The name of the model to be used. + build_reference (str): A reference identifier for the build. + gpt_deployment_name(str): GPT Deployment name you used to generate the predictions. + strategy(str): strategy which was used for predictions generated. + prompt_config(str): prompt config which was used for predictions generated. + ground_truth_data(Input): ground truth input + score_config (str): dictionary loaded from file as string + samples_amount (int): amount of samples to randomly use from the data set, 0 means all + sampling_seed (int): seed for random sampling of dataset, -1 means no seed + predictions (Input): predictions generated previously. + + Returns: + dict: A dictionary containing the outputs of various stages of the pipeline: + """ + score_with_sample_data = gl_pipeline_components[2]( + predictions=predictions_file, + ground_truth=ground_truth_data, + score_config=score_config, + ) + + pipeline_outputs = { + "pipeline_job_score_report": score_with_sample_data.outputs.score_report, + "pipeline_job_missing_refs": score_with_sample_data.outputs.missing_refs, + "pipeline_job_all_unmatched_gt": score_with_sample_data.outputs.all_unmatched_gt, + "pipeline_job_all_unmatched_pred": score_with_sample_data.outputs.all_unmatched_pred, + } + + return pipeline_outputs + + +class InvoiceProcessing(PipelineJobConfig): + """ + Class for the invoice processing data Azure ML pipeline configuration and construction. + + This class extends the Pipeline class and provides specific implementations for the invoice processing data + regression pipeline. It includes methods for constructing the pipeline. + """ + + def construct_pipeline(self, ml_client): + """ + Construct a pipeline job for invoice data regression. + + Args: + ml_client: The Azure ML client to use for retrieving data assets and components. + + Returns: + pipeline_job: The constructed pipeline job components. + """ + + registered_data_asset = ml_client.data.get( + name=self.dataset_name, label="latest" + ) + + registered_gt_asset = ml_client.data.get(name=self.gt_name, label="latest") + + parent_dir = os.path.join(os.getcwd(), "mlops/invoice_processing/components") + + components = ["prep", "predict", "score"] + + for component in components: + comp = load_component(source=f"{parent_dir}/{component}.yml") + comp.environment = self.environment_name + gl_pipeline_components.append(comp) + + experiment_config = yaml.safe_load(open("config/experiment_config.yaml")) + + pipeline_inputs = { + "pipeline_job_input": Input( + type="uri_folder", path=registered_data_asset.id + ), + "model_name": self.model_name, + "build_reference": self.build_reference, + "strategy": (experiment_config["predict_config"])["strategy"], + "temperature": (experiment_config["predict_config"])["temperature"], + "gpt_deployment_name": (experiment_config["predict_config"])[ + "gpt_deployment_name" + ], + "prompt_config": str( + (experiment_config["predict_config"])["prompt_config"] + ), + "ground_truth_data": Input(type="uri_folder", path=registered_gt_asset.id), + "score_config": str(experiment_config["score_config"]), + "samples_amount": (experiment_config["prep_config"])["samples_amount"], + "sampling_seed": (experiment_config["prep_config"])["sampling_seed"], + } + + if self.predictions is not None: + prediction_file = ml_client.data.get(name=self.predictions, label="latest") + pipeline_inputs["predictions_file"] = Input( + type="uri_folder", path=prediction_file.id + ) + pipeline_job = invoice_processing_score_only(**pipeline_inputs) + else: + pipeline_job = invoice_processing_data_regression(**pipeline_inputs) + + # demo how to change pipeline output settings + # pipeline_job.outputs.pipeline_job_prepped_data.mode = "rw_mount" + + return pipeline_job + + +def prepare_and_execute( + model_name: str, + build_environment: str, + wait_for_completion: str, + output_file: str, + predictions: str = None, +): + """ + Prepare and execute the pipeline. + + Args: + model_name (str): The name of the model. + build_environment (str): The build environment configuration. + wait_for_completion (str): Whether to wait for the pipeline job to complete. + output_file (str): A file to save the run ID. + """ + config = MLOpsConfig(environment=build_environment) + + pipeline_config = config.get_pipeline_config(model_name) + published_model_name = generate_model_name(model_name) + experiment_description = config.get_experiment_description() + + pipeline_job_config = InvoiceProcessing( + environment_name=None, # will be set in prepare_and_execute_pipeline + build_reference=config.environment_configuration["build_reference"], + published_model_name=published_model_name, + dataset_name=pipeline_config["dataset_name"], + gt_name=pipeline_config["gt_name"], + build_environment=build_environment, + wait_for_completion=wait_for_completion, + output_file=output_file, + model_name=model_name, + predictions=predictions, + ) + + prepare_and_execute_pipeline(pipeline_job_config, experiment_description) + + +def main(): + """Parse the command line arguments and call the `prepare_and_execute` function.""" + parser = argparse.ArgumentParser("build_environment") + parser.add_argument( + "--model_name", type=str, help="name of the model", default="invoice_processing" + ) + parser.add_argument( + "--build_environment", + type=str, + help="configuration environment for the pipeline", + ) + parser.add_argument( + "--wait_for_completion", + type=str, + help="determine if pipeline should wait for job completion", + default="True", + ) + parser.add_argument( + "--output_file", type=str, required=False, help="A file to save run id" + ) + parser.add_argument( + "--predictions", type=str, required=False, help="Name of the predictions file" + ) + + args = parser.parse_args() + + prepare_and_execute( + args.model_name, + args.build_environment, + args.wait_for_completion, + args.output_file, + args.predictions, + ) + + +if __name__ == "__main__": + main() diff --git a/mlops/invoice_processing/start_local_pipeline.py b/mlops/invoice_processing/start_local_pipeline.py new file mode 100644 index 00000000..25aa5855 --- /dev/null +++ b/mlops/invoice_processing/start_local_pipeline.py @@ -0,0 +1,5 @@ +"""The script invokes prepare_and_execute to test it from a local computer.""" +from mlops.invoice_processing.src import mlops_pipeline + +if __name__ == "__main__": + mlops_pipeline.prepare_and_execute("invoice_processing", "pr", "True", None, None) diff --git a/src/invoice_processing/__init__.py b/src/invoice_processing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/invoice_processing/predict_component/predict/__init__.py b/src/invoice_processing/predict_component/predict/__init__.py new file mode 100644 index 00000000..ca0e42ea --- /dev/null +++ b/src/invoice_processing/predict_component/predict/__init__.py @@ -0,0 +1,3 @@ +import logging + +logging.basicConfig(level=logging.INFO) diff --git a/src/invoice_processing/predict_component/predict/data_extraction/__init__.py b/src/invoice_processing/predict_component/predict/data_extraction/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/invoice_processing/predict_component/predict/data_extraction/config/__init__.py b/src/invoice_processing/predict_component/predict/data_extraction/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/invoice_processing/predict_component/predict/data_extraction/config/configuration_container.py b/src/invoice_processing/predict_component/predict/data_extraction/config/configuration_container.py new file mode 100644 index 00000000..174871eb --- /dev/null +++ b/src/invoice_processing/predict_component/predict/data_extraction/config/configuration_container.py @@ -0,0 +1,24 @@ +class ConfigurationContainer: + """A simple service container to store configurations for extractors.""" + _config_registry = {} + + @classmethod + def register_config(cls, extractor_name: str, config: dict): + """Register configuration for a specific extractor.""" + cls._config_registry[extractor_name] = config + + @classmethod + def get_config(cls, extractor_name: str) -> dict: + """Retrieve the configuration for a specific extractor.""" + if extractor_name not in cls._config_registry: + return {} + return cls._config_registry[extractor_name] + + @classmethod + def load_configs_from_file(cls, filepath: str): + """Load configurations from a JSON file.""" + import json + with open(filepath, 'r') as f: + configs = json.load(f) + for extractor_name, config in configs.items(): + cls.register_config(extractor_name, config) diff --git a/src/invoice_processing/predict_component/predict/data_extraction/data_extractor_factory.py b/src/invoice_processing/predict_component/predict/data_extraction/data_extractor_factory.py new file mode 100644 index 00000000..b46880ca --- /dev/null +++ b/src/invoice_processing/predict_component/predict/data_extraction/data_extractor_factory.py @@ -0,0 +1,55 @@ +from .config.configuration_container import ConfigurationContainer +from .extractors.base_extractor import Extractor, LoggerProxy + + +class DataExtractorFactory: + """Factory to dynamically load and manage data extractors by category.""" + + _registry = {} + + @classmethod + def load_default_extractors(cls) -> None: + from .extractors.gpt_only_extractor import GPTOnlyExtractor + cls.register("gpt_only", "invoice", GPTOnlyExtractor) + + @classmethod + def register(cls, name: str, category: str, extractor_cls: type) -> None: + """Register a data extractor class under a category.""" + if not issubclass(extractor_cls, Extractor): + raise ValueError( + f"{extractor_cls} is not a subclass of Extractor" + ) + if category not in cls._registry: + cls._registry[category] = {} + cls._registry[category][name] = extractor_cls + + @classmethod + def list_categories(cls) -> list[str]: + """List all available categories.""" + return list(cls._registry.keys()) + + @classmethod + def list_extractors(cls, category: str) -> list[str]: + """List all extractors in a specific category.""" + if category not in cls._registry: + raise ValueError(f"Category {category} is not registered") + return list(cls._registry[category].keys()) + + @classmethod + def create(cls, category: str, name: str, additional_config: dict, + logger_proxy: LoggerProxy) -> Extractor: + """Create an instance of extractor by category and name.""" + + if (category not in cls._registry or name not in cls._registry[category]): + raise ValueError( + f"Extractor {name} in category {category} is not registered" + ) + # Get extractor class + extractor_cls = cls._registry[category][name] + + # Get configuration from the ServiceContainer + config = ConfigurationContainer.get_config(name) + config.update(additional_config) + + # Instantiate the extractor with the configuration + return extractor_cls(config, logger_proxy) diff --git a/src/invoice_processing/predict_component/predict/data_extraction/extractors/__init__.py b/src/invoice_processing/predict_component/predict/data_extraction/extractors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/invoice_processing/predict_component/predict/data_extraction/extractors/base_extractor.py b/src/invoice_processing/predict_component/predict/data_extraction/extractors/base_extractor.py new file mode 100644 index 00000000..699059c2 --- /dev/null +++ b/src/invoice_processing/predict_component/predict/data_extraction/extractors/base_extractor.py @@ -0,0 +1,25 @@ +"""This class is an interface """ +from abc import ABC, abstractmethod + +from ..models.extraction_response import ( + ExtractionResponse +) + + +class LoggerProxy(ABC): + @abstractmethod + def log_metric(self, key: str, value: float) -> None: + pass + + +class Extractor(ABC): + """Abstract class to define extractor functionalities and data.""" + + def __init__(self, config: dict, logger_proxy: LoggerProxy): + self.config = config + self.logger_proxy = logger_proxy + + @abstractmethod + def extract_data(self, file) -> ExtractionResponse: + """Extract structured data from the given the input.""" + pass diff --git a/src/invoice_processing/predict_component/predict/data_extraction/extractors/gpt_only_extractor.py b/src/invoice_processing/predict_component/predict/data_extraction/extractors/gpt_only_extractor.py new file mode 100644 index 00000000..a98f7f5e --- /dev/null +++ b/src/invoice_processing/predict_component/predict/data_extraction/extractors/gpt_only_extractor.py @@ -0,0 +1,154 @@ +import json +import logging +from python_retry import retry + +from openai import AzureOpenAI +from .base_extractor import ( + Extractor, + LoggerProxy +) +from ..models.extraction_response import ( + ExtractionResponse, + Invoice, + LineItem, + Provider, + ServiceFor +) +from ..prompts.prompt_manager import PromptManager + +log = logging.getLogger(__name__) + + +class GPTOnlyExtractor(Extractor): + """ + Extraction implementation use Azure OpenAI Model + Args: + config (dict): The configuration dictionary. the following values are expected: + - azure_openai_endpoint (str): Azure OpenAI endpoint. + - azure_openai_api_key (str): Azure OpenAI API key. + (optional, either provide azure_openai_api_key, or a token_provider). + - api_version (str): Azure OpenAI API version. + - token_provider (lambda: str): OAuth JWT token provider method. + (optional, either provide azure_openai_api_key, or a token_provider). + - custom_hdeaders (str): Custom headers to add to the request to OpenAI. + - gpt_deployment_name (str): Azure OpenAI API deployment name. + - prompt_config (dict): Prompt configuration. + logger_proxy (LoggerProxy): Metrics logger. + """ + + def __init__(self, config: dict, logger_proxy: LoggerProxy): + self.client = AzureOpenAI( + azure_endpoint=config.get("azure_openai_endpoint"), + api_key=config.get("azure_openai_api_key"), + api_version=config.get("api_version", "2024-08-01-preview"), + azure_ad_token_provider=config.get("token_provider"), + default_headers=config.get("custom_headers") + ) + self.gpt_deployment_name = config.get("gpt_deployment_name") + self.prompt_config = config.get('prompt_config', {}) + self.temperature = config.get('temperature', 0.0) + super().__init__(config, logger_proxy) + + @retry( + retry_on=(Exception,), + max_retries=5, + backoff_factor=2, + retry_logger=log + ) + def extract_data(self, file) -> ExtractionResponse: + """ + Process an input file using the Azure OpenAI model and save the output. + + Args: + input_path (str): Path to an input image file. + output_folder (str): Path to the output folder. + + Returns: + Optional[InvoiceData]: Parsed response from the model. + """ + messages = self.create_prompt(file) + + completion = self.client.beta.chat.completions.parse( + model=self.gpt_deployment_name, + messages=messages, + response_format=ExtractionResponse, + temperature=self.temperature + ) + + if completion and completion.choices and completion.choices[0].message.parsed: + event = completion.choices[0].message.parsed + log.debug(completion.model_dump_json(indent=2)) + self.logger_proxy.log_metric("completion_tokens", completion.usage.completion_tokens) + self.logger_proxy.log_metric("prompt_tokens", completion.usage.prompt_tokens) + + event.metadata = { + "completion_tokens": completion.usage.completion_tokens, + "prompt_tokens": completion.usage.prompt_tokens + } + return event + else: + log.error("No completion returned or no choices in completion.") + return None + + def create_prompt(self, base64_image): + """ + Create a prompt for the Azure OpenAI model. + + Args: + base64_image (str): Base64 encoded image string. + + Returns: + List[dict]: List of messages for the prompt. + """ + structure = ExtractionResponse( + invoice=Invoice( + totalClaimAmount=0.0, + provider=Provider( + name="" + ), + serviceFor=ServiceFor( + name="" + ), + lineItems=[ + LineItem( + amount=0.0, + text="", + transactionType="", + serviceStartDate="", + serviceEndDate="" + ) + ] + ) + ).model_dump() + + user_prompt = PromptManager.get_prompt( + self.prompt_config["prompt_name"], + line_item_instructions=self.prompt_config["line_item_instructions"], + structure={json.dumps(structure)} + ) + + user_prompt_formatted_with_image = [ + { + "type": "text", + "text": ( + user_prompt + ) + }, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{base64_image}"} + } + ] + + messages = [ + { + "role": "system", + "content": + "You are an AI assistant that analyzes the text provided " + "and supplemented images and returns them as structured JSON objects. " + "Do not return as a code block." + }, + {"role": "user", "content": user_prompt_formatted_with_image} + ] + + return messages diff --git a/src/invoice_processing/predict_component/predict/data_extraction/models/__init__.py b/src/invoice_processing/predict_component/predict/data_extraction/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/invoice_processing/predict_component/predict/data_extraction/models/extraction_response.py b/src/invoice_processing/predict_component/predict/data_extraction/models/extraction_response.py new file mode 100644 index 00000000..13f9007e --- /dev/null +++ b/src/invoice_processing/predict_component/predict/data_extraction/models/extraction_response.py @@ -0,0 +1,36 @@ +from pydantic import BaseModel +from typing import Optional, List + + +class LineItem(BaseModel): + """Represents a line item in a structured format.""" + amount: float + text: str + transactionType: str # noqa: N815 + serviceStartDate: str # noqa: N815 + serviceEndDate: str # noqa: N815 + miles: Optional[int] = None + + +class Provider(BaseModel): + """Represents a provider in a structured format.""" + name: str + + +class ServiceFor(BaseModel): + """Represents the person the service was provided for in a structured format.""" + name: str + + +class Invoice(BaseModel): + """Represents an invoice in a structured format.""" + totalClaimAmount: float # noqa: N815 + provider: Provider + serviceFor: ServiceFor # noqa: N815 + lineItems: List[LineItem] # noqa: N815 + + +class ExtractionResponse(BaseModel): + """Represents extracted data in a structured format.""" + invoice: Invoice + metadata: Optional[dict] = None diff --git a/src/invoice_processing/predict_component/predict/data_extraction/prompts/__init__.py b/src/invoice_processing/predict_component/predict/data_extraction/prompts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/invoice_processing/predict_component/predict/data_extraction/prompts/prompt_manager.py b/src/invoice_processing/predict_component/predict/data_extraction/prompts/prompt_manager.py new file mode 100644 index 00000000..cdd0bb1b --- /dev/null +++ b/src/invoice_processing/predict_component/predict/data_extraction/prompts/prompt_manager.py @@ -0,0 +1,49 @@ +from pathlib import Path +import frontmatter +from jinja2 import Environment, FileSystemLoader, StrictUndefined, TemplateError, meta, select_autoescape + + +class PromptManager: + _env = None + + @classmethod + def _get_env(cls, templates_dir="prompts/templates"): + templates_dir = Path(__file__).parent.parent / templates_dir + if cls._env is None: + cls._env = Environment( + loader=FileSystemLoader(templates_dir), + undefined=StrictUndefined, + autoescape=select_autoescape(['html', 'xml']) + ) + return cls._env + + @staticmethod + def get_prompt(template, **kwargs): + env = PromptManager._get_env() + template_path = f"{template}.j2" + with open(env.loader.get_source(env, template_path)[1]) as file: + post = frontmatter.load(file) + + template = env.from_string(post.content) + try: + return template.render(**kwargs) + except TemplateError as e: + raise ValueError(f"Error rendering template: {str(e)}") + + @staticmethod + def get_template_info(template): + env = PromptManager._get_env() + template_path = f"{template}.j2" + with open(env.loader.get_source(env, template_path)[1]) as file: + post = frontmatter.load(file) + + ast = env.parse(post.content) + variables = meta.find_undeclared_variables(ast) + + return { + "name": template, + "description": post.metadata.get("description", "No description provided"), + "author": post.metadata.get("author", "Unknown"), + "variables": list(variables), + "frontmatter": post.metadata, + } diff --git a/src/invoice_processing/predict_component/predict/data_extraction/prompts/templates/medical_claim_reimbursement.j2 b/src/invoice_processing/predict_component/predict/data_extraction/prompts/templates/medical_claim_reimbursement.j2 new file mode 100644 index 00000000..f6c43844 --- /dev/null +++ b/src/invoice_processing/predict_component/predict/data_extraction/prompts/templates/medical_claim_reimbursement.j2 @@ -0,0 +1,14 @@ +You are an intelligent document processing system. Your task is to extract the following information from a receipt, which may be a handwritten note or a structured document. +Your output should be a JSON object containing the following fields: +totalClaimAmount: The total claim or invoice amount from the receipt. +provider.name: The name of the organization or individual providing the service. +serviceFor.name: The name of the person receiving the service. +lineItems: A list of individual charges or service lines. +{% if line_item_instructions == 'complex' %} +lineItems: A list of individual charges or service lines, each with the following fields: +- amount: The monetary value associated with the line item. +- text: Description or label for the service or charge. +- transactionType: Type of service (e.g., "mileage", "consultation", "supplies", etc.). +- serviceStartDate: Date when the service started. +- serviceEndDate: Date when the service ended. +{% else %} \ No newline at end of file diff --git a/src/invoice_processing/predict_component/predict/helpers.py b/src/invoice_processing/predict_component/predict/helpers.py new file mode 100644 index 00000000..44dc85b2 --- /dev/null +++ b/src/invoice_processing/predict_component/predict/helpers.py @@ -0,0 +1,33 @@ +import base64 +import json +import logging + +log = logging.getLogger(__name__) + + +def save_output_as_json(output, output_file_path): + """ + Save response output as a JSON file. + + Args: + output (dict): Output data to save. + output_file_path (str): Path to the output file. + """ + with open(output_file_path, 'w', encoding='utf-8') as json_file: + json.dump(output, json_file, ensure_ascii=False, indent=4) + log.info(f"Saved output to {output_file_path}") + + +def convert_image_to_base64(image_path: str) -> str: + """ + Convert an image path to a base64 encoded string. + + Args: + image_path (str): Path to the image file. + + Returns: + str: Base64 encoded image string. + """ + with open(image_path, "rb") as image_file: + base64_image = base64.b64encode(image_file.read()).decode('utf-8') + return base64_image diff --git a/src/invoice_processing/predict_component/predict/mlflow_logger.py b/src/invoice_processing/predict_component/predict/mlflow_logger.py new file mode 100644 index 00000000..4f7f288d --- /dev/null +++ b/src/invoice_processing/predict_component/predict/mlflow_logger.py @@ -0,0 +1,9 @@ +import mlflow +from .data_extraction.extractors.base_extractor import ( + LoggerProxy +) + + +class MLFlowLogger(LoggerProxy): + def log_metric(self, key: str, value: float) -> None: + mlflow.log_metric(key, value) diff --git a/src/invoice_processing/predict_component/predict/predict.py b/src/invoice_processing/predict_component/predict/predict.py new file mode 100644 index 00000000..a72e6789 --- /dev/null +++ b/src/invoice_processing/predict_component/predict/predict.py @@ -0,0 +1,246 @@ +import ast +import os +import argparse +import time +from glob import glob +import logging +import mlflow +import pandas as pd + + +from .data_extraction.data_extractor_factory import ( + DataExtractorFactory +) +from .data_extraction.extractors.base_extractor import ( + Extractor +) +from .data_extraction.models.extraction_response import ( + ExtractionResponse +) +from .mlflow_logger import MLFlowLogger +from .helpers import convert_image_to_base64, save_output_as_json + +log = logging.getLogger(__name__) + + +def predict( + strategy, + temperature, + gpt_deployment_name, + azure_openai_endpoint, + azure_openai_api_key, + prompt_config, + test_data, + prediction_path, +) -> None: + """ + Perform end-to-end initialization and processing input folder's .png and .jpg files + using the specified model type with Azure services. + + This includes: + - Initializing the Azure OpenAI client. + - Creating prompt messages for the model. + - Processing each file in the input folder. + - Saving the output as a JSON file in the output folder. + Args: + strategy (string): orchestration strategy name + temperature (float): LLM temperature + gpt_deployment_name (string): name of the GPT deployment name to use + prompt_config (string): dictionary loaded as string with prompt configuration + test_data (str): a folder with input data + prediction_path (str): a folder for storing predictions + """ + + config_dict = ast.literal_eval(prompt_config) + params = { + "gpt_deployment_name": gpt_deployment_name, + "temperature": temperature, + "prompt_name": config_dict["prompt_name"], + "line_item_instructions": config_dict["line_item_instructions"] + } + mlflow.log_params(params) + + DataExtractorFactory.load_default_extractors() + extractor = DataExtractorFactory.create("invoice", strategy, { + "azure_openai_endpoint": azure_openai_endpoint, + "azure_openai_api_key": azure_openai_api_key, + "gpt_deployment_name": gpt_deployment_name, + "temperature": temperature, + "prompt_config": config_dict + }, MLFlowLogger()) + + mlflow.log_param("strategy", strategy) + + os.makedirs(prediction_path, exist_ok=True) + + test_data_paths = glob_by_extesion(test_data, ['.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG']) + + log.info(f"Processing files in {test_data} using model/strategy {strategy}, len_imgs: {len(test_data_paths)}") + mlflow.log_metric('images_identified', + len(test_data_paths)) + + performance_df = pd.DataFrame(columns=[ + 'file_path', + 'completion_tokens', + 'prompt_tokens', + 'execution_time' + ]) + for file in test_data_paths: + file_path = os.path.join(test_data, file) + try: + extraction_response, execution_time = process(extractor, file_path, prediction_path) + mlflow.log_metric("execution_time", execution_time) + log.info(f"Execution time for {gpt_deployment_name}: {execution_time} seconds") + + performance_df = pd.concat([pd.DataFrame([[ + file_path, + extraction_response.metadata.get("completion_tokens"), + extraction_response.metadata.get("prompt_tokens"), + execution_time + ]], columns=performance_df.columns), performance_df], ignore_index=True) + except Exception as e: + log.error(f"Error processing file {file}: {e}") + + performance_df_path = "performance_results.csv" + performance_df.to_csv(f"{performance_df_path}", index=False) + mlflow.log_artifact(f"{performance_df_path}") + + mlflow.log_metric('successfully_processed_images', len(performance_df.index)) + + for column in ['completion_tokens', 'prompt_tokens', 'execution_time']: + mean = performance_df.loc[:, column].mean() + median = performance_df.loc[:, column].median() + total = performance_df.loc[:, column].sum() + mlflow.log_metric(f"mean_{column}", mean) + mlflow.log_metric(f"median_{column}", median) + mlflow.log_metric(f"total_{column}", total) + + total_input_price, total_output_price = estimate_cost(gpt_deployment_name, performance_df) + mlflow.log_metric("estimated_input_price", total_input_price) + mlflow.log_metric("estimated_output_price", total_output_price) + + +def estimate_cost(gpt_deployment_name, performance_df): + total_input_tokens = performance_df.loc[:, 'prompt_tokens'].sum() + total_output_tokens = performance_df.loc[:, 'completion_tokens'].sum() + if gpt_deployment_name == 'gpt-4o': + input_price_in_usd = 2.5 + output_price_in_usd = 10 + elif gpt_deployment_name == 'gpt-4o-mini': + input_price_in_usd = 0.15 + output_price_in_usd = 0.6 + else: + input_price_in_usd = 0 + output_price_in_usd = 0 + log.error(f"{gpt_deployment_name} not included in estimate_cost function, " + + "please add price logic to estimate cost") + per_one_mln_tokens = 1000000 + total_input_price = total_input_tokens * (input_price_in_usd / per_one_mln_tokens) + total_output_price = total_output_tokens * (output_price_in_usd / per_one_mln_tokens) + return total_input_price, total_output_price + + +def glob_by_extesion(test_data, types): + all_images = [] + for type in types: + arr = glob(f'{test_data}/*{type}') + all_images += arr + return all_images + + +def process(extractor: Extractor, input_path, output_folder) -> tuple[ExtractionResponse, float]: + base64_image = convert_image_to_base64(input_path) + + start_time = time.time() + extraction_response = extractor.extract_data(base64_image) + end_time = time.time() + execution_time = end_time - start_time + + json_base_name = os.path.splitext(os.path.basename(input_path))[0] + output_file_path = os.path.join(output_folder, f"{json_base_name}_result.json") + save_output_as_json(extraction_response.model_dump(), output_file_path) + return extraction_response, execution_time + + +def main( + strategy, + temperature, + gpt_deployment_name, + azure_openai_endpoint, + azure_openai_api_key, + prompt_config, + test_data, + prediction_path, +): + """Load test data, call predict function. + + Args: + strategy (string): orchestration strategy name + temperature (float): LLM temperature + gpt_deployment_name (string): name of the GPT deployment name to use + prompt_config (string): dictionary with prompt configuration + test_data (string): path to test data + prediction_path (string): path to which to write predictions + """ + lines = [ + f"Orchestration strategy: {strategy}", + f"Temperature: {temperature}", + f"GPT deployment name: {gpt_deployment_name}", + f"Predict configuration: {prompt_config}", + f"Test data path: {test_data}", + f"Predictions path: {prediction_path}", + ] + + for line in lines: + log.info(line) + + predict( + strategy, + temperature, + gpt_deployment_name, + azure_openai_endpoint, + azure_openai_api_key, + prompt_config, + test_data, + prediction_path, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("predict") + parser.add_argument("--strategy", type=str, help="Orchestration strategy") + parser.add_argument("--temperature", type=float, help="LLM temperature") + + parser.add_argument("--gpt_deployment_name", type=str, help="GPT deployment name") + parser.add_argument( + "--azure_openai_endpoint", type=str, help="Azure OpenaAI endpoint" + ) + parser.add_argument( + "--azure_openai_api_key", type=str, help="Azure OpenaAI API key" + ) + parser.add_argument("--prompt_config", type=str, help="Config dictionary") + parser.add_argument("--test_data", type=str, help="Path to test data") + parser.add_argument("--predictions", type=str, help="Path of predictions") + + args = parser.parse_args() + + log.debug("Predict started... arguments parsed successfully.") + + strategy = args.strategy + temperature = args.temperature + gpt_deployment_name = args.gpt_deployment_name + prompt_config = args.prompt_config + test_data = args.test_data + prediction_path = args.predictions + azure_openai_endpoint = args.azure_openai_endpoint + azure_openai_api_key = args.azure_openai_api_key + main( + strategy, + temperature, + gpt_deployment_name, + azure_openai_endpoint, + azure_openai_api_key, + prompt_config, + test_data, + prediction_path, + ) diff --git a/src/invoice_processing/predict_component/predict/pyproject.toml b/src/invoice_processing/predict_component/predict/pyproject.toml new file mode 100644 index 00000000..bedb37cd --- /dev/null +++ b/src/invoice_processing/predict_component/predict/pyproject.toml @@ -0,0 +1,25 @@ +[project] +version = "0.0.10" +description = "A shared package of extraction strategies for content digitization." +readme = "README.md" +requires-python = ">=3.8" +dependencies = [ + "openai", + "pydantic", + "jinja2", + "python-frontmatter", + "python-retry" +] + + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + + +[tool.setuptools.packages.find] +where = ["."] +include = ["data_extraction*"] + +[tool.setuptools.package-data] +"*" = ["*.j2"] \ No newline at end of file diff --git a/src/invoice_processing/prep_component/prep/__init__.py b/src/invoice_processing/prep_component/prep/__init__.py new file mode 100644 index 00000000..ca0e42ea --- /dev/null +++ b/src/invoice_processing/prep_component/prep/__init__.py @@ -0,0 +1,3 @@ +import logging + +logging.basicConfig(level=logging.INFO) diff --git a/src/invoice_processing/prep_component/prep/prep.py b/src/invoice_processing/prep_component/prep/prep.py new file mode 100644 index 00000000..5311f19e --- /dev/null +++ b/src/invoice_processing/prep_component/prep/prep.py @@ -0,0 +1,93 @@ +import argparse +import os +import random +import shutil +import logging +import mlflow + +log = logging.getLogger(__name__) + + +def sample_data(data_paths, samples_amount, sampling_seed): + """ + Samples randomly number of data paths based on input amount and seed. + + Parameters: + data_paths (str): paths to files + samples_amount (int): amount of samples to randomly use from the data set, 0 means all + sampling_seed (int): seed for random sampling of dataset, -1 means no seed + """ + sampled_data_paths = data_paths + if samples_amount > 0: + data_paths_len = len(data_paths) + + if sampling_seed != -1: + random.seed(sampling_seed) + + sampled_data_paths = random.sample(data_paths, samples_amount) + print(f"filtered samples array from {data_paths_len} to {len(data_paths)}") + print(sampled_data_paths) + + return sampled_data_paths + + +def main(raw_data, prep_data, samples_amount, sampling_seed): + """ + Read existing jpg and png files and invoke preprocessing step. + + Parameters: + raw_data (str): a folder to read csv files + prep_data (str): a folder for preprocessed data + samples_amount (int): amount of samples to randomly use from the data set, 0 means all + sampling_seed (int): seed for random sampling of dataset, -1 means no seed + """ + + mlflow.log_param('number_of_samples', samples_amount) + + lines = [ + f"Raw data path: {raw_data}", + f"Data output path: {prep_data}", + ] + + for line in lines: + log.info(line) + + data_paths = os.listdir(raw_data) + log.debug(f"mounted_path files: {data_paths}") + + data_paths = sample_data(data_paths, samples_amount, sampling_seed) + + os.makedirs(prep_data, exist_ok=True) + for filename in data_paths: + log.info("reading file: %s ..." % filename) + destination = os.path.join(prep_data, filename) + source = os.path.join(raw_data, filename) + shutil.copy(source, destination) + log.info("saving file: %s ..." % destination) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--raw_data", + type=str, + default="../data/raw_data", + help="Path to raw data", + ) + parser.add_argument( + "--prep_data", type=str, default="../data/prep_data", help="Path to prep data" + ) + parser.add_argument( + "--samples_amount", required=False, type=int, default=0, + help="Amount of samples to randomly use from the data set, 0 means all," + ) + parser.add_argument( + "--sampling_seed", required=False, type=int, default=-1, + help="Seed for random sampling of dataset, -1 means no seed," + ) + + args = parser.parse_args() + + log.debug("Prep started... arguments parsed successfully.") + + main(args.raw_data, args.prep_data, args.samples_amount, args.sampling_seed) diff --git a/src/invoice_processing/score_component/score/__init__.py b/src/invoice_processing/score_component/score/__init__.py new file mode 100644 index 00000000..ca0e42ea --- /dev/null +++ b/src/invoice_processing/score_component/score/__init__.py @@ -0,0 +1,3 @@ +import logging + +logging.basicConfig(level=logging.INFO) diff --git a/src/invoice_processing/score_component/score/extraction_evaluator.py b/src/invoice_processing/score_component/score/extraction_evaluator.py new file mode 100644 index 00000000..b6386df3 --- /dev/null +++ b/src/invoice_processing/score_component/score/extraction_evaluator.py @@ -0,0 +1,293 @@ +""" +This module contains the ExtractionEvaluator class, which provides evaluation methods for data extraction from images. +The class includes functionalities for calculating various metrics and generating reports. + +Classes: ExtractionEvaluation: a class for evaluation of data extraction from images +""" + +from typing import Dict, List +import pandas as pd +import logging +from .matchers.date_exact_matcher import DateExactMatcher +from .matchers.amount_exact_matcher import AmountExactMatcher +from .matchers.levenshtein_matcher import LevenshteinMatcher +from .matchers.text_exact_matcher import TextExactMatcher + +log = logging.getLogger(__name__) + + +class ExtractionEvaluator: + """ + A comprehensive evaluator for comparing invoice details between + ground truth and predictions. + + This class currently supports single file evaluation + with flexible comparison strategies. + """ + + def __init__( + self, + fuzzy_match_config: Dict, + exact_match_fields: List[str], + matchers_dict: Dict, + find_best_matches_strategy: str, + ): + """ + Initialize the base matcher with ground truth data. + + Args: + fuzzy_match_config: fuzzy match configuration + exact_match_fields: fields that are compared by exact match + matchers_dict: field to matcher type mapping + find_best_matches_strategy: The startegy to use when trying + to find the most similar line items + """ + # Configuration for line item matching + self.fuzzy_match_config = fuzzy_match_config + + # exact match fields list + self.exact_match_fields = exact_match_fields + + # fields to matcher type + self.matchers_dict = matchers_dict + + # find best matches strategy + self.find_best_matches_strategy = find_best_matches_strategy + + def get_matcher(self, matcher_class_name: str): + """ + Create an instance of the requested matcher. + Args: matcher_class_name: Name of the matcher per + field as defined in the experiment config file. + Returns: + Instance of the requested matcher class + """ + if matcher_class_name == "date_exact_match": + return DateExactMatcher() + elif matcher_class_name == "amount_exact_match": + return AmountExactMatcher() + elif matcher_class_name == "description_levenshtein": + return LevenshteinMatcher() + elif matcher_class_name == "text_exact_match": + return TextExactMatcher() + else: + print("matcher undefined!") + return None + + def get_matcher_for_best_matches_strategy(self, best_matches_strategy: str): + """ + Get find best matches strategy. + """ + if best_matches_strategy == "levenshtein": + return LevenshteinMatcher() + elif best_matches_strategy == "text_exact_match": + return TextExactMatcher() + else: + print("best matches strategy is not defined!") + return None + + def get_match_method(self, matcher_name: str): + """ + Get match method from matcher name. + Args: matcher_name: name of the matcher + Returns: match method (currently exact_match or levenshtein) + """ + match_method = "" + matcher_name_split = matcher_name.split("_") + if len(matcher_name_split) > 1: + match_method = f"{matcher_name_split[1]}_{matcher_name_split[2]}" + else: + match_method = matcher_name + return match_method + + def compare_line_item_values_per_invoice( + self, + ground_truth_df: pd.DataFrame, + predictions_df: pd.DataFrame, + ): + """ + Compare the line items in the ground trith data with the line items in the prediction data. + Find exact matches if exist and calculate fuzzy match metrics for relevent fields. + Args: + ground_truth_df: A dataframe in which each column is a different + extracted field (startDate, endDate, amount, description) + and each row represents a different line item in the invoice + predictions_df: Same as ground_truth_df only for model data extraction + Returns: + Dataframe with exact and fuzzy match metrics for all line items (all vs. all) + """ + field_names = [] + ground_truth_df["gt_index"] = range(ground_truth_df.shape[0]) + predictions_df["pred_index"] = range(predictions_df.shape[0]) + comparison_df = pd.merge( + ground_truth_df, predictions_df, how="cross", suffixes=["_gt", "_pred"] + ) + for field_name in self.matchers_dict: + curr_matcher = self.get_matcher(self.matchers_dict.get(field_name)) + matcher_name = curr_matcher.get_matcher_name() + match_method = self.get_match_method(matcher_name) + comparison_df[f"{field_name}_{match_method}"] = curr_matcher.get_match( + comparison_df, field_name + ) + field_names.append(f"{field_name}_{match_method}") + + exact_match_fields = [x for x in field_names if "exact" in x] + comparison_df["exact_match_sum"] = comparison_df[exact_match_fields].sum(axis=1) + comparison_df["similarity_score"] = comparison_df[field_names].sum(axis=1) + return comparison_df + + def get_match_results( + self, + comparison_df: pd.DataFrame(), + ): + """ + Report the line items match results and additional datasets for error analysis. + Args: + comparison_df (pd.DataFrame): dataframe will all possible combinations + of line items from the ground truth and the predictions. + best_matches_dict (Dict): Dictionary with lists per fuzzy match method, + of pairs of matched ground truth and prediction line items. + Returns: + """ + unmatched_gt = pd.DataFrame() + gt_cols = [] + unmatched_pred = pd.DataFrame() + pred_cols = [] + match_results_df = pd.DataFrame() + best_matches_matcher = self.get_matcher_for_best_matches_strategy( + self.find_best_matches_strategy + ) + description_matcher_name = best_matches_matcher.get_matcher_name() + best_matches_dict = best_matches_matcher.find_best_matches( + comparison_df, self.fuzzy_match_config + ) + best_matches_pairs_list = best_matches_dict[description_matcher_name] + + best_matches_pairs_df = pd.DataFrame(best_matches_pairs_list) + best_matches_df = pd.merge( + best_matches_pairs_df, + comparison_df, + on=["gt_index", "pred_index"], + how="left", + ) + best_matches_gt_index = best_matches_pairs_df["gt_index"].unique().tolist() + best_matches_pred_index = best_matches_pairs_df["pred_index"].unique().tolist() + unmatched_gt = comparison_df[ + ~comparison_df["gt_index"].isin(best_matches_gt_index) + ].drop_duplicates(subset=["gt_index"]) + unmatched_pred = comparison_df[ + ~comparison_df["pred_index"].isin(best_matches_pred_index) + ].drop_duplicates(subset=["pred_index"]) + gt_cols = [x for x in comparison_df.columns.tolist() if "_gt" in x] + pred_cols = [x for x in comparison_df.columns.tolist() if "_pred" in x] + match_results_df = pd.concat( + [best_matches_df, unmatched_gt[gt_cols], unmatched_pred[pred_cols]] + ).fillna(0) + return ( + match_results_df, + unmatched_gt[gt_cols], + unmatched_pred[pred_cols], + best_matches_df, + ) + + def calculate_evaluation_metrics_per_field_in_invoice( + self, match_results_df: pd.DataFrame + ): + """ + Calculate the evaluation metric per invoice per field. Currently supports accuracy. + Args: + match_results_df (pd.DataFrame): A dataframe with all line items from + the ground truth and the predictions: line items of the ground truth and + their match from the predictions, line items from the ground truth and the + predictions that were not matched. + Returns: + results_df (DataFrame): A dataframe with the accuracy results + per line item for a single invoice. + """ + field_col_names = [ + f"{x}_{self.get_match_method(self.get_matcher(self.matchers_dict.get(x)).get_matcher_name())}" + for x in self.matchers_dict.keys() + ] + matches_eval_fields = match_results_df[field_col_names] + results_df = self.calculate_mean_accuracy_per_invoice(matches_eval_fields) + results_df = results_df.reset_index().rename(columns={"index": "field_name"}) + results_df = results_df[["field_name", "accuracy"]].sort_values(by="field_name") + return results_df + + def calculate_mean_accuracy_per_invoice(self, matches_eval_fields: pd.DataFrame): + """ + Calcualte the mean accuracy per field in a single invoice. + Args: + matches_eval_fields (pd.DataFrame): A dataframe with the ressulting matches + per line item in the ground truth data which includes only the fields we would + like to include in the evaluation. + Returns: + mean_accuracy_df (pd.DataFrame): A dataframe with the + accuracy results. + """ + mean_accuracy_df = matches_eval_fields.mean() + accuracy_df = pd.DataFrame(mean_accuracy_df.T).rename(columns={0: "accuracy"}) + accuracy_df = accuracy_df.round({"accuracy": 3}) + return accuracy_df + + def calculate_mean_accuracy_per_batch(self, all_invoices_results: pd.DataFrame): + """ + Calcualte the mean accuracy per field in a batch of invoices. + Args: + all_invoices_results (pd.DataFrame): A dataframe with the mean accuracy + results of all invoices in the experiment. + Returns: + final_results_df (pd.DataFrame): A dataframe with the mean + accuracy results across all invoices. + overall_accuracy (float): Mean accuracy across all fields and invoices. + """ + final_results_df = ( + all_invoices_results[["field_name", "accuracy"]] + .groupby(by="field_name") + .mean() + .reset_index() + ) + overall_accuracy = round(final_results_df["accuracy"].mean(), 3) + final_results_df = final_results_df[["field_name", "accuracy"]] + final_results_df = final_results_df.round({"accuracy": 3}) + return overall_accuracy, final_results_df + + def calculate_precision_per_record( + self, unmatched_pred: pd.DataFrame(), best_matches_df: pd.DataFrame() + ): + """ + This function calculates the precision per invoice (record). + Args: + unmatched_pred: Dataframe of line items in the extracted data that + were not matched to any ground truth line item (defined as FPs). + best_matches_df: Dataframe with the matched line items from + the ground truth data and the extractions. + Returns: + precision (float): The precision metric + """ + tp = 0 + fp = 0 + fp = unmatched_pred.shape[0] + tp = best_matches_df.shape[0] + precision = tp / (tp + fp) + return precision + + def calculate_recall_per_record( + self, unmatched_gt: pd.DataFrame(), best_matches_df: pd.DataFrame() + ): + """ + This function calculates the precision per invoice (record). + Args: + unmatched_gt: Dataframe of line items in the ground truth data that + were not matched to any extracted line item (defined as FNs). + best_matches_df: Dataframe with the matched line items from + the ground truth data and the extractions. + Returns: + recall (float): The recall metric + """ + tp = 0 + fn = 0 + fn = unmatched_gt.shape[0] + tp = best_matches_df.shape[0] + recall = tp / (tp + fn) + return recall diff --git a/src/invoice_processing/score_component/score/matchers/__init__.py b/src/invoice_processing/score_component/score/matchers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/invoice_processing/score_component/score/matchers/amount_exact_matcher.py b/src/invoice_processing/score_component/score/matchers/amount_exact_matcher.py new file mode 100644 index 00000000..23b3ef96 --- /dev/null +++ b/src/invoice_processing/score_component/score/matchers/amount_exact_matcher.py @@ -0,0 +1,51 @@ +"""Class that performs exact matches for amounts""" + +import logging + +from .base_matcher import BaseMatcher +from ..utils import preprocess_amount + +log = logging.getLogger(__name__) + + +class AmountExactMatcher(BaseMatcher): + """ + Calculate amount exact match. + """ + + def amount_exact_match(self, amount1, amount2): + """ + Find out whether the amounts in ground truth and prediction are equal + Args: + amount_str1: First amount value + amount_str2: Second amount value + Returns: + match: whether the compared values are equal (exact match) + """ + match = False + parsed_amount1 = preprocess_amount(amount1) + parsed_amount2 = preprocess_amount(amount2) + diff = abs(float(parsed_amount1) - float(parsed_amount2)) + if diff == 0: + match = True + else: + match = False + return match + + def get_matcher_name(self): + """ + return matcher name. + """ + return "amount_exact_match" + + def get_match(self, comparison_df, field_name): + """ + Get match result per line item. + """ + match_df = comparison_df.apply( + lambda x: self.amount_exact_match( + x[f"{field_name}_gt"], x[f"{field_name}_pred"] + ), + axis=1, + ) + return match_df diff --git a/src/invoice_processing/score_component/score/matchers/base_matcher.py b/src/invoice_processing/score_component/score/matchers/base_matcher.py new file mode 100644 index 00000000..8d2b8d34 --- /dev/null +++ b/src/invoice_processing/score_component/score/matchers/base_matcher.py @@ -0,0 +1,17 @@ +"""This class is an interface """ + +from abc import ABC, abstractmethod + + +class BaseMatcher(ABC): + """ + Abstract class to define matcher base functions. + """ + + @abstractmethod + def get_match(self): + pass + + @abstractmethod + def get_matcher_name(self): + pass diff --git a/src/invoice_processing/score_component/score/matchers/date_exact_matcher.py b/src/invoice_processing/score_component/score/matchers/date_exact_matcher.py new file mode 100644 index 00000000..de18fe73 --- /dev/null +++ b/src/invoice_processing/score_component/score/matchers/date_exact_matcher.py @@ -0,0 +1,57 @@ +"""Class that performs exact matches for dates""" + +import logging + +from .base_matcher import BaseMatcher +from ..utils import preprocess_date + +log = logging.getLogger(__name__) + + +class DateExactMatcher(BaseMatcher): + + def dates_exact_match(self, date_str1: str, date_str2: str): + """ + Find out whether the dates are identical. + Args: + date_str1: First date value + date_str2: Second date value + Returns: + match: whether the compared values are equal (exact match) + """ + match = False + try: + date1 = preprocess_date(date_str1) + date2 = preprocess_date(date_str2) + if date1 == date2: + match = True + else: + match = False + except ValueError: + log.debug( + """One or more of the date strings could not be parsed into date type, + performing string comparison instead""" + ) + if date_str1 == date_str2: + match = True + else: + match = False + return match + + def get_matcher_name(self): + """ + return matcher name. + """ + return "date_exact_match" + + def get_match(self, comparison_df, field_name): + """ + Get match result per line item. + """ + match_df = comparison_df.apply( + lambda x: self.dates_exact_match( + x[f"{field_name}_gt"], x[f"{field_name}_pred"] + ), + axis=1, + ) + return match_df diff --git a/src/invoice_processing/score_component/score/matchers/levenshtein_matcher.py b/src/invoice_processing/score_component/score/matchers/levenshtein_matcher.py new file mode 100644 index 00000000..d4fdb972 --- /dev/null +++ b/src/invoice_processing/score_component/score/matchers/levenshtein_matcher.py @@ -0,0 +1,89 @@ +"""Class that performs exact matches for dates""" + +import logging +from typing import Dict +import pandas as pd +import Levenshtein +from ..utils import normalize_string +from .base_matcher import BaseMatcher + +log = logging.getLogger(__name__) + + +class LevenshteinMatcher(BaseMatcher): + + def get_matcher_name(self): + return "levenshtein" + + def calculate_levenshtein_ratio(self, string1: str, string2: str): + """ + Calculates the Levenshtein ratio between two strings. + + The Levenshtein ratio is a measure of the similarity between two strings, + defined as the ratio of the Levenshtein distance to the length of the longer string. + It range from 0 to 1, where 1 indicates identical strings and 0 indicates + completely different strings. + + Args: + string1 (str): The first string + string2 (str): The second string + + Returns: + levenshtein_ratio (float): The calculated Lenshtein ratio. + """ + normalized_gt = normalize_string(str(string1)) + normalized_gen = normalize_string(str(string2)) + levenshtein_ratio = Levenshtein.ratio(normalized_gt, normalized_gen) + rounded_levenshtein_ratio = round(levenshtein_ratio, 3) + return rounded_levenshtein_ratio + + def get_match(self, comparison_df, field_name): + """ + Get match result per line item. Calculates Levenshtein ratio. + """ + match_df = comparison_df.apply( + lambda x: self.calculate_levenshtein_ratio( + x[f"{field_name}_gt"], x[f"{field_name}_pred"] + ), + axis=1, + ) + return match_df + + def find_best_matches(self, comparison_df: pd.DataFrame, fuzzy_match_config: Dict): + """ + For every line item in the ground truth data, find the most similar + line item in the predictions + Args: + comparison_df: a dataframe which is the cartesian product of the + line items inthe ground truth and the predictions datasets + Returns: + A dictionary of the best matches: {"fuzzy_match_method_name": best_matches_df} + """ + + levenshtein_ratio_thr = fuzzy_match_config["field_match_threshold"] + remaining_comparisons = comparison_df.copy() + best_matches_list_levenshtein = [] + best_matches_dict = {} + for i in range(comparison_df["gt_index"].nunique()): + max_exact_matches = remaining_comparisons["exact_match_sum"].max() + similarity_thr = max_exact_matches + levenshtein_ratio_thr + curr_max_similarity = remaining_comparisons["similarity_score"].max() + if curr_max_similarity >= similarity_thr: + max_similarity_index = remaining_comparisons[ + "similarity_score" + ].argmax() + best_match = remaining_comparisons.iloc[max_similarity_index] + best_match_gt = best_match["gt_index"] + best_match_pred = best_match["pred_index"] + best_matches_list_levenshtein.append( + {"gt_index": best_match_gt, "pred_index": best_match_pred} + ) + index_to_drop = remaining_comparisons[ + (remaining_comparisons["gt_index"] == best_match_gt) + | (remaining_comparisons["pred_index"] == best_match_pred) + ].index + remaining_comparisons.drop(index_to_drop, inplace=True) + else: + continue + best_matches_dict["levenshtein"] = best_matches_list_levenshtein + return best_matches_dict diff --git a/src/invoice_processing/score_component/score/matchers/text_exact_matcher.py b/src/invoice_processing/score_component/score/matchers/text_exact_matcher.py new file mode 100644 index 00000000..612c24d0 --- /dev/null +++ b/src/invoice_processing/score_component/score/matchers/text_exact_matcher.py @@ -0,0 +1,82 @@ +"""Class that performs exact matches for text""" + +import logging + +from .base_matcher import BaseMatcher +from ..utils import normalize_string + +log = logging.getLogger(__name__) + + +class TextExactMatcher(BaseMatcher): + + def get_matcher_name(self): + """ + return matcher name. + """ + return "text_exact_match" + + def get_match(self, comparison_df, field_name): + """ + Get match result per line item. + """ + match_df = comparison_df.apply( + lambda x: self.text_exact_match( + x[f"{field_name}_gt"], x[f"{field_name}_pred"] + ), + axis=1, + ) + return match_df + + def text_exact_match(self, str1: str, str2: str): + """ + Find out whether the dates are identical. + Args: + str1: First string value + date_str2: Second string value + Returns: + match: whether the compared values are equal (exact match) + """ + normalized_str1 = normalize_string(str(str1)) + normalized_str2 = normalize_string(str(str2)) + match = False + if normalized_str1 == normalized_str2: + match = True + else: + match = False + return match + + def find_best_matches(self, comparison_df): + """ + For every line item in the ground truth data, find the most similar + line item in the predictions + Args: + comparison_df: a dataframe which is the cartesian product of the + line items inthe ground truth and the predictions datasets + Returns: + A dictionary of the best matches: {"match_method_name": best_matches_df} + """ + remaining_comparisons = comparison_df.copy() + best_matches_list = [] + best_matches_dict = {} + for i in range(comparison_df["gt_index"].nunique()): + curr_max_similarity = remaining_comparisons["similarity_score"].max() + if curr_max_similarity > 0: + max_similarity_index = remaining_comparisons[ + "similarity_score" + ].argmax() + best_match = remaining_comparisons.iloc[max_similarity_index] + best_match_gt = best_match["gt_index"] + best_match_pred = best_match["pred_index"] + best_matches_list.append( + {"gt_index": best_match_gt, "pred_index": best_match_pred} + ) + index_to_drop = remaining_comparisons[ + (remaining_comparisons["gt_index"] == best_match_gt) + | (remaining_comparisons["pred_index"] == best_match_pred) + ].index + remaining_comparisons.drop(index_to_drop, inplace=True) + else: + continue + best_matches_dict["exact_match"] = best_matches_list + return best_matches_dict diff --git a/src/invoice_processing/score_component/score/score.py b/src/invoice_processing/score_component/score/score.py new file mode 100644 index 00000000..66516958 --- /dev/null +++ b/src/invoice_processing/score_component/score/score.py @@ -0,0 +1,439 @@ +""" +This module runs the evaluation step of the experimentation framework. +First, the ground truth data and the predictions data are read. +Next, each line item in the ground truth data is compared against +line items from the prediction data to find the best match for +each line item in the ground truth. Finally, the accuracy is +calculated per field and a general accuracy is calculated +for all fields combined. +The score results are logged into AML. +""" + +import os +import argparse + +import mlflow +import ast +import pandas as pd +import numpy as np +from tqdm import tqdm +from typing import Dict +import logging + +from .extraction_evaluator import ExtractionEvaluator +from .utils import load_json_file + +log = logging.getLogger(__name__) + + +def get_score_config(score_config_str): + """ + Load score config from dict loaded as str. + Args: + components_config: Dictionary loaded as string with configuration + Returns: + score_config_dict: Dict with score configuration + """ + log.info(f"score_config from get_score_config in score.py: {score_config_str}") + score_config = ast.literal_eval(score_config_str) + # Parse the line_items dict into a list of line item fields to compare (same for the rest) + fuzzy_match_config = score_config["fuzzy_match_config"] + exact_match_fields = [k for k, v in score_config["exact_match_fields"].items() if v] + matchers_dict = score_config["matchers_dict"] + find_best_matches_strategy = score_config["find_best_matches_strategy"] + score_config_dict = { + "fuzzy_match_config": fuzzy_match_config, + "exact_match_fields": exact_match_fields, + "matchers_dict": matchers_dict, + "find_best_matches_strategy": find_best_matches_strategy, + } + return score_config_dict + + +def create_extraction_evaluator(components_config): + """ + Initialize evaluator object + Args: + components_config: Dictionary loaded as string with configuration + Returns: + evaluator: Performs data evaluation + """ + score_config_dict = get_score_config(components_config) + log.info(f"score config dict from: {score_config_dict}") + fuzzy_match_config = score_config_dict.get("fuzzy_match_config") + exact_match_fields = score_config_dict.get("exact_match_fields") + matchers_dict = score_config_dict.get("matchers_dict") + find_best_matches_strategy = score_config_dict.get("find_best_matches_strategy") + evaluator = ExtractionEvaluator( + fuzzy_match_config=fuzzy_match_config, + exact_match_fields=exact_match_fields, + matchers_dict=matchers_dict, + find_best_matches_strategy=find_best_matches_strategy, + ) + return evaluator + + +def get_gt_and_pred_data_for_evaluation(ground_truth, predictions): + """ + Parse current JSON input to DataFrames + Args: + ground_truth: Ground truth JSON object + predictions: Predictions JSON object + Returns: + gt_data: DataFrame of the line items of the ground truth data + pred_data: DataFrame of the line items of the predictions data + """ + # normalize ground truth and predictions structure + ground_truth_invoice = ground_truth["lineItems"] + + predicted_invoice = list(predictions.values())[0] + gt_data = pd.DataFrame.from_records(ground_truth_invoice) + pred_data = pd.DataFrame.from_records(predicted_invoice["lineItems"]).rename( + columns={"text": "description", "transactionType": "TransactionType"} + ) + # If one of the dataframes is empty, create a dataframe with empty strings instead + if gt_data.shape[0] == 0: + gt_data = pd.DataFrame( + { + "serviceStartDate": "", + "serviceEndDate": "", + "amount": "", + "description": "", + }, + index=[0], + ) + if pred_data.shape[0] == 0: + pred_data = pd.DataFrame( + { + "serviceStartDate": "", + "serviceEndDate": "", + "amount": "", + "description": "", + "miles": "", + }, + index=[0], + ) + pred_data.drop("miles", axis=1, inplace=True) + pred_data.replace(to_replace=["NA", "N/A"], value="", inplace=True) + return gt_data, pred_data + + +def get_corresponding_prediction_path( + gt_path: str, pred_path: str, all_pred_data: Dict +): + """ + Get the file path of the predictions that correspond to a given ground truth file. + Args: + gt_path (str): File path to the currently evaluated ground truth data. + pred_path (str): path to the predictions directory or file path + all_pred_data (Dict): The predictions parsed data. + key: value -> prediction_file_path: prediction_parsed_data + Returns: + corresponding_pred_path (str): The path of the corresponding predictions to + provided the ground truth file. + """ + corresponding_pred_path = "" + if os.path.isdir(pred_path): + temp_file_name = gt_path.split("/")[-1].split(".")[0] + file_name = f"{temp_file_name}_gpt-4o_result.json" + corresponding_pred_path = f"{pred_path}/{file_name}" + else: + corresponding_pred_path = pred_path + return corresponding_pred_path + + +def add_ref_ids_to_result_dfs( + best_matches_df: pd.DataFrame(), + unmatched_gt: pd.DataFrame(), + unmatched_pred: pd.DataFrame(), + curr_gt_ref_id: str, + pred_path: str, +): + """ + Add reference ids or predicted data path to the reported results dataframes. + Args: + best_matches_df: Dataframe with the line items that were matched. + unmatched_gt: Dataframe with line items from the ground truth that were not matched. + unmatched_pred: Dataframe with line items from the prediction that were not matched. + curr_gt_ref_id: string of the image reference id. + pred_path: string of the path of the current prediction + Returns: + Dataframes with reference ids or paths to predicted data. + """ + best_matches_df["gt_ref"] = curr_gt_ref_id + best_matches_df["matched_to"] = pred_path + if unmatched_gt.shape[0] > 0: + unmatched_gt["gt_ref"] = curr_gt_ref_id + unmatched_gt["matched_to"] = pred_path + else: + unmatched_gt["gt_ref"] = [] + unmatched_gt["matched_to"] = [] + if unmatched_pred.shape[0] > 0: + unmatched_pred["pred_path"] = pred_path + unmatched_pred["matched_to"] = curr_gt_ref_id + else: + unmatched_pred["pred_path"] = [] + unmatched_pred["matched_to"] = [] + return best_matches_df, unmatched_gt, unmatched_pred + + +def evaluate(all_invoices_pred, all_invoices_gt, components_config): + """ + Evaluates the quality of data extraction from images by comparing + the extracted data to ground truth. This function calculates the + accuracy, precision and recall to assess the correctness of the extraction. + Args: + predictions_file_path (str): Path of the predictions + file (the extracted data) + ground_truth_path (str): Path of the ground truth file + (The true field values based on the invoice image) + components_config: Dictionary loaded as string with configuration + """ + # Create evaluator + evaluator = create_extraction_evaluator(components_config) + # we want to know which ground truth files did not have predictions + missing_predictions_paths = [] + results_list = [] + comparison_dfs_list = [] + best_matches_dfs_list = [] + all_matches_dfs_list = [] + unmatched_gt_list = [] + unmatched_pred_list = [] + precisions_list = [] + recalls_list = [] + for raw_gt_invoice in tqdm(all_invoices_gt, desc="Invoices evaluated"): + gt_file_name = raw_gt_invoice["reference_id"] + log.debug(f"gt_file_name: {gt_file_name}") + curr_gt_ref_id = gt_file_name.split(".")[0] + pred_path_list = [ + x for x in list(all_invoices_pred.keys()) if curr_gt_ref_id in x + ] + if len(pred_path_list) == 0: + pred_path = "" + else: + pred_path = pred_path_list[0] + raw_pred_invoice = all_invoices_pred.get(pred_path) + # if there is no prediction for this ground truth invoice + if raw_pred_invoice is None: + missing_predictions_paths.append(gt_file_name) + continue + gt_invoice, pred_invoice = get_gt_and_pred_data_for_evaluation( + raw_gt_invoice, raw_pred_invoice + ) + # In the future, we might want to return the matches and best_matches for further analysis + comparison_df = evaluator.compare_line_item_values_per_invoice( + ground_truth_df=gt_invoice, predictions_df=pred_invoice + ) + comparison_df["gt_ref"] = curr_gt_ref_id + comparison_df["matched_to"] = pred_path + match_results_df, unmatched_gt, unmatched_pred, best_matches_df = ( + evaluator.get_match_results(comparison_df=comparison_df) + ) + results_df = evaluator.calculate_evaluation_metrics_per_field_in_invoice( + match_results_df=match_results_df + ) + best_matches_df, unmatched_gt, unmatched_pred = add_ref_ids_to_result_dfs( + best_matches_df, unmatched_gt, unmatched_pred, curr_gt_ref_id, pred_path + ) + precision_per_invoice = evaluator.calculate_precision_per_record( + unmatched_pred, best_matches_df + ) + recall_per_invoice = evaluator.calculate_recall_per_record( + unmatched_gt, best_matches_df + ) + results_list.append(results_df) + comparison_dfs_list.append(comparison_df) + best_matches_dfs_list.append(best_matches_df) + all_matches_dfs_list.append(match_results_df) + unmatched_gt_list.append(unmatched_gt) + unmatched_pred_list.append(unmatched_pred) + precisions_list.append(precision_per_invoice) + recalls_list.append(recall_per_invoice) + all_invoices_results = pd.concat(results_list) + all_unmatched_gt = pd.concat(unmatched_gt_list) + all_unmatched_pred = pd.concat(unmatched_pred_list) + overall_accuracy, final_results_df = evaluator.calculate_mean_accuracy_per_batch( + all_invoices_results + ) + comparison_df_all = pd.concat(comparison_dfs_list).sort_values( + by="similarity_score" + ) + best_matches_all = pd.concat(best_matches_dfs_list) + all_matches_results_total = pd.concat(all_matches_dfs_list) + gt_invoices_number = len(all_invoices_gt) + pred_invoices_number = len(all_invoices_pred) + overall_precision = round(np.mean(precisions_list), 3) + overall_recall = round(np.mean(recalls_list), 3) + return ( + final_results_df, + overall_accuracy, + gt_invoices_number, + pred_invoices_number, + all_unmatched_gt, + all_unmatched_pred, + comparison_df_all, + best_matches_all, + all_matches_results_total, + overall_precision, + overall_recall, + ) + + +def log_results( + score_results_output_path: str, + final_results_df: pd.DataFrame, + all_unmatched_gt: pd.DataFrame, + all_unmatched_pred: pd.DataFrame, + overall_accuracy: float, + gt_invoices_number: int, + pred_invoices_number: int, + comparison_df_all: pd.DataFrame, + best_matches_all: pd.DataFrame, + all_matches_results: pd.DataFrame, + overall_precision: float, + overall_recall: float, +): + """ + Log score results to AML + """ + score_results_output_path = "score_results.csv" + all_unmatched_gt_path = "all_unmatched_gt.csv" + all_unmatched_pred_path = "all_unmatched_pred.csv" + comparison_df_path = "comparison_df.csv" + best_matches_path = "best_matches.csv" + all_matches_path = "all_match_results.csv" + final_results_df.to_csv(f"{score_results_output_path}", index=False) + mlflow.log_artifact(f"{score_results_output_path}") + all_unmatched_gt.to_csv(f"{all_unmatched_gt_path}", index=False) + mlflow.log_artifact(f"{all_unmatched_gt_path}") + all_unmatched_pred.to_csv(f"{all_unmatched_pred_path}", index=False) + mlflow.log_artifact(f"{all_unmatched_pred_path}") + comparison_df_all.to_csv(f"{comparison_df_path}", index=False) + mlflow.log_artifact(f"{comparison_df_path}") + best_matches_all.to_csv(f"{best_matches_path}", index=False) + mlflow.log_artifact(f"{best_matches_path}") + all_matches_results.to_csv(f"{all_matches_path}", index=False) + mlflow.log_artifact(f"{all_matches_path}") + results_dict = { + "overall_accuracy": overall_accuracy, + "number_of_ground_truth_invoices": gt_invoices_number, + "number_of_prediction_invoices": pred_invoices_number, + "number_of_ground_truth_invoices_with_partial_prediction": all_unmatched_gt[ + "gt_ref" + ].nunique(), + "overall_precision": overall_precision, + "overall_recall": overall_recall, + } + mlflow.log_metrics(results_dict) + + +def main( + predictions_path, + ground_truth_path, + score_results_path, + missing_refs_path, + all_unmatched_gt_path, + all_unmatched_pred_path, + components_config, +): + """Load ground truth and predictions data, call score function. + + Args: + predictions_path (string): path to predictions data + ground_truth_path (string): path to ground truth data + score_results_path (string): output path to which to write score results + missing_refs (string): path to which to write ground truth ref id if no + prediction for this ground truth item was found + all_unmatched_gt_path (string): output path to ground truth data without + a match in the predictions + all_unmatched_pred_path (string): output path to predictions which + were not matched to any ground truth + components_config: score config from experiment config + """ + lines = [ + f"predictions_file_path: {predictions_path}", + f"ground_truth_path: {ground_truth_path}", + f"score_results_path: {score_results_path}", + f"all_unmatched_gt_path: {all_unmatched_gt_path}", + f"all_unmatched_pred_path: {all_unmatched_pred_path}", + ] + + for line in lines: + log.info(line) + + all_invoices_gt = load_json_file(ground_truth_path) + all_invoices_pred = load_json_file(predictions_path) + ( + final_results_df, + overall_accuracy, + gt_invoices_number, + pred_invoices_number, + all_unmatched_gt, + all_unmatched_pred, + comparison_df, + best_matches, + all_matches_results, + overall_precision, + overall_recall, + ) = evaluate(all_invoices_pred, all_invoices_gt, components_config) + + log_results( + score_results_path, + final_results_df, + all_unmatched_gt, + all_unmatched_pred, + overall_accuracy, + gt_invoices_number, + pred_invoices_number, + comparison_df, + best_matches, + all_matches_results, + overall_precision, + overall_recall, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("evaluate") + parser.add_argument("--predictions", type=str, help="Path of predictions") + parser.add_argument("--ground_truth", type=str, help="Path of ground truth") + parser.add_argument( + "--missing_refs", + type=str, + help="Output path of ground truth file names without predictions", + ) + parser.add_argument( + "--score_report", type=str, help="Output path of evaluation results" + ) + parser.add_argument( + "--all_unmatched_gt_path", + type=str, + help="Output path of unmatched ground truth line items for all invoices", + ) + parser.add_argument( + "--all_unmatched_pred_path", + type=str, + help="Output path of unmatched predictions line items for all invoices", + ) + parser.add_argument("--score_config", type=str, help="Config dictionary") + + args = parser.parse_args() + + log.debug("Scoring started... arguments parsed successfully.") + + predictions_file_path = args.predictions + ground_truth_path = args.ground_truth + score_report = args.score_report + missing_refs = args.missing_refs + all_unmatched_gt_path = args.all_unmatched_gt_path + all_unmatched_pred_path = args.all_unmatched_pred_path + score_config = args.score_config + main( + predictions_file_path, + ground_truth_path, + score_report, + missing_refs, + all_unmatched_gt_path, + all_unmatched_pred_path, + score_config, + ) diff --git a/src/invoice_processing/score_component/score/utils.py b/src/invoice_processing/score_component/score/utils.py new file mode 100644 index 00000000..fd760e02 --- /dev/null +++ b/src/invoice_processing/score_component/score/utils.py @@ -0,0 +1,114 @@ +""" +utils.py + +This module contains various utility functions that can be used across +different parts of the project. + +Functions: + read_json_file(file_path): + Reads a JSON file and returns the parsed data. + normalize_string(value): + Normalize string by stripping extra whitespace and converting to lowercase. + load_csv_file(file_path): + Reads a CSV file and returns the parsed data. +""" + +import json +from pathlib import Path +from typing import Union, List +import re +import logging +from dateutil.parser import parse + +log = logging.getLogger(__name__) + + +def load_json_file(path: Union[str, Path]): + """ + Reads a JSON file and returns the parsed data. + Args: + file_path (str): The path to the JSON file to be read. + Returns: + dict: The parsed JSON data as a dictionary. + Raises: + FileNotFoundError + """ + # Load ground truth data + all_data = [] + all_data_dict = {} + data_path = Path(path) + if data_path.is_dir(): + # Multiple files in a directory + for file_path in data_path.glob("*.json"): + log.debug(f"file_path: {file_path}") + try: + with open(file_path, "r", encoding="utf-8") as f: + curr_data = json.load(f) + # For ground truth format + if isinstance(curr_data, List): + all_data = all_data + curr_data + # For predictions data format + else: + all_data_dict[str(file_path)] = curr_data + except FileNotFoundError: + log.error(f"Error: The file at {file_path} was not found") + except json.JSONDecodeError: + log.error(f"Error: The file at {file_path} is not a valid JSON") + else: + # Single file + try: + with open(path, "r", encoding="utf-8") as f: + curr_data = json.load(f) + # For ground truth format + if isinstance(curr_data, List): + all_data = all_data + curr_data + # For predictions data format + else: + all_data_dict[str(file_path)] = curr_data + except FileNotFoundError: + log.error(f"Error: The file at {file_path} was not found") + except json.JSONDecodeError: + log.error(f"Error: The file at {file_path} is not a valid JSON") + if len(all_data) > 1: + return all_data + else: + return all_data_dict + + +def normalize_string(value: str) -> str: + """ + Normalize string by stripping extra whitespace and converting to lowercase. + """ + if not isinstance(value, str): + return str(value) + value = re.sub(r"\s+", " ", value).strip().lower() + value = re.sub(r"\(\s*", "(", value) + value = re.sub(r"\s*\)", ")", value) + value = re.sub(r"day\s*\(s\)", "days", value) + return value + + +def preprocess_amount(amount): + """ + Amount pre-processing - remove parentheses and + white spaces from amount string + """ + parsed_amount = "" + if isinstance(amount, str): + parsed_amount = amount.strip() + parsed_amount = parsed_amount.replace("(", "").replace(")", "") + if len(parsed_amount) == 0: + parsed_amount = "0" + else: + parsed_amount = amount + return parsed_amount + + +def preprocess_date(date_str): + """ + Date preprocessing - remove whitespaces and parse + date string into date object + """ + date_str = date_str.strip() + date = parse(date_str) + return date diff --git a/test/invoice_processing/predict_component/predict/data_extraction/assets/config.json b/test/invoice_processing/predict_component/predict/data_extraction/assets/config.json new file mode 100644 index 00000000..f398fd5c --- /dev/null +++ b/test/invoice_processing/predict_component/predict/data_extraction/assets/config.json @@ -0,0 +1,10 @@ +{ + "gpt_only": { + "azure_openai_endpoint": "http://example", + "azure_openai_api_key": "example" + }, + "mock_ocr_extractor": { + "encoding": "utf-16", + "logging_enabled": false + } +} diff --git a/test/invoice_processing/predict_component/predict/data_extraction/assets/mock_extractor.py b/test/invoice_processing/predict_component/predict/data_extraction/assets/mock_extractor.py new file mode 100644 index 00000000..98fed720 --- /dev/null +++ b/test/invoice_processing/predict_component/predict/data_extraction/assets/mock_extractor.py @@ -0,0 +1,29 @@ +from src.invoice_processing.predict_component.predict.data_extraction.extractors.base_extractor import ( + Extractor +) +from src.invoice_processing.predict_component.predict.data_extraction.models.extraction_response import ( + ExtractionResponse, + Invoice, + Provider, + ServiceFor +) + + +class MockExtractor(Extractor): + + def extract_data(self, file) -> ExtractionResponse: + invoice = Invoice( + lineItems=[], + provider=Provider( + name="Mock Provider" + ), + serviceFor=ServiceFor( + name="Mock Patient" + ), + totalClaimAmount=1.99 + ) + + return ExtractionResponse( + invoice=invoice, + metadata={} + ) diff --git a/test/invoice_processing/predict_component/predict/data_extraction/extractors/test_gpt_only_extractor.py b/test/invoice_processing/predict_component/predict/data_extraction/extractors/test_gpt_only_extractor.py new file mode 100644 index 00000000..4d831bb9 --- /dev/null +++ b/test/invoice_processing/predict_component/predict/data_extraction/extractors/test_gpt_only_extractor.py @@ -0,0 +1,202 @@ +import unittest +from unittest.mock import patch +from openai.types.chat.parsed_chat_completion import ParsedChatCompletion, ParsedChoice, ParsedChatCompletionMessage +from openai.types.completion_usage import CompletionUsage + +from src.invoice_processing.predict_component.predict.data_extraction.extractors.gpt_only_extractor import ( + GPTOnlyExtractor +) +from src.invoice_processing.predict_component.predict.data_extraction.models.extraction_response import ( + ExtractionResponse, + Invoice, + LineItem, + Provider, + ServiceFor +) + + +class TestGPTOnlyExtractor(unittest.TestCase): + @patch('src.invoice_processing.predict_component.predict.data_extraction.extractors.gpt_only_extractor.AzureOpenAI') + @patch('src.invoice_processing.predict_component.predict.data_extraction.extractors.gpt_only_extractor.LoggerProxy') + @patch('src.invoice_processing.predict_component.predict.data_extraction.extractors.gpt_only_extractor.GPTOnlyExtractor.create_prompt') + def test_extract_data(self, mock_create_prompt, mock_logger_proxy, mock_azure_open_ai): + mock_create_prompt.return_value = [ + { + "role": "system", + "content": "You are an AI assistant" + }, + {"role": "user", "content": "Hi."} + ] + + extraction_response = ExtractionResponse( + invoice=Invoice( + totalClaimAmount=0.0, + provider=Provider( + name="" + ), + serviceFor=ServiceFor( + name="" + ), + lineItems=[ + LineItem( + amount=0.0, + text="", + transactionType="", + serviceStartDate="", + serviceEndDate="" + ) + ] + ) + ) + mock_completion = ParsedChatCompletion( + id="id", + created=0, + model="model", + object="chat.completion", + choices=[ + ParsedChoice( + message=ParsedChatCompletionMessage( + parsed=extraction_response, + role="assistant" + ), + index=0, + finish_reason="stop" + ) + ], + usage=CompletionUsage( + completion_tokens=100, + prompt_tokens=101, + total_tokens=201 + ) + ) + mock_azure_open_ai_instance = mock_azure_open_ai.return_value + mock_azure_open_ai_instance.beta.chat.completions.parse.return_value = mock_completion + + gpt_only_extractor = GPTOnlyExtractor({ + "azure_openai_endpoint": "https://example.com", + "azure_openai_api_key": "SSSHHH", + "gpt_deployment_name": 'gpt-4o', + "temperature": 0, + "prompt_config": {'prompt_name': 'medical_claim_reimbursement', 'line_item_instructions': 'complex'} + }, mock_logger_proxy) + result = gpt_only_extractor.extract_data("BASE64_STRING") + + self.assertEqual(result, extraction_response) + mock_azure_open_ai_instance.beta.chat.completions.parse.assert_called_with( + model='gpt-4o', + temperature=0, + messages=mock_create_prompt.return_value, + response_format=ExtractionResponse + ) + mock_logger_proxy.log_metric.assert_any_call("completion_tokens", 100) + mock_logger_proxy.log_metric.assert_any_call("prompt_tokens", 101) + + @patch('src.invoice_processing.predict_component.predict.data_extraction.extractors.gpt_only_extractor.PromptManager.get_prompt') + @patch('src.invoice_processing.predict_component.predict.data_extraction.extractors.gpt_only_extractor.AzureOpenAI') + @patch('src.invoice_processing.predict_component.predict.data_extraction.extractors.gpt_only_extractor.LoggerProxy') + def test_create_prompt(self, mock_logger_proxy, mock_azure_open_ai, mock_get_prompt): + mock_get_prompt.return_value = "Extract data from this invoice" + + gpt_only_extractor = GPTOnlyExtractor({ + "azure_openai_endpoint": "https://example.com", + "azure_openai_api_key": "SSSHHH", + "gpt_deployment_name": 'gpt-4o', + "prompt_config": {'prompt_name': 'medical_claim_reimbursement', 'line_item_instructions': 'complex'} + }, mock_logger_proxy) + + base64_image = "base64_image_string" + messages = gpt_only_extractor.create_prompt(base64_image) + + self.assertEqual(messages, [ + { + "role": "system", + "content": + "You are an AI assistant that analyzes the text provided " + "and supplemented images and returns them as structured JSON objects. " + "Do not return as a code block." + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": mock_get_prompt.return_value + }, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{base64_image}"} + } + ] + } + ]) + + @patch('src.invoice_processing.predict_component.predict.data_extraction.extractors.gpt_only_extractor.AzureOpenAI') + @patch('src.invoice_processing.predict_component.predict.data_extraction.extractors.gpt_only_extractor.LoggerProxy') + @patch('src.invoice_processing.predict_component.predict.data_extraction.extractors.gpt_only_extractor.GPTOnlyExtractor.create_prompt') + def test_extract_data_with_retry(self, mock_create_prompt, mock_logger_proxy, mock_azure_open_ai): + mock_create_prompt.return_value = [ + { + "role": "system", + "content": "You are an AI assistant" + }, + {"role": "user", "content": "Hi."} + ] + + mock_azure_open_ai_instance = mock_azure_open_ai.return_value + mock_azure_open_ai_instance.beta.chat.completions.parse.side_effect = [Exception("Error"), Exception("Error"), ParsedChatCompletion( + id="id", + created=0, + model="model", + object="chat.completion", + choices=[ + ParsedChoice( + message=ParsedChatCompletionMessage( + parsed=ExtractionResponse( + invoice=Invoice( + totalClaimAmount=0.0, + provider=Provider( + name="" + ), + serviceFor=ServiceFor( + name="" + ), + lineItems=[ + LineItem( + amount=0.0, + text="", + transactionType="", + serviceStartDate="", + serviceEndDate="" + ) + ] + ) + ), + role="assistant" + ), + index=0, + finish_reason="stop" + ) + ], + usage=CompletionUsage( + completion_tokens=100, + prompt_tokens=101, + total_tokens=201 + ) + )] + + gpt_only_extractor = GPTOnlyExtractor({ + "azure_openai_endpoint": "https://example.com", + "azure_openai_api_key": "SSSHHH", + "gpt_deployment_name": 'gpt-4o', + "prompt_config": {'prompt_name': 'medical_claim_reimbursement', 'line_item_instructions': 'complex'} + }, mock_logger_proxy) + result = gpt_only_extractor.extract_data("BASE64_STRING") + + self.assertIsNotNone(result) + self.assertIsInstance(result, ExtractionResponse) + self.assertEqual(mock_azure_open_ai_instance.beta.chat.completions.parse.call_count, 3) + mock_logger_proxy.log_metric.assert_any_call("completion_tokens", 100) + mock_logger_proxy.log_metric.assert_any_call("prompt_tokens", 101) + +if __name__ == '__main__': + unittest.main() diff --git a/test/invoice_processing/predict_component/predict/data_extraction/models/test_extraction_response.py b/test/invoice_processing/predict_component/predict/data_extraction/models/test_extraction_response.py new file mode 100644 index 00000000..8e4d2b5b --- /dev/null +++ b/test/invoice_processing/predict_component/predict/data_extraction/models/test_extraction_response.py @@ -0,0 +1,54 @@ +import unittest +from pydantic import ValidationError + +from src.invoice_processing.predict_component.predict.data_extraction.models.extraction_response import ( + ExtractionResponse +) + + +class TestExtractionResponse(unittest.TestCase): + def setUp(self): + self.valid_line_item = { + "amount": 100.0, + "text": "Consultation", + "transactionType": "Service", + "serviceStartDate": "2023-01-01", + "serviceEndDate": "2023-01-02" + } + self.valid_invoice = { + "totalClaimAmount": 100.0, + "provider": { + "name": "Provider A" + }, + "serviceFor": { + "name": "Patient A" + }, + "lineItems": [self.valid_line_item] + } + self.valid_invoice_data = { + "invoice": self.valid_invoice + } + + def test_valid_invoice_data(self): + invoice_data = ExtractionResponse(**self.valid_invoice_data) + self.assertEqual(invoice_data.invoice.totalClaimAmount, 100.0) + self.assertEqual(invoice_data.invoice.provider.name, "Provider A") + self.assertEqual(invoice_data.invoice.serviceFor.name, "Patient A") + self.assertEqual(len(invoice_data.invoice.lineItems), 1) + self.assertEqual(invoice_data.invoice.lineItems[0].amount, 100.0) + + def test_invalid_invoice_data(self): + invalid_invoice_data = self.valid_invoice_data.copy() + invalid_invoice_data["invoice"]["totalClaimAmount"] = "invalid_amount" + with self.assertRaises(ValidationError): + ExtractionResponse(**invalid_invoice_data) + + def test_missing_required_field(self): + invalid_invoice_data = self.valid_invoice_data.copy() + del invalid_invoice_data["invoice"]["provider"]["name"] + with self.assertRaises(ValidationError): + ExtractionResponse(**invalid_invoice_data) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/invoice_processing/predict_component/predict/data_extraction/prompts/test_prompt_manager.py b/test/invoice_processing/predict_component/predict/data_extraction/prompts/test_prompt_manager.py new file mode 100644 index 00000000..bd0414ea --- /dev/null +++ b/test/invoice_processing/predict_component/predict/data_extraction/prompts/test_prompt_manager.py @@ -0,0 +1,45 @@ +import unittest +from unittest.mock import patch, mock_open + +from src.invoice_processing.predict_component.predict.data_extraction.prompts.prompt_manager import PromptManager + + +class TestPromptManager(unittest.TestCase): + @patch('builtins.open', new_callable=mock_open, + read_data='---\ndescription: Test template\nauthor: Test Author\n---\nHello, {{ name }}!') + @patch('src.invoice_processing.predict_component.predict.data_extraction.prompts.prompt_manager.FileSystemLoader.get_source') + def test_get_prompt(self, mock_get_source, mock_file): + mock_get_source.return_value = ('template content', 'template/path', lambda: True) + result = PromptManager.get_prompt('test_template', name='World') + self.assertEqual(result, 'Hello, World!') + + @patch('builtins.open', new_callable=mock_open, + read_data='---\ndescription: Test template\nauthor: Test Author\n---\nHello, {{ name }}!') + @patch('src.invoice_processing.predict_component.predict.data_extraction.prompts.prompt_manager.FileSystemLoader.get_source') + def test_get_prompt_template_error(self, mock_get_source, mock_file): + mock_get_source.return_value = ('template content', 'template/path', lambda: True) + with self.assertRaises(ValueError) as context: + PromptManager.get_prompt('test_template') + self.assertIn('Error rendering template', str(context.exception)) + + @patch('builtins.open', new_callable=mock_open, + read_data='---\ndescription: Test template\nauthor: Test Author\n---\nHello, {{ name }}!') + @patch('src.invoice_processing.predict_component.predict.data_extraction.prompts.prompt_manager.FileSystemLoader.get_source') + def test_get_template_info(self, mock_get_source, mock_file): + mock_get_source.return_value = ('template content', 'template/path', lambda: True) + result = PromptManager.get_template_info('test_template') + expected_result = { + 'name': 'test_template', + 'description': 'Test template', + 'author': 'Test Author', + 'variables': ['name'], + 'frontmatter': { + 'description': 'Test template', + 'author': 'Test Author' + } + } + self.assertEqual(result, expected_result) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/invoice_processing/predict_component/predict/data_extraction/test_configuration_container.py b/test/invoice_processing/predict_component/predict/data_extraction/test_configuration_container.py new file mode 100644 index 00000000..fa6b24ae --- /dev/null +++ b/test/invoice_processing/predict_component/predict/data_extraction/test_configuration_container.py @@ -0,0 +1,29 @@ +import unittest +from unittest.mock import patch, mock_open + +from src.invoice_processing.predict_component.predict.data_extraction.config.configuration_container import ( + ConfigurationContainer +) + + +class TestConfigurationContainer(unittest.TestCase): + + def setUp(self): + # Clear the config registry before each test + ConfigurationContainer._config_registry = {} + + def test_register_and_get_config(self): + config = {"key": "value"} + ConfigurationContainer.register_config("extractor1", config) + retrieved_config = ConfigurationContainer.get_config("extractor1") + self.assertEqual(retrieved_config, config) + + def test_get_config_not_registered(self): + retrieved_config = ConfigurationContainer.get_config("non_existent_extractor") + self.assertEqual(retrieved_config, {}) + + @patch("builtins.open", new_callable=mock_open, read_data='{"extractor2": {"key": "value"}}') + @patch("json.load", return_value={"extractor2": {"key": "value"}}) + def test_load_configs_from_file(self, mock_json_load, mock_open): + ConfigurationContainer.load_configs_from_file("dummy_path") + self.assertEqual(ConfigurationContainer._config_registry["extractor2"], {"key": "value"}) diff --git a/test/invoice_processing/predict_component/predict/data_extraction/test_extractor_factory.py b/test/invoice_processing/predict_component/predict/data_extraction/test_extractor_factory.py new file mode 100644 index 00000000..6d1056e7 --- /dev/null +++ b/test/invoice_processing/predict_component/predict/data_extraction/test_extractor_factory.py @@ -0,0 +1,77 @@ +import os +import unittest +from unittest.mock import MagicMock + +from src.invoice_processing.predict_component.predict.data_extraction.config.configuration_container import ( + ConfigurationContainer +) +from src.invoice_processing.predict_component.predict.data_extraction.data_extractor_factory import ( + DataExtractorFactory +) +from src.invoice_processing.predict_component.predict.data_extraction.extractors.gpt_only_extractor import ( + GPTOnlyExtractor +) +from test.invoice_processing.predict_component.predict.data_extraction.assets.mock_extractor import MockExtractor + + +class TestDataExtractorFactory(unittest.TestCase): + def setUp(self): + self.assets_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") + config_path = os.path.join(self.assets_path, "config.json") + ConfigurationContainer.load_configs_from_file(config_path) + self.additional_config = { + "prompt_config": { + "prompt_name": "medical_claim_reimbursement", + "line_item_instructions": "complex" + } + } + self.logger = MagicMock() + DataExtractorFactory.register("mockextractor", "mock", MockExtractor) + + def test_load_default_extractors(self): + DataExtractorFactory.load_default_extractors() + extractor = DataExtractorFactory.create("invoice", + "gpt_only", + self.additional_config, + self.logger) + + self.assertIsInstance(extractor, GPTOnlyExtractor) + + def test_create_extractor(self): + extractor = DataExtractorFactory.create("mock", + "mockextractor", + self.additional_config, + self.logger) + + self.assertIsInstance(extractor, MockExtractor) + + resp = extractor.extract_data("") + self.assertEqual(resp.invoice.provider.name, "Mock Provider") + self.assertEqual(resp.invoice.serviceFor.name, "Mock Patient") + + def test_create_extractor_invalid_category(self): + with self.assertRaises(ValueError): + DataExtractorFactory.create("invalid", "mock_ocr_extractor", + self.additional_config, + self.logger) + + def test_create_extractor_invalid_name(self): + with self.assertRaises(ValueError): + DataExtractorFactory.create("mock", "invalid_extractor", + self.additional_config, + self.logger) + + def test_list_categories(self): + categories = DataExtractorFactory.list_categories() + self.assertIn("mock", categories) + + def test_list_extractors(self): + extractors = DataExtractorFactory.list_extractors("mock") + self.assertIn("mockextractor", extractors) + + def test_register_invalid_extractor(self): + with self.assertRaises(ValueError): + DataExtractorFactory.register("invalid_extractor", "mock", object) + +if __name__ == '__main__': + unittest.main() diff --git a/test/invoice_processing/predict_component/predict/test_helpers.py b/test/invoice_processing/predict_component/predict/test_helpers.py new file mode 100644 index 00000000..3e2571d1 --- /dev/null +++ b/test/invoice_processing/predict_component/predict/test_helpers.py @@ -0,0 +1,33 @@ +import base64 +import unittest +from unittest.mock import patch, mock_open +import json + +from src.invoice_processing.predict_component.predict.helpers import save_output_as_json, convert_image_to_base64 + + +class TestPredictOrchestratorHelpers(unittest.TestCase): + @patch("builtins.open", new_callable=mock_open) + def test_save_output_as_json(self, mock_file): + output = {"key": "value"} + output_file_path = "test_output.json" + + save_output_as_json(output, output_file_path) + + mock_file.assert_called_once_with(output_file_path, 'w', encoding='utf-8') + handle = mock_file() + written_content = ''.join(call.args[0] for call in handle.write.call_args_list) + self.assertEqual(written_content, json.dumps(output, ensure_ascii=False, indent=4)) + + @patch("builtins.open", new_callable=mock_open, read_data=b"fake_image_data") + def test_convert_image_to_base64(self, mock_file): + image_path = "test_image.png" + + result = convert_image_to_base64(image_path) + + mock_file.assert_called_once_with(image_path, "rb") + self.assertEqual(result, base64.b64encode(b"fake_image_data").decode('utf-8')) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/invoice_processing/predict_component/predict/test_predict.py b/test/invoice_processing/predict_component/predict/test_predict.py new file mode 100644 index 00000000..9298ead5 --- /dev/null +++ b/test/invoice_processing/predict_component/predict/test_predict.py @@ -0,0 +1,107 @@ +import os +import unittest +from unittest.mock import ANY, patch, MagicMock +from src.invoice_processing.predict_component.predict.data_extraction.models.extraction_response import ( + ExtractionResponse, + Invoice, + LineItem, + Provider, + ServiceFor +) + +from src.invoice_processing.predict_component.predict.predict import predict, main, process + + +class TestPredictFunctions(unittest.TestCase): + @patch('src.invoice_processing.predict_component.predict.data_extraction.data_extractor_factory.DataExtractorFactory.create') + @patch('os.makedirs') + @patch('src.invoice_processing.predict_component.predict.predict.glob_by_extesion') + @patch('src.invoice_processing.predict_component.predict.predict.MLFlowLogger') + @patch('src.invoice_processing.predict_component.predict.predict.mlflow') + @patch('src.invoice_processing.predict_component.predict.predict.process') + @patch('pandas.DataFrame.to_csv') + def test_predict(self, mock_to_csv, mock_process, mock_mlflow, mock_logger, mock_glob, mock_makedirs, mock_factory_create): + mock_extractor = MagicMock() + mock_factory_create.return_value = mock_extractor + mock_process.return_value = ExtractionResponse( + invoice=Invoice( + totalClaimAmount=0.0, + provider=Provider( + name="" + ), + serviceFor=ServiceFor( + name="" + ), + lineItems=[ + LineItem( + amount=0.0, + text="", + transactionType="", + serviceStartDate="", + serviceEndDate="" + ) + ] + ) + ) + mock_glob.return_value = ['file1.png', 'file2.jpg'] + azure_openai_endpoint = "https://example.com" + azure_openai_api_key = "test_api_key" + + predict('gpt_only', 0, 'gpt-4o', azure_openai_endpoint, azure_openai_api_key, + "{'prompt_name':'medical_claim_reimbursement','line_item_instructions':'complex'}", + 'test_data', 'prediction_path') + + mock_mlflow.log_params.assert_any_call({ + "gpt_deployment_name": "gpt-4o", + "temperature": 0, + "prompt_name": "medical_claim_reimbursement", + "line_item_instructions": "complex" + }) + + mock_factory_create.assert_called_once_with('invoice', 'gpt_only', { + "azure_openai_endpoint": azure_openai_endpoint, + "azure_openai_api_key": azure_openai_api_key, + "gpt_deployment_name": 'gpt-4o', + "temperature": 0, + "prompt_config": {'prompt_name': 'medical_claim_reimbursement', 'line_item_instructions': 'complex'} + }, ANY) + mock_makedirs.assert_called_once_with('prediction_path', exist_ok=True) + self.assertEqual(mock_process.call_count, 2) + + @patch('src.invoice_processing.predict_component.predict.predict.convert_image_to_base64') + @patch('src.invoice_processing.predict_component.predict.predict.save_output_as_json') + @patch('src.invoice_processing.predict_component.predict.predict.Extractor') + def test_process(self, mock_extractor, mock_save_output_as_json, mock_convert_image_to_base64): + mock_convert_image_to_base64.return_value = "IMAGINE_I_AM_BASE64" + mock_extractor.extract_data.return_value = ExtractionResponse( + invoice=Invoice( + provider=Provider( + name="Bob" + ), + serviceFor=ServiceFor( + name="Greg" + ), + lineItems=[], + totalClaimAmount=0.99 + ) + ) + input_path = 'file1.png' + output_path = 'output_path' + process(mock_extractor, input_path, output_path) + mock_extractor.extract_data.assert_called_once_with("IMAGINE_I_AM_BASE64") + output_file_path = os.path.join(output_path, "file1_result.json") + mock_save_output_as_json.assert_called_once_with(mock_extractor.extract_data.return_value.model_dump(), output_file_path) + + @patch('src.invoice_processing.predict_component.predict.predict.predict') + def test_main(self, mock_predict): + main('gpt_only',0 , 'gpt-4o', "https://example.com", "test_api_key", + "{'prompt_name':'claim_reimbursement','line_item_instructions':'complex'}", + 'test_data', 'prediction_path') + + mock_predict.assert_called_once_with('gpt_only', 0,'gpt-4o', "https://example.com", "test_api_key", + "{'prompt_name':'claim_reimbursement','line_item_instructions':'complex'}", + 'test_data', 'prediction_path') + + +if __name__ == '__main__': + unittest.main() diff --git a/test/invoice_processing/score_component/test_experiment_config.yaml b/test/invoice_processing/score_component/test_experiment_config.yaml new file mode 100644 index 00000000..c6f28d10 --- /dev/null +++ b/test/invoice_processing/score_component/test_experiment_config.yaml @@ -0,0 +1,15 @@ +score_config: + fuzzy_match_config: + field_match_threshold: 0.0 + fuzzy_compare_methods: + levenshtein: true + exact_match_fields: + start_date_match: true + end_date_match: true + amount_match: true + find_best_matches_strategy: levenshtein + matchers_dict: + serviceStartDate: date_exact_match + serviceEndDate: date_exact_match + amount: amount_exact_match + description: description_levenshtein diff --git a/test/invoice_processing/score_component/test_extraction_evaluator.py b/test/invoice_processing/score_component/test_extraction_evaluator.py new file mode 100644 index 00000000..6afbf9b8 --- /dev/null +++ b/test/invoice_processing/score_component/test_extraction_evaluator.py @@ -0,0 +1,119 @@ +""" +Unit tests for functions of the ExtractionEvaluator class in the experimentation framework +""" + +import unittest +import yaml +import pandas as pd +from src.invoice_processing.score_component.score.score import ( + create_extraction_evaluator, + get_score_config, +) +from src.invoice_processing.score_component.score.matchers.levenshtein_matcher import ( + LevenshteinMatcher, +) +from src.invoice_processing.score_component.score.matchers.text_exact_matcher import ( + TextExactMatcher, +) + + +class TestExtractionEvaluator(unittest.TestCase): + """ + Test extraction_evaluator.py + """ + + def __init__(self, methodName="runTest"): + super().__init__(methodName) + self.score_config = str( + yaml.safe_load( + open( + "test/invoice_processing/score_component/test_experiment_config.yaml" + ) + )["score_config"] + ) + + def setup_datasets(self): + """ + Setup ground truth data and a corresponding predictions data, + whose fields are a perfect match. + Returns: + ground_truth_df: ground truth dataframe + pred_df: predictions dataframe + """ + + ground_truth_df = pd.DataFrame( + { + "gt_index": [0, 1, 2, 3], + "serviceStartDate": ["1/3/24", "12/29/23", "1/4/24", "1/10/24"], + "serviceEndDate": ["1/9/24", "12/30/23", "1/4/24", "1/12/24"], + "amount": [134, 324, 78, 200], + "description": [ + "Child care service", + "After school program", + "Learning center", + "Swimming lessons", + ], + } + ) + pred_df = pd.DataFrame( + { + "pred_index": [0, 1, 2, 3], + "serviceStartDate": ["12/29/23", "1/1/24", "1/4/24", ""], + "serviceEndDate": ["12/30/23", "1/9/24", "1/4/24", "3/1/24"], + "amount": [324, 134, 76, 100], + "description": [ + "After school program", + "Child care", + "Learning center", + "Summer camp", + ], + "miles": [None, None, None, None], + } + ) + return ground_truth_df, pred_df + + def test_find_best_matches_levenshtein(self): + """ + Test find_best_matches function that is meant to find the best + matches from the predictions data to the ground truth data. + """ + score_config_dict = get_score_config(self.score_config) + fuzzy_match_config = score_config_dict["fuzzy_match_config"] + ground_truth_df, pred_df = self.setup_datasets() + evaluator = create_extraction_evaluator(self.score_config) + comparison_df = evaluator.compare_line_item_values_per_invoice( + ground_truth_df, pred_df + ) + best_matches_dict = LevenshteinMatcher().find_best_matches( + comparison_df, fuzzy_match_config + ) + best_matches_df = pd.DataFrame(best_matches_dict["levenshtein"]) + matches_indices = list( + zip(best_matches_df["gt_index"], best_matches_df["pred_index"]) + ) + self.assertTrue((0.0, 1.0) in matches_indices) + self.assertTrue((1.0, 0.0) in matches_indices) + self.assertTrue((2.0, 2.0) in matches_indices) + + def test_find_best_matches_base_exact_matcher(self): + """ + Test find_best_matches function that is meant to find the best + matches from the predictions data to the ground truth data. + """ + ground_truth_df, pred_df = self.setup_datasets() + evaluator = create_extraction_evaluator(self.score_config) + comparison_df = evaluator.compare_line_item_values_per_invoice( + ground_truth_df, pred_df + ) + best_matches_dict = TextExactMatcher().find_best_matches(comparison_df) + best_matches_df = pd.DataFrame(best_matches_dict["exact_match"]) + matches_indices = list( + zip(best_matches_df["gt_index"], best_matches_df["pred_index"]) + ) + self.assertTrue((0.0, 1.0) in matches_indices) + self.assertTrue((1.0, 0.0) in matches_indices) + self.assertTrue((2.0, 2.0) in matches_indices) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/invoice_processing/score_component/test_score.py b/test/invoice_processing/score_component/test_score.py new file mode 100644 index 00000000..19290751 --- /dev/null +++ b/test/invoice_processing/score_component/test_score.py @@ -0,0 +1,457 @@ +""" +Unit tests for the evaluation step in the experimentation framework +""" + +import unittest +import yaml +import pandas as pd +from src.invoice_processing.score_component.score.score import ( + evaluate, + get_gt_and_pred_data_for_evaluation, +) + + +class TestScore(unittest.TestCase): + """ + Test evaluate output for perfect match + """ + + def __init__(self, methodName="runTest"): + super().__init__(methodName) + self.score_config = str( + yaml.safe_load( + open( + "test/invoice_processing/score_component/test_experiment_config.yaml" + ) + )["score_config"] + ) + + def test_evaluate_perfect_match(self): + + ground_truth = [ + { + "reference_id": "12345", + "lineItems": [ + { + "serviceStartDate": "12/29/23", + "serviceEndDate": "12/30/23", + "amount": 324, + "description": "Child care", + }, + { + "serviceStartDate": "1/1/24", + "serviceEndDate": "3/1/24", + "amount": 134, + "description": "After school program", + }, + { + "serviceStartDate": "4/1/24", + "serviceEndDate": "4/1/24", + "amount": 76, + "description": "Learning center", + }, + ], + } + ] + + pred = { + "12345.jpg": { + "invoice": { + "lineItems": [ + { + "serviceStartDate": "12/29/23", + "serviceEndDate": "12/30/23", + "amount": 324, + "text": "Child care", + "miles": None, + }, + { + "serviceStartDate": "1/1/24", + "serviceEndDate": "3/1/24", + "amount": 134, + "text": "After school program", + "miles": None, + }, + { + "serviceStartDate": "4/1/24", + "serviceEndDate": "4/1/24", + "amount": 76, + "text": "Learning center", + "miles": None, + }, + ], + } + } + } + + ( + final_results_df, + overall_accuracy, + gt_invoices_number, + pred_invoices_number, + all_unmatched_gt, + all_unmatched_pred, + comparison_df_all, + best_matches_all, + all_matches_results_total, + overall_precision, + overall_recall, + ) = evaluate(pred, ground_truth, self.score_config) + self.assertEqual(overall_accuracy, 1.0) + self.assertEqual(overall_precision, 1.0) + self.assertEqual(overall_recall, 1.0) + + def test_evaluate_partial_match(self): + """ + Test evaluate output for partial match + """ + ground_truth = [ + { + "reference_id": "12345", + "lineItems": [ + { + "serviceStartDate": "12/26/23", + "serviceEndDate": "12/30/23", + "amount": 324, + "description": "Child care", + }, + { + "serviceStartDate": "1/1/24", + "serviceEndDate": "3/1/24", + "amount": 134, + "description": "Tuition", + }, + { + "serviceStartDate": "4/1/24", + "serviceEndDate": "4/1/24", + "amount": 76, + "description": "Learning center", + }, + ], + } + ] + + pred = { + "12345.jpg": { + "invoice": { + "lineItems": [ + { + "serviceStartDate": "12/29/23", + "serviceEndDate": "12/30/23", + "amount": 324, + "text": "Child care", + "miles": None, + }, + { + "serviceStartDate": "1/1/24", + "serviceEndDate": "3/1/24", + "amount": 134, + "text": "After school program", + "miles": None, + }, + { + "serviceStartDate": "4/1/24", + "serviceEndDate": "4/1/24", + "amount": 76, + "text": "Learning center", + "miles": None, + }, + { + "serviceStartDate": "12/30/23", + "serviceEndDate": "12/31/23", + "amount": 267, + "text": "Emergency room", + "miles": None, + }, + ], + } + } + } + + ( + final_results_df, + overall_accuracy, + gt_invoices_number, + pred_invoices_number, + all_unmatched_gt, + all_unmatched_pred, + comparison_df_all, + best_matches_all, + all_matches_results_total, + overall_precision, + overall_recall, + ) = evaluate(pred, ground_truth, self.score_config) + self.assertEqual(round(overall_accuracy, 3), 0.634) + + def test_evaluate_no_match(self): + """ + Test evaluate output for no match + """ + ground_truth = [ + { + "reference_id": "12345", + "lineItems": [ + { + "serviceStartDate": "12/26/23", + "serviceEndDate": "12/30/23", + "amount": 324, + "description": "Child care", + }, + { + "serviceStartDate": "1/1/24", + "serviceEndDate": "3/1/24", + "amount": 134, + "description": "Tuition", + }, + { + "serviceStartDate": "4/1/24", + "serviceEndDate": "4/1/24", + "amount": 76, + "description": "Learning center", + }, + ], + } + ] + + pred = { + "12345.jpg": { + "invoice": { + "lineItems": [ + { + "serviceStartDate": "", + "serviceEndDate": "12/31/23", + "amount": 267, + "text": "Emergency room", + "miles": None, + }, + { + "serviceStartDate": "2/1/24", + "serviceEndDate": "", + "amount": 152, + "text": "After school program", + "miles": None, + }, + { + "serviceStartDate": "", + "serviceEndDate": "5/1/24", + "amount": 74, + "text": "Medicines", + "miles": None, + }, + ] + } + } + } + + ( + final_results_df, + overall_accuracy, + gt_invoices_number, + pred_invoices_number, + all_unmatched_gt, + all_unmatched_pred, + comparison_df_all, + best_matches_all, + all_matches_results_total, + overall_precision, + overall_recall, + ) = evaluate(pred, ground_truth, self.score_config) + self.assertEqual(round(overall_accuracy, 3), 0.07) + self.assertEqual(overall_precision, 1.0) + self.assertEqual(overall_recall, 1.0) + + def test_evaluate_partial_match_for_recall(self): + """ + Test evaluate output for partial match, for recall. + """ + ground_truth = [ + { + "reference_id": "12345", + "lineItems": [ + { + "serviceStartDate": "12/26/23", + "serviceEndDate": "12/30/23", + "amount": 324, + "description": "Child care", + }, + { + "serviceStartDate": "1/1/24", + "serviceEndDate": "3/1/24", + "amount": 134, + "description": "Tuition", + }, + { + "serviceStartDate": "4/1/24", + "serviceEndDate": "4/1/24", + "amount": 76, + "description": "Learning center", + }, + { + "serviceStartDate": "4/5/24", + "serviceEndDate": "4/7/24", + "amount": 94, + "description": "Lunch fee", + }, + ], + } + ] + + pred = { + "12345.jpg": { + "InvoiceDetails": { + "lineItems": [ + { + "serviceStartDate": "12/29/23", + "serviceEndDate": "12/30/23", + "amount": 324, + "description": "Child care", + "miles": None, + }, + { + "serviceStartDate": "1/1/24", + "serviceEndDate": "3/1/24", + "amount": 134, + "description": "After school program", + "miles": None, + }, + { + "serviceStartDate": "4/1/24", + "serviceEndDate": "4/1/24", + "amount": 76, + "description": "Learning center", + "miles": None, + }, + ], + } + } + } + + ( + final_results_df, + overall_accuracy, + gt_invoices_number, + pred_invoices_number, + all_unmatched_gt, + all_unmatched_pred, + comparison_df_all, + best_matches_all, + all_matches_results_total, + overall_precision, + overall_recall, + ) = evaluate(pred, ground_truth, self.score_config) + self.assertEqual(round(overall_recall, 3), 0.75) + + def test_evaluate_partial_match_for_precision(self): + """ + Test evaluate output for partial match, for precision. + """ + ground_truth = [ + { + "reference_id": "12345", + "lineItems": [ + { + "serviceStartDate": "12/26/23", + "serviceEndDate": "12/30/23", + "amount": 324, + "description": "Child care", + }, + { + "serviceStartDate": "1/1/24", + "serviceEndDate": "3/1/24", + "amount": 134, + "description": "Tuition", + }, + { + "serviceStartDate": "4/1/24", + "serviceEndDate": "4/1/24", + "amount": 76, + "description": "Learning center", + }, + ], + } + ] + + pred = { + "12345.jpg": { + "InvoiceDetails": { + "lineItems": [ + { + "serviceStartDate": "12/29/23", + "serviceEndDate": "12/30/23", + "amount": 324, + "description": "Child care", + "miles": None, + }, + { + "serviceStartDate": "1/1/24", + "serviceEndDate": "3/1/24", + "amount": 134, + "description": "After school program", + "miles": None, + }, + { + "serviceStartDate": "4/1/24", + "serviceEndDate": "4/1/24", + "amount": 76, + "description": "Learning center", + "miles": None, + }, + { + "serviceStartDate": "9/3/24", + "serviceEndDate": "9/3/24", + "amount": 100, + "description": "Registration fee", + "miles": None, + }, + ], + } + } + } + + ( + final_results_df, + overall_accuracy, + gt_invoices_number, + pred_invoices_number, + all_unmatched_gt, + all_unmatched_pred, + comparison_df_all, + best_matches_all, + all_matches_results_total, + overall_precision, + overall_recall, + ) = evaluate(pred, ground_truth, self.score_config) + self.assertEqual(round(overall_precision, 3), 0.75) + + def test_get_gt_and_pred_data_for_evaluation(self): + ground_truth = { + "reference_id": "12345.jpg", + "lineItems": [ + { + "description": "Dependent care", + "amount": 120, + "serviceStartDate": "07/07/2021", + "serviceEndDate": "07/23/2021", + }, + { + "description": "Lunch fee", + "amount": 64, + "serviceStartDate": "07/07/2021", + "serviceEndDate": "07/23/2021", + }, + ], + } + + predictions = { + "invoice": { + "lineItems": [], + } + } + gt_data, pred_data = get_gt_and_pred_data_for_evaluation( + ground_truth, predictions + ) + self.assertTrue("miles" not in pred_data.columns.tolist()) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/invoice_processing/score_component/test_utils.py b/test/invoice_processing/score_component/test_utils.py new file mode 100644 index 00000000..ed5f1f64 --- /dev/null +++ b/test/invoice_processing/score_component/test_utils.py @@ -0,0 +1,22 @@ +""" +Unit tests for utils.py of the evaluation step of the experimentation framework. +""" + +import unittest +import pandas as pd +from src.invoice_processing.score_component.score.utils import normalize_string + + +class TestScore(unittest.TestCase): + + def test_normalize_str(self): + """ + Test normalize_str. + """ + value = " ( vaLue ) " + normalized_str = normalize_string(value) + self.assertEqual(normalized_str, "(value)") + + +if __name__ == "__main__": + unittest.main() From e85975a15722dc7b7e29b22b7d8f45f6accb4510 Mon Sep 17 00:00:00 2001 From: Martyna Marcinkowska Date: Fri, 6 Jun 2025 12:47:06 +0100 Subject: [PATCH 02/21] add ci pipeline --- .../invoice_processing_ci_pipeline.yml | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 .github/workflows/invoice_processing_ci_pipeline.yml diff --git a/.github/workflows/invoice_processing_ci_pipeline.yml b/.github/workflows/invoice_processing_ci_pipeline.yml new file mode 100644 index 00000000..85e88439 --- /dev/null +++ b/.github/workflows/invoice_processing_ci_pipeline.yml @@ -0,0 +1,34 @@ +name: Invoice Processing CI Workflow + +on: + pull_request: + branches: + - main + - develop + paths-ignore: + - 'docs/**' + - '**.md' + + workflow_dispatch: + inputs: + exec_environment: + type: string + description: "The environment to run the workflow in" + required: true + default: "pr" + model_type: + type: string + description: "The type of model to run the workflow for" + required: true + default: "invoice_processing" +permissions: + id-token: write + contents: read + +jobs: + run-ci-workflow: + uses: ./.github/workflows/platform_ci_workflow.yml + with: + exec_environment: ${{ inputs.exec_environment || 'pr' }} + model_type: ${{ inputs.model_type || 'invoice_processing' }} + secrets: inherit From 6e23d4a093d0c36e25d51e44b12f8082c0953b98 Mon Sep 17 00:00:00 2001 From: Martyna Marcinkowska Date: Fri, 6 Jun 2025 13:38:37 +0100 Subject: [PATCH 03/21] fix linter --- .../invoice_processing/src/mlops_pipeline.py | 1 - .../config/configuration_container.py | 1 + .../data_extraction/data_extractor_factory.py | 2 +- .../extractors/base_extractor.py | 6 +- .../extractors/gpt_only_extractor.py | 5 +- .../models/extraction_response.py | 6 ++ .../data_extraction/prompts/prompt_manager.py | 6 ++ .../predict_component/predict/helpers.py | 1 + .../predict/mlflow_logger.py | 4 + .../predict_component/predict/predict.py | 9 +- .../prep_component/prep/prep.py | 4 +- .../score/extraction_evaluator.py | 25 +++-- .../score/matchers/amount_exact_matcher.py | 18 ++-- .../score/matchers/base_matcher.py | 9 +- .../score/matchers/date_exact_matcher.py | 12 +-- .../score/matchers/levenshtein_matcher.py | 13 ++- .../score/matchers/text_exact_matcher.py | 16 ++-- .../score_component/score/score.py | 39 ++++---- .../score_component/score/utils.py | 91 +++++++++---------- 19 files changed, 143 insertions(+), 125 deletions(-) diff --git a/mlops/invoice_processing/src/mlops_pipeline.py b/mlops/invoice_processing/src/mlops_pipeline.py index 6d7ab4ef..a8c17980 100644 --- a/mlops/invoice_processing/src/mlops_pipeline.py +++ b/mlops/invoice_processing/src/mlops_pipeline.py @@ -154,7 +154,6 @@ def construct_pipeline(self, ml_client): Returns: pipeline_job: The constructed pipeline job components. """ - registered_data_asset = ml_client.data.get( name=self.dataset_name, label="latest" ) diff --git a/src/invoice_processing/predict_component/predict/data_extraction/config/configuration_container.py b/src/invoice_processing/predict_component/predict/data_extraction/config/configuration_container.py index 174871eb..f4e3804b 100644 --- a/src/invoice_processing/predict_component/predict/data_extraction/config/configuration_container.py +++ b/src/invoice_processing/predict_component/predict/data_extraction/config/configuration_container.py @@ -1,5 +1,6 @@ class ConfigurationContainer: """A simple service container to store configurations for extractors.""" + _config_registry = {} @classmethod diff --git a/src/invoice_processing/predict_component/predict/data_extraction/data_extractor_factory.py b/src/invoice_processing/predict_component/predict/data_extraction/data_extractor_factory.py index b46880ca..eda583aa 100644 --- a/src/invoice_processing/predict_component/predict/data_extraction/data_extractor_factory.py +++ b/src/invoice_processing/predict_component/predict/data_extraction/data_extractor_factory.py @@ -9,6 +9,7 @@ class DataExtractorFactory: @classmethod def load_default_extractors(cls) -> None: + """Load default extractors into the factory registry.""" from .extractors.gpt_only_extractor import GPTOnlyExtractor cls.register("gpt_only", "invoice", GPTOnlyExtractor) @@ -39,7 +40,6 @@ def list_extractors(cls, category: str) -> list[str]: def create(cls, category: str, name: str, additional_config: dict, logger_proxy: LoggerProxy) -> Extractor: """Create an instance of extractor by category and name.""" - if (category not in cls._registry or name not in cls._registry[category]): raise ValueError( f"Extractor {name} in category {category} is not registered" diff --git a/src/invoice_processing/predict_component/predict/data_extraction/extractors/base_extractor.py b/src/invoice_processing/predict_component/predict/data_extraction/extractors/base_extractor.py index 699059c2..692e4c3c 100644 --- a/src/invoice_processing/predict_component/predict/data_extraction/extractors/base_extractor.py +++ b/src/invoice_processing/predict_component/predict/data_extraction/extractors/base_extractor.py @@ -1,4 +1,4 @@ -"""This class is an interface """ +"""This class is an interface.""" from abc import ABC, abstractmethod from ..models.extraction_response import ( @@ -7,8 +7,11 @@ class LoggerProxy(ABC): + """Abstract class to define logging functionalities.""" + @abstractmethod def log_metric(self, key: str, value: float) -> None: + """Log a metric with a key and value.""" pass @@ -16,6 +19,7 @@ class Extractor(ABC): """Abstract class to define extractor functionalities and data.""" def __init__(self, config: dict, logger_proxy: LoggerProxy): + """Initialize the Extractor with the provided configuration and logger.""" self.config = config self.logger_proxy = logger_proxy diff --git a/src/invoice_processing/predict_component/predict/data_extraction/extractors/gpt_only_extractor.py b/src/invoice_processing/predict_component/predict/data_extraction/extractors/gpt_only_extractor.py index a98f7f5e..6815b2c8 100644 --- a/src/invoice_processing/predict_component/predict/data_extraction/extractors/gpt_only_extractor.py +++ b/src/invoice_processing/predict_component/predict/data_extraction/extractors/gpt_only_extractor.py @@ -1,3 +1,4 @@ +"""This module implements the GPTOnlyExtractor class for extracting data using Azure OpenAI.""" import json import logging from python_retry import retry @@ -21,7 +22,8 @@ class GPTOnlyExtractor(Extractor): """ - Extraction implementation use Azure OpenAI Model + Extraction implementation use Azure OpenAI Model. + Args: config (dict): The configuration dictionary. the following values are expected: - azure_openai_endpoint (str): Azure OpenAI endpoint. @@ -37,6 +39,7 @@ class GPTOnlyExtractor(Extractor): """ def __init__(self, config: dict, logger_proxy: LoggerProxy): + """Initialize the GPTOnlyExtractor with the provided configuration and logger.""" self.client = AzureOpenAI( azure_endpoint=config.get("azure_openai_endpoint"), api_key=config.get("azure_openai_api_key"), diff --git a/src/invoice_processing/predict_component/predict/data_extraction/models/extraction_response.py b/src/invoice_processing/predict_component/predict/data_extraction/models/extraction_response.py index 13f9007e..64634bad 100644 --- a/src/invoice_processing/predict_component/predict/data_extraction/models/extraction_response.py +++ b/src/invoice_processing/predict_component/predict/data_extraction/models/extraction_response.py @@ -1,9 +1,11 @@ +"""Defines the data models for the extraction response in a structured format.""" from pydantic import BaseModel from typing import Optional, List class LineItem(BaseModel): """Represents a line item in a structured format.""" + amount: float text: str transactionType: str # noqa: N815 @@ -14,16 +16,19 @@ class LineItem(BaseModel): class Provider(BaseModel): """Represents a provider in a structured format.""" + name: str class ServiceFor(BaseModel): """Represents the person the service was provided for in a structured format.""" + name: str class Invoice(BaseModel): """Represents an invoice in a structured format.""" + totalClaimAmount: float # noqa: N815 provider: Provider serviceFor: ServiceFor # noqa: N815 @@ -32,5 +37,6 @@ class Invoice(BaseModel): class ExtractionResponse(BaseModel): """Represents extracted data in a structured format.""" + invoice: Invoice metadata: Optional[dict] = None diff --git a/src/invoice_processing/predict_component/predict/data_extraction/prompts/prompt_manager.py b/src/invoice_processing/predict_component/predict/data_extraction/prompts/prompt_manager.py index cdd0bb1b..c59647cc 100644 --- a/src/invoice_processing/predict_component/predict/data_extraction/prompts/prompt_manager.py +++ b/src/invoice_processing/predict_component/predict/data_extraction/prompts/prompt_manager.py @@ -1,13 +1,17 @@ +"""This module manages the prompts used in the data extraction process.""" from pathlib import Path import frontmatter from jinja2 import Environment, FileSystemLoader, StrictUndefined, TemplateError, meta, select_autoescape class PromptManager: + """Manages the prompts used in the data extraction process.""" + _env = None @classmethod def _get_env(cls, templates_dir="prompts/templates"): + """Get the Jinja2 environment for rendering templates.""" templates_dir = Path(__file__).parent.parent / templates_dir if cls._env is None: cls._env = Environment( @@ -19,6 +23,7 @@ def _get_env(cls, templates_dir="prompts/templates"): @staticmethod def get_prompt(template, **kwargs): + """Render a prompt template with the provided context.""" env = PromptManager._get_env() template_path = f"{template}.j2" with open(env.loader.get_source(env, template_path)[1]) as file: @@ -32,6 +37,7 @@ def get_prompt(template, **kwargs): @staticmethod def get_template_info(template): + """Get information about a prompt template.""" env = PromptManager._get_env() template_path = f"{template}.j2" with open(env.loader.get_source(env, template_path)[1]) as file: diff --git a/src/invoice_processing/predict_component/predict/helpers.py b/src/invoice_processing/predict_component/predict/helpers.py index 44dc85b2..55a2c10a 100644 --- a/src/invoice_processing/predict_component/predict/helpers.py +++ b/src/invoice_processing/predict_component/predict/helpers.py @@ -1,3 +1,4 @@ +"""This module contains helper functions for the predict component of the invoice processing application.""" import base64 import json import logging diff --git a/src/invoice_processing/predict_component/predict/mlflow_logger.py b/src/invoice_processing/predict_component/predict/mlflow_logger.py index 4f7f288d..391a98f4 100644 --- a/src/invoice_processing/predict_component/predict/mlflow_logger.py +++ b/src/invoice_processing/predict_component/predict/mlflow_logger.py @@ -1,3 +1,4 @@ +"""This module provides a logger that integrates with MLflow for logging metrics.""" import mlflow from .data_extraction.extractors.base_extractor import ( LoggerProxy @@ -5,5 +6,8 @@ class MLFlowLogger(LoggerProxy): + """Logger that integrates with MLflow for logging metrics.""" + def log_metric(self, key: str, value: float) -> None: + """Log a metric with a key and value using MLflow.""" mlflow.log_metric(key, value) diff --git a/src/invoice_processing/predict_component/predict/predict.py b/src/invoice_processing/predict_component/predict/predict.py index a72e6789..bbe3b3c0 100644 --- a/src/invoice_processing/predict_component/predict/predict.py +++ b/src/invoice_processing/predict_component/predict/predict.py @@ -1,3 +1,4 @@ +"""This module contains the predict function for the invoice processing component.""" import ast import os import argparse @@ -34,14 +35,14 @@ def predict( prediction_path, ) -> None: """ - Perform end-to-end initialization and processing input folder's .png and .jpg files - using the specified model type with Azure services. + Perform data extraction using the specified orchestration strategy and Azure OpenAI model. This includes: - Initializing the Azure OpenAI client. - Creating prompt messages for the model. - Processing each file in the input folder. - Saving the output as a JSON file in the output folder. + Args: strategy (string): orchestration strategy name temperature (float): LLM temperature @@ -50,7 +51,6 @@ def predict( test_data (str): a folder with input data prediction_path (str): a folder for storing predictions """ - config_dict = ast.literal_eval(prompt_config) params = { "gpt_deployment_name": gpt_deployment_name, @@ -121,6 +121,7 @@ def predict( def estimate_cost(gpt_deployment_name, performance_df): + """Estimate the cost of processing based on the number of tokens used.""" total_input_tokens = performance_df.loc[:, 'prompt_tokens'].sum() total_output_tokens = performance_df.loc[:, 'completion_tokens'].sum() if gpt_deployment_name == 'gpt-4o': @@ -141,6 +142,7 @@ def estimate_cost(gpt_deployment_name, performance_df): def glob_by_extesion(test_data, types): + """Glob files by extension in the specified test data directory.""" all_images = [] for type in types: arr = glob(f'{test_data}/*{type}') @@ -149,6 +151,7 @@ def glob_by_extesion(test_data, types): def process(extractor: Extractor, input_path, output_folder) -> tuple[ExtractionResponse, float]: + """Process a single input file and return the extraction response and execution time.""" base64_image = convert_image_to_base64(input_path) start_time = time.time() diff --git a/src/invoice_processing/prep_component/prep/prep.py b/src/invoice_processing/prep_component/prep/prep.py index 5311f19e..734d7233 100644 --- a/src/invoice_processing/prep_component/prep/prep.py +++ b/src/invoice_processing/prep_component/prep/prep.py @@ -1,3 +1,4 @@ +"""This module contains the preprocessing step for invoice processing.""" import argparse import os import random @@ -10,7 +11,7 @@ def sample_data(data_paths, samples_amount, sampling_seed): """ - Samples randomly number of data paths based on input amount and seed. + Take a sample of data paths based on the specified amount and seed. Parameters: data_paths (str): paths to files @@ -41,7 +42,6 @@ def main(raw_data, prep_data, samples_amount, sampling_seed): samples_amount (int): amount of samples to randomly use from the data set, 0 means all sampling_seed (int): seed for random sampling of dataset, -1 means no seed """ - mlflow.log_param('number_of_samples', samples_amount) lines = [ diff --git a/src/invoice_processing/score_component/score/extraction_evaluator.py b/src/invoice_processing/score_component/score/extraction_evaluator.py index b6386df3..afbb9e98 100644 --- a/src/invoice_processing/score_component/score/extraction_evaluator.py +++ b/src/invoice_processing/score_component/score/extraction_evaluator.py @@ -1,5 +1,6 @@ """ This module contains the ExtractionEvaluator class, which provides evaluation methods for data extraction from images. + The class includes functionalities for calculating various metrics and generating reports. Classes: ExtractionEvaluation: a class for evaluation of data extraction from images @@ -18,8 +19,7 @@ class ExtractionEvaluator: """ - A comprehensive evaluator for comparing invoice details between - ground truth and predictions. + An evaluator for comparing invoice details between ground truth and predictions. This class currently supports single file evaluation with flexible comparison strategies. @@ -57,6 +57,7 @@ def __init__( def get_matcher(self, matcher_class_name: str): """ Create an instance of the requested matcher. + Args: matcher_class_name: Name of the matcher per field as defined in the experiment config file. Returns: @@ -75,9 +76,7 @@ def get_matcher(self, matcher_class_name: str): return None def get_matcher_for_best_matches_strategy(self, best_matches_strategy: str): - """ - Get find best matches strategy. - """ + """Get find best matches strategy.""" if best_matches_strategy == "levenshtein": return LevenshteinMatcher() elif best_matches_strategy == "text_exact_match": @@ -89,6 +88,7 @@ def get_matcher_for_best_matches_strategy(self, best_matches_strategy: str): def get_match_method(self, matcher_name: str): """ Get match method from matcher name. + Args: matcher_name: name of the matcher Returns: match method (currently exact_match or levenshtein) """ @@ -107,7 +107,9 @@ def compare_line_item_values_per_invoice( ): """ Compare the line items in the ground trith data with the line items in the prediction data. + Find exact matches if exist and calculate fuzzy match metrics for relevent fields. + Args: ground_truth_df: A dataframe in which each column is a different extracted field (startDate, endDate, amount, description) @@ -142,12 +144,12 @@ def get_match_results( ): """ Report the line items match results and additional datasets for error analysis. + Args: comparison_df (pd.DataFrame): dataframe will all possible combinations of line items from the ground truth and the predictions. best_matches_dict (Dict): Dictionary with lists per fuzzy match method, of pairs of matched ground truth and prediction line items. - Returns: """ unmatched_gt = pd.DataFrame() gt_cols = [] @@ -195,6 +197,7 @@ def calculate_evaluation_metrics_per_field_in_invoice( ): """ Calculate the evaluation metric per invoice per field. Currently supports accuracy. + Args: match_results_df (pd.DataFrame): A dataframe with all line items from the ground truth and the predictions: line items of the ground truth and @@ -217,6 +220,7 @@ def calculate_evaluation_metrics_per_field_in_invoice( def calculate_mean_accuracy_per_invoice(self, matches_eval_fields: pd.DataFrame): """ Calcualte the mean accuracy per field in a single invoice. + Args: matches_eval_fields (pd.DataFrame): A dataframe with the ressulting matches per line item in the ground truth data which includes only the fields we would @@ -232,7 +236,8 @@ def calculate_mean_accuracy_per_invoice(self, matches_eval_fields: pd.DataFrame) def calculate_mean_accuracy_per_batch(self, all_invoices_results: pd.DataFrame): """ - Calcualte the mean accuracy per field in a batch of invoices. + Calculate the mean accuracy per field in a batch of invoices. + Args: all_invoices_results (pd.DataFrame): A dataframe with the mean accuracy results of all invoices in the experiment. @@ -256,7 +261,8 @@ def calculate_precision_per_record( self, unmatched_pred: pd.DataFrame(), best_matches_df: pd.DataFrame() ): """ - This function calculates the precision per invoice (record). + Calculate the precision per invoice (record). + Args: unmatched_pred: Dataframe of line items in the extracted data that were not matched to any ground truth line item (defined as FPs). @@ -276,7 +282,8 @@ def calculate_recall_per_record( self, unmatched_gt: pd.DataFrame(), best_matches_df: pd.DataFrame() ): """ - This function calculates the precision per invoice (record). + Calculate the recall per invoice (record). + Args: unmatched_gt: Dataframe of line items in the ground truth data that were not matched to any extracted line item (defined as FNs). diff --git a/src/invoice_processing/score_component/score/matchers/amount_exact_matcher.py b/src/invoice_processing/score_component/score/matchers/amount_exact_matcher.py index 23b3ef96..ebb1c406 100644 --- a/src/invoice_processing/score_component/score/matchers/amount_exact_matcher.py +++ b/src/invoice_processing/score_component/score/matchers/amount_exact_matcher.py @@ -1,5 +1,4 @@ -"""Class that performs exact matches for amounts""" - +"""Class that performs exact matches for amounts.""" import logging from .base_matcher import BaseMatcher @@ -9,13 +8,12 @@ class AmountExactMatcher(BaseMatcher): - """ - Calculate amount exact match. - """ + """Calculate amount exact match.""" def amount_exact_match(self, amount1, amount2): """ - Find out whether the amounts in ground truth and prediction are equal + Find out whether the amounts in ground truth and prediction are equal. + Args: amount_str1: First amount value amount_str2: Second amount value @@ -33,15 +31,11 @@ def amount_exact_match(self, amount1, amount2): return match def get_matcher_name(self): - """ - return matcher name. - """ + """Return matcher name.""" return "amount_exact_match" def get_match(self, comparison_df, field_name): - """ - Get match result per line item. - """ + """Get match result per line item.""" match_df = comparison_df.apply( lambda x: self.amount_exact_match( x[f"{field_name}_gt"], x[f"{field_name}_pred"] diff --git a/src/invoice_processing/score_component/score/matchers/base_matcher.py b/src/invoice_processing/score_component/score/matchers/base_matcher.py index 8d2b8d34..372a774a 100644 --- a/src/invoice_processing/score_component/score/matchers/base_matcher.py +++ b/src/invoice_processing/score_component/score/matchers/base_matcher.py @@ -1,17 +1,16 @@ -"""This class is an interface """ - +"""This class is an interface.""" from abc import ABC, abstractmethod class BaseMatcher(ABC): - """ - Abstract class to define matcher base functions. - """ + """Abstract class to define matcher base functions.""" @abstractmethod def get_match(self): + """Get match result per line item.""" pass @abstractmethod def get_matcher_name(self): + """Return matcher name.""" pass diff --git a/src/invoice_processing/score_component/score/matchers/date_exact_matcher.py b/src/invoice_processing/score_component/score/matchers/date_exact_matcher.py index de18fe73..4c1d2490 100644 --- a/src/invoice_processing/score_component/score/matchers/date_exact_matcher.py +++ b/src/invoice_processing/score_component/score/matchers/date_exact_matcher.py @@ -1,5 +1,4 @@ -"""Class that performs exact matches for dates""" - +"""Class that performs exact matches for dates.""" import logging from .base_matcher import BaseMatcher @@ -13,6 +12,7 @@ class DateExactMatcher(BaseMatcher): def dates_exact_match(self, date_str1: str, date_str2: str): """ Find out whether the dates are identical. + Args: date_str1: First date value date_str2: Second date value @@ -39,15 +39,11 @@ def dates_exact_match(self, date_str1: str, date_str2: str): return match def get_matcher_name(self): - """ - return matcher name. - """ + """Return matcher name.""" return "date_exact_match" def get_match(self, comparison_df, field_name): - """ - Get match result per line item. - """ + """Get match result per line item.""" match_df = comparison_df.apply( lambda x: self.dates_exact_match( x[f"{field_name}_gt"], x[f"{field_name}_pred"] diff --git a/src/invoice_processing/score_component/score/matchers/levenshtein_matcher.py b/src/invoice_processing/score_component/score/matchers/levenshtein_matcher.py index d4fdb972..c10b3a62 100644 --- a/src/invoice_processing/score_component/score/matchers/levenshtein_matcher.py +++ b/src/invoice_processing/score_component/score/matchers/levenshtein_matcher.py @@ -1,5 +1,4 @@ -"""Class that performs exact matches for dates""" - +"""Class that performs exact matches for dates.""" import logging from typing import Dict import pandas as pd @@ -11,8 +10,10 @@ class LevenshteinMatcher(BaseMatcher): + """Class that performs fuzzy matches using Levenshtein distance.""" def get_matcher_name(self): + """Return matcher name.""" return "levenshtein" def calculate_levenshtein_ratio(self, string1: str, string2: str): @@ -38,9 +39,7 @@ def calculate_levenshtein_ratio(self, string1: str, string2: str): return rounded_levenshtein_ratio def get_match(self, comparison_df, field_name): - """ - Get match result per line item. Calculates Levenshtein ratio. - """ + """Get match result per line item. Calculates Levenshtein ratio.""" match_df = comparison_df.apply( lambda x: self.calculate_levenshtein_ratio( x[f"{field_name}_gt"], x[f"{field_name}_pred"] @@ -51,8 +50,8 @@ def get_match(self, comparison_df, field_name): def find_best_matches(self, comparison_df: pd.DataFrame, fuzzy_match_config: Dict): """ - For every line item in the ground truth data, find the most similar - line item in the predictions + For each line item in the ground truth, find the most similar in predictions. + Args: comparison_df: a dataframe which is the cartesian product of the line items inthe ground truth and the predictions datasets diff --git a/src/invoice_processing/score_component/score/matchers/text_exact_matcher.py b/src/invoice_processing/score_component/score/matchers/text_exact_matcher.py index 612c24d0..3534fe8d 100644 --- a/src/invoice_processing/score_component/score/matchers/text_exact_matcher.py +++ b/src/invoice_processing/score_component/score/matchers/text_exact_matcher.py @@ -1,5 +1,4 @@ -"""Class that performs exact matches for text""" - +"""Class that performs exact matches for text.""" import logging from .base_matcher import BaseMatcher @@ -11,15 +10,11 @@ class TextExactMatcher(BaseMatcher): def get_matcher_name(self): - """ - return matcher name. - """ + """Return matcher name.""" return "text_exact_match" def get_match(self, comparison_df, field_name): - """ - Get match result per line item. - """ + """Get match result per line item.""" match_df = comparison_df.apply( lambda x: self.text_exact_match( x[f"{field_name}_gt"], x[f"{field_name}_pred"] @@ -31,6 +26,7 @@ def get_match(self, comparison_df, field_name): def text_exact_match(self, str1: str, str2: str): """ Find out whether the dates are identical. + Args: str1: First string value date_str2: Second string value @@ -48,8 +44,8 @@ def text_exact_match(self, str1: str, str2: str): def find_best_matches(self, comparison_df): """ - For every line item in the ground truth data, find the most similar - line item in the predictions + For each line item in the ground truth, find the most similar in predictions. + Args: comparison_df: a dataframe which is the cartesian product of the line items inthe ground truth and the predictions datasets diff --git a/src/invoice_processing/score_component/score/score.py b/src/invoice_processing/score_component/score/score.py index 66516958..fd8b9e53 100644 --- a/src/invoice_processing/score_component/score/score.py +++ b/src/invoice_processing/score_component/score/score.py @@ -1,5 +1,6 @@ """ This module runs the evaluation step of the experimentation framework. + First, the ground truth data and the predictions data are read. Next, each line item in the ground truth data is compared against line items from the prediction data to find the best match for @@ -8,7 +9,6 @@ for all fields combined. The score results are logged into AML. """ - import os import argparse @@ -29,6 +29,7 @@ def get_score_config(score_config_str): """ Load score config from dict loaded as str. + Args: components_config: Dictionary loaded as string with configuration Returns: @@ -52,7 +53,8 @@ def get_score_config(score_config_str): def create_extraction_evaluator(components_config): """ - Initialize evaluator object + Initialize evaluator object. + Args: components_config: Dictionary loaded as string with configuration Returns: @@ -75,13 +77,14 @@ def create_extraction_evaluator(components_config): def get_gt_and_pred_data_for_evaluation(ground_truth, predictions): """ - Parse current JSON input to DataFrames - Args: - ground_truth: Ground truth JSON object - predictions: Predictions JSON object - Returns: - gt_data: DataFrame of the line items of the ground truth data - pred_data: DataFrame of the line items of the predictions data + Parse current JSON input to DataFrames. + + Args: + ground_truth: Ground truth JSON object + predictions: Predictions JSON object + Returns: + gt_data: DataFrame of the line items of the ground truth data + pred_data: DataFrame of the line items of the predictions data """ # normalize ground truth and predictions structure ground_truth_invoice = ground_truth["lineItems"] @@ -123,6 +126,7 @@ def get_corresponding_prediction_path( ): """ Get the file path of the predictions that correspond to a given ground truth file. + Args: gt_path (str): File path to the currently evaluated ground truth data. pred_path (str): path to the predictions directory or file path @@ -151,6 +155,7 @@ def add_ref_ids_to_result_dfs( ): """ Add reference ids or predicted data path to the reported results dataframes. + Args: best_matches_df: Dataframe with the line items that were matched. unmatched_gt: Dataframe with line items from the ground truth that were not matched. @@ -179,9 +184,12 @@ def add_ref_ids_to_result_dfs( def evaluate(all_invoices_pred, all_invoices_gt, components_config): """ - Evaluates the quality of data extraction from images by comparing - the extracted data to ground truth. This function calculates the - accuracy, precision and recall to assess the correctness of the extraction. + Evaluate the quality of data extraction from images. + + It does it by comparing the extracted data to ground truth. + This function calculates the accuracy, precision and recall + to assess the correctness of the extraction. + Args: predictions_file_path (str): Path of the predictions file (the extracted data) @@ -293,9 +301,7 @@ def log_results( overall_precision: float, overall_recall: float, ): - """ - Log score results to AML - """ + """Log score results to AML.""" score_results_output_path = "score_results.csv" all_unmatched_gt_path = "all_unmatched_gt.csv" all_unmatched_pred_path = "all_unmatched_pred.csv" @@ -336,7 +342,8 @@ def main( all_unmatched_pred_path, components_config, ): - """Load ground truth and predictions data, call score function. + """ + Load ground truth and predictions data, call score function. Args: predictions_path (string): path to predictions data diff --git a/src/invoice_processing/score_component/score/utils.py b/src/invoice_processing/score_component/score/utils.py index fd760e02..3ce727f2 100644 --- a/src/invoice_processing/score_component/score/utils.py +++ b/src/invoice_processing/score_component/score/utils.py @@ -1,8 +1,5 @@ """ -utils.py - -This module contains various utility functions that can be used across -different parts of the project. +Utility functions for invoice processing scoring component. Functions: read_json_file(file_path): @@ -12,7 +9,6 @@ load_csv_file(file_path): Reads a CSV file and returns the parsed data. """ - import json from pathlib import Path from typing import Union, List @@ -23,9 +19,38 @@ log = logging.getLogger(__name__) +def _load_json_from_file(file_path): + """Helper to load JSON from a single file and return data.""" + try: + with open(file_path, "r", encoding="utf-8") as f: + curr_data = json.load(f) + return curr_data + except FileNotFoundError: + log.error(f"Error: The file at {file_path} was not found") + except json.JSONDecodeError: + log.error(f"Error: The file at {file_path} is not a valid JSON") + return None + + +def _load_json_from_directory(directory_path): + """Helper to load JSON from all files in a directory.""" + all_data = [] + all_data_dict = {} + for file_path in Path(directory_path).glob("*.json"): + log.debug(f"file_path: {file_path}") + curr_data = _load_json_from_file(file_path) + if curr_data is not None: + if isinstance(curr_data, list): + all_data += curr_data + else: + all_data_dict[str(file_path)] = curr_data + return all_data, all_data_dict + + def load_json_file(path: Union[str, Path]): """ Reads a JSON file and returns the parsed data. + Args: file_path (str): The path to the JSON file to be read. Returns: @@ -33,42 +58,18 @@ def load_json_file(path: Union[str, Path]): Raises: FileNotFoundError """ - # Load ground truth data - all_data = [] - all_data_dict = {} data_path = Path(path) if data_path.is_dir(): - # Multiple files in a directory - for file_path in data_path.glob("*.json"): - log.debug(f"file_path: {file_path}") - try: - with open(file_path, "r", encoding="utf-8") as f: - curr_data = json.load(f) - # For ground truth format - if isinstance(curr_data, List): - all_data = all_data + curr_data - # For predictions data format - else: - all_data_dict[str(file_path)] = curr_data - except FileNotFoundError: - log.error(f"Error: The file at {file_path} was not found") - except json.JSONDecodeError: - log.error(f"Error: The file at {file_path} is not a valid JSON") + all_data, all_data_dict = _load_json_from_directory(data_path) else: - # Single file - try: - with open(path, "r", encoding="utf-8") as f: - curr_data = json.load(f) - # For ground truth format - if isinstance(curr_data, List): - all_data = all_data + curr_data - # For predictions data format - else: - all_data_dict[str(file_path)] = curr_data - except FileNotFoundError: - log.error(f"Error: The file at {file_path} was not found") - except json.JSONDecodeError: - log.error(f"Error: The file at {file_path} is not a valid JSON") + all_data = [] + all_data_dict = {} + curr_data = _load_json_from_file(data_path) + if curr_data is not None: + if isinstance(curr_data, list): + all_data += curr_data + else: + all_data_dict[str(data_path)] = curr_data if len(all_data) > 1: return all_data else: @@ -76,9 +77,7 @@ def load_json_file(path: Union[str, Path]): def normalize_string(value: str) -> str: - """ - Normalize string by stripping extra whitespace and converting to lowercase. - """ + """Normalize string by stripping extra whitespace and converting to lowercase.""" if not isinstance(value, str): return str(value) value = re.sub(r"\s+", " ", value).strip().lower() @@ -89,10 +88,7 @@ def normalize_string(value: str) -> str: def preprocess_amount(amount): - """ - Amount pre-processing - remove parentheses and - white spaces from amount string - """ + """Remove parentheses and white spaces from amount string.""" parsed_amount = "" if isinstance(amount, str): parsed_amount = amount.strip() @@ -105,10 +101,7 @@ def preprocess_amount(amount): def preprocess_date(date_str): - """ - Date preprocessing - remove whitespaces and parse - date string into date object - """ + """Remove whitespaces and parse date string into date object.""" date_str = date_str.strip() date = parse(date_str) return date From 3701e3d44aa44dfadc6a3fb73fe3232ea33b208f Mon Sep 17 00:00:00 2001 From: Martyna Marcinkowska Date: Fri, 6 Jun 2025 13:43:13 +0100 Subject: [PATCH 04/21] fix linter --- .../data_extraction/config/configuration_container.py | 1 + .../predict/data_extraction/data_extractor_factory.py | 1 + .../predict/data_extraction/prompts/prompt_manager.py | 2 +- .../score_component/score/matchers/date_exact_matcher.py | 1 + .../score_component/score/matchers/levenshtein_matcher.py | 3 +-- .../score_component/score/matchers/text_exact_matcher.py | 1 + src/invoice_processing/score_component/score/utils.py | 8 ++++---- 7 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/invoice_processing/predict_component/predict/data_extraction/config/configuration_container.py b/src/invoice_processing/predict_component/predict/data_extraction/config/configuration_container.py index f4e3804b..80e248f1 100644 --- a/src/invoice_processing/predict_component/predict/data_extraction/config/configuration_container.py +++ b/src/invoice_processing/predict_component/predict/data_extraction/config/configuration_container.py @@ -1,3 +1,4 @@ +"""Container for storing configurations of extractors.""" class ConfigurationContainer: """A simple service container to store configurations for extractors.""" diff --git a/src/invoice_processing/predict_component/predict/data_extraction/data_extractor_factory.py b/src/invoice_processing/predict_component/predict/data_extraction/data_extractor_factory.py index eda583aa..62a14f0c 100644 --- a/src/invoice_processing/predict_component/predict/data_extraction/data_extractor_factory.py +++ b/src/invoice_processing/predict_component/predict/data_extraction/data_extractor_factory.py @@ -1,3 +1,4 @@ +"""Factory for creating data extractors based on categories.""" from .config.configuration_container import ConfigurationContainer from .extractors.base_extractor import Extractor, LoggerProxy diff --git a/src/invoice_processing/predict_component/predict/data_extraction/prompts/prompt_manager.py b/src/invoice_processing/predict_component/predict/data_extraction/prompts/prompt_manager.py index c59647cc..9347283b 100644 --- a/src/invoice_processing/predict_component/predict/data_extraction/prompts/prompt_manager.py +++ b/src/invoice_processing/predict_component/predict/data_extraction/prompts/prompt_manager.py @@ -5,7 +5,7 @@ class PromptManager: - """Manages the prompts used in the data extraction process.""" + """Manages the prompts used in the data extraction process.""" _env = None diff --git a/src/invoice_processing/score_component/score/matchers/date_exact_matcher.py b/src/invoice_processing/score_component/score/matchers/date_exact_matcher.py index 4c1d2490..e8299f85 100644 --- a/src/invoice_processing/score_component/score/matchers/date_exact_matcher.py +++ b/src/invoice_processing/score_component/score/matchers/date_exact_matcher.py @@ -8,6 +8,7 @@ class DateExactMatcher(BaseMatcher): + """Class that performs exact matches for dates.""" def dates_exact_match(self, date_str1: str, date_str2: str): """ diff --git a/src/invoice_processing/score_component/score/matchers/levenshtein_matcher.py b/src/invoice_processing/score_component/score/matchers/levenshtein_matcher.py index c10b3a62..a0d07fcd 100644 --- a/src/invoice_processing/score_component/score/matchers/levenshtein_matcher.py +++ b/src/invoice_processing/score_component/score/matchers/levenshtein_matcher.py @@ -18,7 +18,7 @@ def get_matcher_name(self): def calculate_levenshtein_ratio(self, string1: str, string2: str): """ - Calculates the Levenshtein ratio between two strings. + Calculate the Levenshtein ratio between two strings. The Levenshtein ratio is a measure of the similarity between two strings, defined as the ratio of the Levenshtein distance to the length of the longer string. @@ -58,7 +58,6 @@ def find_best_matches(self, comparison_df: pd.DataFrame, fuzzy_match_config: Dic Returns: A dictionary of the best matches: {"fuzzy_match_method_name": best_matches_df} """ - levenshtein_ratio_thr = fuzzy_match_config["field_match_threshold"] remaining_comparisons = comparison_df.copy() best_matches_list_levenshtein = [] diff --git a/src/invoice_processing/score_component/score/matchers/text_exact_matcher.py b/src/invoice_processing/score_component/score/matchers/text_exact_matcher.py index 3534fe8d..f41d07f8 100644 --- a/src/invoice_processing/score_component/score/matchers/text_exact_matcher.py +++ b/src/invoice_processing/score_component/score/matchers/text_exact_matcher.py @@ -8,6 +8,7 @@ class TextExactMatcher(BaseMatcher): + """Class that performs exact matches for text.""" def get_matcher_name(self): """Return matcher name.""" diff --git a/src/invoice_processing/score_component/score/utils.py b/src/invoice_processing/score_component/score/utils.py index 3ce727f2..903c53c2 100644 --- a/src/invoice_processing/score_component/score/utils.py +++ b/src/invoice_processing/score_component/score/utils.py @@ -11,7 +11,7 @@ """ import json from pathlib import Path -from typing import Union, List +from typing import Union import re import logging from dateutil.parser import parse @@ -20,7 +20,7 @@ def _load_json_from_file(file_path): - """Helper to load JSON from a single file and return data.""" + """Load JSON data from a file.""" try: with open(file_path, "r", encoding="utf-8") as f: curr_data = json.load(f) @@ -33,7 +33,7 @@ def _load_json_from_file(file_path): def _load_json_from_directory(directory_path): - """Helper to load JSON from all files in a directory.""" + """Load JSON data from all files in a directory.""" all_data = [] all_data_dict = {} for file_path in Path(directory_path).glob("*.json"): @@ -49,7 +49,7 @@ def _load_json_from_directory(directory_path): def load_json_file(path: Union[str, Path]): """ - Reads a JSON file and returns the parsed data. + Read a JSON file and returns the parsed data. Args: file_path (str): The path to the JSON file to be read. From 1168d0cce8b21a9ccff1b2d9b14f5791dbde9e0d Mon Sep 17 00:00:00 2001 From: Martyna Marcinkowska Date: Fri, 6 Jun 2025 13:45:47 +0100 Subject: [PATCH 05/21] fix linter last time --- .../predict/data_extraction/config/configuration_container.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/invoice_processing/predict_component/predict/data_extraction/config/configuration_container.py b/src/invoice_processing/predict_component/predict/data_extraction/config/configuration_container.py index 80e248f1..63cf98cb 100644 --- a/src/invoice_processing/predict_component/predict/data_extraction/config/configuration_container.py +++ b/src/invoice_processing/predict_component/predict/data_extraction/config/configuration_container.py @@ -1,4 +1,6 @@ """Container for storing configurations of extractors.""" + + class ConfigurationContainer: """A simple service container to store configurations for extractors.""" From 61605e7ba1301b0af7389562a08ef5dc474f8184 Mon Sep 17 00:00:00 2001 From: Martyna Marcinkowska Date: Fri, 6 Jun 2025 13:53:05 +0100 Subject: [PATCH 06/21] add python path src --- .github/workflows/build_validation_workflow.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_validation_workflow.yml b/.github/workflows/build_validation_workflow.yml index e4489194..4aabd81f 100644 --- a/.github/workflows/build_validation_workflow.yml +++ b/.github/workflows/build_validation_workflow.yml @@ -86,7 +86,7 @@ jobs: - name: Run Unit Tests shell: bash run: | - pytest --junitxml=junit/test-results.xml --cov=. --cov-report=xml + PYTHONPATH=$PYTHONPATH:$(pwd)/src pytest --junitxml=junit/test-results.xml --cov=. --cov-report=xml - name: Publish Test Results uses: actions/upload-artifact@v4 with: From d598dacde3e4ace981fcb0abd4a1e478e46a4e7a Mon Sep 17 00:00:00 2001 From: Martyna Marcinkowska Date: Fri, 6 Jun 2025 13:56:20 +0100 Subject: [PATCH 07/21] update requirememnts files --- .../requirements/build_validation_requirements.txt | 11 ++++++++--- .github/requirements/execute_job_requirements.txt | 9 +++++---- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/.github/requirements/build_validation_requirements.txt b/.github/requirements/build_validation_requirements.txt index 9eacf419..da4a29e3 100644 --- a/.github/requirements/build_validation_requirements.txt +++ b/.github/requirements/build_validation_requirements.txt @@ -5,9 +5,14 @@ pytest-cov==3.0.0 pytest-azurepipelines==1.0.3 pytest-mock==3.7.0 pytest==7.1.2 -mlflow==2.11.3 +mlflow==2.16.0 mldesigner==0.1.0b4 -azure-ai-ml==1.8.0 +azure-ai-ml==1.23.1 azure-identity==1.16.1 +azureml-fsspec==1.3.1 python-dotenv>=0.10.3 -azureml-mlflow>=1.51 +azureml-mlflow>=1.59 +openai==1.59.3 +python-frontmatter +Levenshtein +python-retry \ No newline at end of file diff --git a/.github/requirements/execute_job_requirements.txt b/.github/requirements/execute_job_requirements.txt index 19340555..3b3dbfc4 100644 --- a/.github/requirements/execute_job_requirements.txt +++ b/.github/requirements/execute_job_requirements.txt @@ -1,8 +1,9 @@ azure-cli==2.64.0 -azure-ai-ml==1.12.1 +azure-ai-ml==1.23.1 azure-identity==1.16.1 -mlflow==2.11.3 +mlflow==2.16.0 python-dotenv>=0.10.3 -azureml-mlflow>=1.51 +azureml-mlflow>=1.59 azureml-core -azureml-mlflow>=1.51 +azureml-mlflow>=1.59 +azureml-fsspec==1.3.1 \ No newline at end of file From ab18bf529f253f6f6c37043899c998e56ee1e614 Mon Sep 17 00:00:00 2001 From: Martyna Marcinkowska Date: Fri, 6 Jun 2025 14:04:22 +0100 Subject: [PATCH 08/21] random change for test --- src/london_src/predict/predict.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/london_src/predict/predict.py b/src/london_src/predict/predict.py index 7cbf28dd..335ec708 100644 --- a/src/london_src/predict/predict.py +++ b/src/london_src/predict/predict.py @@ -34,6 +34,7 @@ def main(model_input, test_data, prediction_path): for line in lines: print(line) + print("Loading test data...") test_x, testy = load_test_data(test_data) predict(test_x, testy, model_input, prediction_path) From dca1eb9b53f3f0a80f433f5f797c6627e8b011ac Mon Sep 17 00:00:00 2001 From: Martyna Marcinkowska Date: Fri, 6 Jun 2025 14:11:38 +0100 Subject: [PATCH 09/21] add clear azure cache --- .github/actions/configure_azureml_agent/action.yml | 4 ++++ .github/actions/execute_shell_code/action.yml | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/.github/actions/configure_azureml_agent/action.yml b/.github/actions/configure_azureml_agent/action.yml index 1fa964fe..9f13737f 100644 --- a/.github/actions/configure_azureml_agent/action.yml +++ b/.github/actions/configure_azureml_agent/action.yml @@ -19,6 +19,10 @@ inputs: runs: using: composite steps: + - name: Clear Azure CLI token cache + run: | + rm -rf ~/.azure + - name: Azure login uses: azure/login@v2 with: diff --git a/.github/actions/execute_shell_code/action.yml b/.github/actions/execute_shell_code/action.yml index 569b14e1..87a83e01 100644 --- a/.github/actions/execute_shell_code/action.yml +++ b/.github/actions/execute_shell_code/action.yml @@ -20,6 +20,10 @@ inputs: runs: using: composite steps: + - name: Clear Azure CLI token cache + run: | + rm -rf ~/.azure + - name: Azure login uses: azure/login@v2 with: From aee5494e44e7bf53439d857d46e8d68f5d31ec0d Mon Sep 17 00:00:00 2001 From: Martyna Marcinkowska Date: Fri, 6 Jun 2025 14:16:29 +0100 Subject: [PATCH 10/21] fix add shell to action --- .github/actions/configure_azureml_agent/action.yml | 1 + .github/actions/execute_shell_code/action.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/actions/configure_azureml_agent/action.yml b/.github/actions/configure_azureml_agent/action.yml index 9f13737f..c7f4b0fd 100644 --- a/.github/actions/configure_azureml_agent/action.yml +++ b/.github/actions/configure_azureml_agent/action.yml @@ -20,6 +20,7 @@ runs: using: composite steps: - name: Clear Azure CLI token cache + shell: bash run: | rm -rf ~/.azure diff --git a/.github/actions/execute_shell_code/action.yml b/.github/actions/execute_shell_code/action.yml index 87a83e01..be21311c 100644 --- a/.github/actions/execute_shell_code/action.yml +++ b/.github/actions/execute_shell_code/action.yml @@ -21,6 +21,7 @@ runs: using: composite steps: - name: Clear Azure CLI token cache + shell: bash run: | rm -rf ~/.azure From ee7f415e61ccc1c7ad99d5cd80490ffbf158c466 Mon Sep 17 00:00:00 2001 From: Martyna Marcinkowska Date: Fri, 6 Jun 2025 14:20:44 +0100 Subject: [PATCH 11/21] revert configure change --- .github/actions/configure_azureml_agent/action.yml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.github/actions/configure_azureml_agent/action.yml b/.github/actions/configure_azureml_agent/action.yml index c7f4b0fd..1fa964fe 100644 --- a/.github/actions/configure_azureml_agent/action.yml +++ b/.github/actions/configure_azureml_agent/action.yml @@ -19,11 +19,6 @@ inputs: runs: using: composite steps: - - name: Clear Azure CLI token cache - shell: bash - run: | - rm -rf ~/.azure - - name: Azure login uses: azure/login@v2 with: From 1d21951792601bc0dfc7db12a3a8f02b4582557e Mon Sep 17 00:00:00 2001 From: Martyna Marcinkowska Date: Fri, 6 Jun 2025 14:36:32 +0100 Subject: [PATCH 12/21] try fix --- .github/requirements/execute_job_requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/requirements/execute_job_requirements.txt b/.github/requirements/execute_job_requirements.txt index 3b3dbfc4..146f8386 100644 --- a/.github/requirements/execute_job_requirements.txt +++ b/.github/requirements/execute_job_requirements.txt @@ -6,4 +6,5 @@ python-dotenv>=0.10.3 azureml-mlflow>=1.59 azureml-core azureml-mlflow>=1.59 -azureml-fsspec==1.3.1 \ No newline at end of file +azureml-fsspec==1.3.1 +marshmallow>=3.18.0,<4.0.0 \ No newline at end of file From 4029f680ffa1fa5374fa2443598284457e0f2b28 Mon Sep 17 00:00:00 2001 From: Martyna Marcinkowska Date: Fri, 6 Jun 2025 20:35:07 +0100 Subject: [PATCH 13/21] change to strings --- config/experiment_config.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/config/experiment_config.yaml b/config/experiment_config.yaml index 96129f5c..51add5d5 100644 --- a/config/experiment_config.yaml +++ b/config/experiment_config.yaml @@ -1,7 +1,7 @@ experiment_description: - user_name: - title: - hypothesis: + user_name: user + title: title + hypothesis: hypothesis prep_config: samples_amount: 4 From 076f9a3ff7baaa29a8ab464a4f3c4384aeef14ea Mon Sep 17 00:00:00 2001 From: Martyna Marcinkowska Date: Fri, 6 Jun 2025 20:46:27 +0100 Subject: [PATCH 14/21] fix --- config/experiment_config.yaml | 6 +++--- mlops/common/config_utils.py | 18 +++++++++++++++++- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/config/experiment_config.yaml b/config/experiment_config.yaml index 51add5d5..96129f5c 100644 --- a/config/experiment_config.yaml +++ b/config/experiment_config.yaml @@ -1,7 +1,7 @@ experiment_description: - user_name: user - title: title - hypothesis: hypothesis + user_name: + title: + hypothesis: prep_config: samples_amount: 4 diff --git a/mlops/common/config_utils.py b/mlops/common/config_utils.py index 9d531fcb..7778128b 100644 --- a/mlops/common/config_utils.py +++ b/mlops/common/config_utils.py @@ -1,4 +1,5 @@ """Configuration utils to load config from yaml/json.""" + import os from typing import Dict, Any from pathlib import Path @@ -12,19 +13,34 @@ class MLOpsConfig: _raw_config: Any def __init__( - self, environment: str = "pr", config_path: Path = "config/config.yaml" + self, + environment: str = "pr", + config_path: Path = "config/config.yaml", + exp_config_path: Path = "config/experiment_config.yaml" ): """Intialize MLConfig with yaml config data.""" self.config_path = config_path + self.exp_config_path = exp_config_path self._environment = environment load_dotenv() with open(config_path, "r", encoding="utf-8") as stream: self._raw_config = yaml.safe_load(os.path.expandvars(stream.read())) + with open(exp_config_path, "r", encoding="utf-8") as stream: + self._raw_desc_config = yaml.safe_load(os.path.expandvars(stream.read()))["experiment_description"] + def __getattr__(self, __name: str) -> Any: """Get values for top level keys in configuration.""" return self._raw_config[__name] + def get_experiment_description(self) -> str: + """Get the experiment description from the configuration.""" + name = self._raw_desc_config["user_name"] + title = self._raw_desc_config["title"] + hypothesis = self._raw_desc_config["hypothesis"] + + return f"User Name: {name} \n\n Title: {title} \n\n Hypothesis: {hypothesis}" + def get_pipeline_config(self, pipeline_name: str) -> Dict: """Get the pipeline configuration for given pipeline name and environment.""" pipelineconfig_name = f"{pipeline_name}_{self._environment}" From 48fb444cb9ac10e89e2109181458ac2244f337c9 Mon Sep 17 00:00:00 2001 From: Martyna Marcinkowska Date: Fri, 6 Jun 2025 20:55:12 +0100 Subject: [PATCH 15/21] add kwargs --- mlops/common/pipeline_job_config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlops/common/pipeline_job_config.py b/mlops/common/pipeline_job_config.py index e2ddb256..b8a6a6e7 100644 --- a/mlops/common/pipeline_job_config.py +++ b/mlops/common/pipeline_job_config.py @@ -14,6 +14,7 @@ def __init__( wait_for_completion: str, output_file: str, model_name: str, + **kwargs ): """ Initialize the pipeline job components. @@ -36,3 +37,5 @@ def __init__( self.wait_for_completion = wait_for_completion self.output_file = output_file self.model_name = model_name + for key, value in kwargs.items(): + setattr(self, key, value) From 1f91dcf27520e1e8b71905f2ddfcf2be0edbb38f Mon Sep 17 00:00:00 2001 From: Martyna Marcinkowska Date: Fri, 6 Jun 2025 21:06:23 +0100 Subject: [PATCH 16/21] remove exp desc --- config/experiment_config.yaml | 5 ----- docs/how-to/ConfigureExperiments.md | 13 +------------ mlops/common/config_utils.py | 11 ----------- 3 files changed, 1 insertion(+), 28 deletions(-) diff --git a/config/experiment_config.yaml b/config/experiment_config.yaml index 96129f5c..76ab72a5 100644 --- a/config/experiment_config.yaml +++ b/config/experiment_config.yaml @@ -1,8 +1,3 @@ -experiment_description: - user_name: - title: - hypothesis: - prep_config: samples_amount: 4 sampling_seed: 42 diff --git a/docs/how-to/ConfigureExperiments.md b/docs/how-to/ConfigureExperiments.md index 50e09b94..d15fe98c 100644 --- a/docs/how-to/ConfigureExperiments.md +++ b/docs/how-to/ConfigureExperiments.md @@ -54,18 +54,7 @@ This file is only used by the invoice_processing pipeline and is used to configu The file has several sections, each section configures a different component of the experiment. -The `experiment_description` section enables the users to add more information about the experiment they are about to run. The user can provide their user name, give a title to the experiment and explain what hypothesis is being tested in this experiment. This information will be logged into AML's job description. - -``` yaml -experiment_description: - user_name: - title: - hypothesis: -``` - -Providing this information will make it easier to differentiate between the different AML runs. - -The next section in the config file configures the data preparation step of the pipeline. +The first section in the config file configures the data preparation step of the pipeline. ```yaml prep_config: diff --git a/mlops/common/config_utils.py b/mlops/common/config_utils.py index 7778128b..7e72c335 100644 --- a/mlops/common/config_utils.py +++ b/mlops/common/config_utils.py @@ -26,21 +26,10 @@ def __init__( with open(config_path, "r", encoding="utf-8") as stream: self._raw_config = yaml.safe_load(os.path.expandvars(stream.read())) - with open(exp_config_path, "r", encoding="utf-8") as stream: - self._raw_desc_config = yaml.safe_load(os.path.expandvars(stream.read()))["experiment_description"] - def __getattr__(self, __name: str) -> Any: """Get values for top level keys in configuration.""" return self._raw_config[__name] - def get_experiment_description(self) -> str: - """Get the experiment description from the configuration.""" - name = self._raw_desc_config["user_name"] - title = self._raw_desc_config["title"] - hypothesis = self._raw_desc_config["hypothesis"] - - return f"User Name: {name} \n\n Title: {title} \n\n Hypothesis: {hypothesis}" - def get_pipeline_config(self, pipeline_name: str) -> Dict: """Get the pipeline configuration for given pipeline name and environment.""" pipelineconfig_name = f"{pipeline_name}_{self._environment}" From 9d3af4ab5c8c0cd59e4f19ca45d45935d3014346 Mon Sep 17 00:00:00 2001 From: Martyna Marcinkowska Date: Fri, 6 Jun 2025 21:11:22 +0100 Subject: [PATCH 17/21] no desc --- mlops/invoice_processing/src/mlops_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mlops/invoice_processing/src/mlops_pipeline.py b/mlops/invoice_processing/src/mlops_pipeline.py index a8c17980..41d78c5f 100644 --- a/mlops/invoice_processing/src/mlops_pipeline.py +++ b/mlops/invoice_processing/src/mlops_pipeline.py @@ -226,7 +226,6 @@ def prepare_and_execute( pipeline_config = config.get_pipeline_config(model_name) published_model_name = generate_model_name(model_name) - experiment_description = config.get_experiment_description() pipeline_job_config = InvoiceProcessing( environment_name=None, # will be set in prepare_and_execute_pipeline From 2edbf21bd91594c1753bc59af1bbf9c1311689e9 Mon Sep 17 00:00:00 2001 From: Martyna Marcinkowska Date: Fri, 6 Jun 2025 21:11:50 +0100 Subject: [PATCH 18/21] missed arg --- mlops/invoice_processing/src/mlops_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlops/invoice_processing/src/mlops_pipeline.py b/mlops/invoice_processing/src/mlops_pipeline.py index 41d78c5f..2c52f1f2 100644 --- a/mlops/invoice_processing/src/mlops_pipeline.py +++ b/mlops/invoice_processing/src/mlops_pipeline.py @@ -240,7 +240,7 @@ def prepare_and_execute( predictions=predictions, ) - prepare_and_execute_pipeline(pipeline_job_config, experiment_description) + prepare_and_execute_pipeline(pipeline_job_config) def main(): From 81b2b29a882320caae6011b1011892591820d8c8 Mon Sep 17 00:00:00 2001 From: Martyna Marcinkowska Date: Fri, 27 Jun 2025 16:15:18 +0100 Subject: [PATCH 19/21] remove tests --- .../data_extraction/assets/config.json | 10 - .../data_extraction/assets/mock_extractor.py | 29 -- .../extractors/test_gpt_only_extractor.py | 202 -------- .../models/test_extraction_response.py | 54 --- .../prompts/test_prompt_manager.py | 45 -- .../test_configuration_container.py | 29 -- .../data_extraction/test_extractor_factory.py | 77 --- .../predict_component/predict/test_helpers.py | 33 -- .../predict_component/predict/test_predict.py | 107 ---- .../test_experiment_config.yaml | 15 - .../test_extraction_evaluator.py | 119 ----- .../score_component/test_score.py | 457 ------------------ .../score_component/test_utils.py | 22 - .../test_invoice_processing_to_delete.py | 6 + 14 files changed, 6 insertions(+), 1199 deletions(-) delete mode 100644 test/invoice_processing/predict_component/predict/data_extraction/assets/config.json delete mode 100644 test/invoice_processing/predict_component/predict/data_extraction/assets/mock_extractor.py delete mode 100644 test/invoice_processing/predict_component/predict/data_extraction/extractors/test_gpt_only_extractor.py delete mode 100644 test/invoice_processing/predict_component/predict/data_extraction/models/test_extraction_response.py delete mode 100644 test/invoice_processing/predict_component/predict/data_extraction/prompts/test_prompt_manager.py delete mode 100644 test/invoice_processing/predict_component/predict/data_extraction/test_configuration_container.py delete mode 100644 test/invoice_processing/predict_component/predict/data_extraction/test_extractor_factory.py delete mode 100644 test/invoice_processing/predict_component/predict/test_helpers.py delete mode 100644 test/invoice_processing/predict_component/predict/test_predict.py delete mode 100644 test/invoice_processing/score_component/test_experiment_config.yaml delete mode 100644 test/invoice_processing/score_component/test_extraction_evaluator.py delete mode 100644 test/invoice_processing/score_component/test_score.py delete mode 100644 test/invoice_processing/score_component/test_utils.py create mode 100644 test/invoice_processing/test_invoice_processing_to_delete.py diff --git a/test/invoice_processing/predict_component/predict/data_extraction/assets/config.json b/test/invoice_processing/predict_component/predict/data_extraction/assets/config.json deleted file mode 100644 index f398fd5c..00000000 --- a/test/invoice_processing/predict_component/predict/data_extraction/assets/config.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "gpt_only": { - "azure_openai_endpoint": "http://example", - "azure_openai_api_key": "example" - }, - "mock_ocr_extractor": { - "encoding": "utf-16", - "logging_enabled": false - } -} diff --git a/test/invoice_processing/predict_component/predict/data_extraction/assets/mock_extractor.py b/test/invoice_processing/predict_component/predict/data_extraction/assets/mock_extractor.py deleted file mode 100644 index 98fed720..00000000 --- a/test/invoice_processing/predict_component/predict/data_extraction/assets/mock_extractor.py +++ /dev/null @@ -1,29 +0,0 @@ -from src.invoice_processing.predict_component.predict.data_extraction.extractors.base_extractor import ( - Extractor -) -from src.invoice_processing.predict_component.predict.data_extraction.models.extraction_response import ( - ExtractionResponse, - Invoice, - Provider, - ServiceFor -) - - -class MockExtractor(Extractor): - - def extract_data(self, file) -> ExtractionResponse: - invoice = Invoice( - lineItems=[], - provider=Provider( - name="Mock Provider" - ), - serviceFor=ServiceFor( - name="Mock Patient" - ), - totalClaimAmount=1.99 - ) - - return ExtractionResponse( - invoice=invoice, - metadata={} - ) diff --git a/test/invoice_processing/predict_component/predict/data_extraction/extractors/test_gpt_only_extractor.py b/test/invoice_processing/predict_component/predict/data_extraction/extractors/test_gpt_only_extractor.py deleted file mode 100644 index 4d831bb9..00000000 --- a/test/invoice_processing/predict_component/predict/data_extraction/extractors/test_gpt_only_extractor.py +++ /dev/null @@ -1,202 +0,0 @@ -import unittest -from unittest.mock import patch -from openai.types.chat.parsed_chat_completion import ParsedChatCompletion, ParsedChoice, ParsedChatCompletionMessage -from openai.types.completion_usage import CompletionUsage - -from src.invoice_processing.predict_component.predict.data_extraction.extractors.gpt_only_extractor import ( - GPTOnlyExtractor -) -from src.invoice_processing.predict_component.predict.data_extraction.models.extraction_response import ( - ExtractionResponse, - Invoice, - LineItem, - Provider, - ServiceFor -) - - -class TestGPTOnlyExtractor(unittest.TestCase): - @patch('src.invoice_processing.predict_component.predict.data_extraction.extractors.gpt_only_extractor.AzureOpenAI') - @patch('src.invoice_processing.predict_component.predict.data_extraction.extractors.gpt_only_extractor.LoggerProxy') - @patch('src.invoice_processing.predict_component.predict.data_extraction.extractors.gpt_only_extractor.GPTOnlyExtractor.create_prompt') - def test_extract_data(self, mock_create_prompt, mock_logger_proxy, mock_azure_open_ai): - mock_create_prompt.return_value = [ - { - "role": "system", - "content": "You are an AI assistant" - }, - {"role": "user", "content": "Hi."} - ] - - extraction_response = ExtractionResponse( - invoice=Invoice( - totalClaimAmount=0.0, - provider=Provider( - name="" - ), - serviceFor=ServiceFor( - name="" - ), - lineItems=[ - LineItem( - amount=0.0, - text="", - transactionType="", - serviceStartDate="", - serviceEndDate="" - ) - ] - ) - ) - mock_completion = ParsedChatCompletion( - id="id", - created=0, - model="model", - object="chat.completion", - choices=[ - ParsedChoice( - message=ParsedChatCompletionMessage( - parsed=extraction_response, - role="assistant" - ), - index=0, - finish_reason="stop" - ) - ], - usage=CompletionUsage( - completion_tokens=100, - prompt_tokens=101, - total_tokens=201 - ) - ) - mock_azure_open_ai_instance = mock_azure_open_ai.return_value - mock_azure_open_ai_instance.beta.chat.completions.parse.return_value = mock_completion - - gpt_only_extractor = GPTOnlyExtractor({ - "azure_openai_endpoint": "https://example.com", - "azure_openai_api_key": "SSSHHH", - "gpt_deployment_name": 'gpt-4o', - "temperature": 0, - "prompt_config": {'prompt_name': 'medical_claim_reimbursement', 'line_item_instructions': 'complex'} - }, mock_logger_proxy) - result = gpt_only_extractor.extract_data("BASE64_STRING") - - self.assertEqual(result, extraction_response) - mock_azure_open_ai_instance.beta.chat.completions.parse.assert_called_with( - model='gpt-4o', - temperature=0, - messages=mock_create_prompt.return_value, - response_format=ExtractionResponse - ) - mock_logger_proxy.log_metric.assert_any_call("completion_tokens", 100) - mock_logger_proxy.log_metric.assert_any_call("prompt_tokens", 101) - - @patch('src.invoice_processing.predict_component.predict.data_extraction.extractors.gpt_only_extractor.PromptManager.get_prompt') - @patch('src.invoice_processing.predict_component.predict.data_extraction.extractors.gpt_only_extractor.AzureOpenAI') - @patch('src.invoice_processing.predict_component.predict.data_extraction.extractors.gpt_only_extractor.LoggerProxy') - def test_create_prompt(self, mock_logger_proxy, mock_azure_open_ai, mock_get_prompt): - mock_get_prompt.return_value = "Extract data from this invoice" - - gpt_only_extractor = GPTOnlyExtractor({ - "azure_openai_endpoint": "https://example.com", - "azure_openai_api_key": "SSSHHH", - "gpt_deployment_name": 'gpt-4o', - "prompt_config": {'prompt_name': 'medical_claim_reimbursement', 'line_item_instructions': 'complex'} - }, mock_logger_proxy) - - base64_image = "base64_image_string" - messages = gpt_only_extractor.create_prompt(base64_image) - - self.assertEqual(messages, [ - { - "role": "system", - "content": - "You are an AI assistant that analyzes the text provided " - "and supplemented images and returns them as structured JSON objects. " - "Do not return as a code block." - }, - { - "role": "user", - "content": [ - { - "type": "text", - "text": mock_get_prompt.return_value - }, - { - "type": "image_url", - "image_url": {"url": f"data:image/png;base64,{base64_image}"} - } - ] - } - ]) - - @patch('src.invoice_processing.predict_component.predict.data_extraction.extractors.gpt_only_extractor.AzureOpenAI') - @patch('src.invoice_processing.predict_component.predict.data_extraction.extractors.gpt_only_extractor.LoggerProxy') - @patch('src.invoice_processing.predict_component.predict.data_extraction.extractors.gpt_only_extractor.GPTOnlyExtractor.create_prompt') - def test_extract_data_with_retry(self, mock_create_prompt, mock_logger_proxy, mock_azure_open_ai): - mock_create_prompt.return_value = [ - { - "role": "system", - "content": "You are an AI assistant" - }, - {"role": "user", "content": "Hi."} - ] - - mock_azure_open_ai_instance = mock_azure_open_ai.return_value - mock_azure_open_ai_instance.beta.chat.completions.parse.side_effect = [Exception("Error"), Exception("Error"), ParsedChatCompletion( - id="id", - created=0, - model="model", - object="chat.completion", - choices=[ - ParsedChoice( - message=ParsedChatCompletionMessage( - parsed=ExtractionResponse( - invoice=Invoice( - totalClaimAmount=0.0, - provider=Provider( - name="" - ), - serviceFor=ServiceFor( - name="" - ), - lineItems=[ - LineItem( - amount=0.0, - text="", - transactionType="", - serviceStartDate="", - serviceEndDate="" - ) - ] - ) - ), - role="assistant" - ), - index=0, - finish_reason="stop" - ) - ], - usage=CompletionUsage( - completion_tokens=100, - prompt_tokens=101, - total_tokens=201 - ) - )] - - gpt_only_extractor = GPTOnlyExtractor({ - "azure_openai_endpoint": "https://example.com", - "azure_openai_api_key": "SSSHHH", - "gpt_deployment_name": 'gpt-4o', - "prompt_config": {'prompt_name': 'medical_claim_reimbursement', 'line_item_instructions': 'complex'} - }, mock_logger_proxy) - result = gpt_only_extractor.extract_data("BASE64_STRING") - - self.assertIsNotNone(result) - self.assertIsInstance(result, ExtractionResponse) - self.assertEqual(mock_azure_open_ai_instance.beta.chat.completions.parse.call_count, 3) - mock_logger_proxy.log_metric.assert_any_call("completion_tokens", 100) - mock_logger_proxy.log_metric.assert_any_call("prompt_tokens", 101) - -if __name__ == '__main__': - unittest.main() diff --git a/test/invoice_processing/predict_component/predict/data_extraction/models/test_extraction_response.py b/test/invoice_processing/predict_component/predict/data_extraction/models/test_extraction_response.py deleted file mode 100644 index 8e4d2b5b..00000000 --- a/test/invoice_processing/predict_component/predict/data_extraction/models/test_extraction_response.py +++ /dev/null @@ -1,54 +0,0 @@ -import unittest -from pydantic import ValidationError - -from src.invoice_processing.predict_component.predict.data_extraction.models.extraction_response import ( - ExtractionResponse -) - - -class TestExtractionResponse(unittest.TestCase): - def setUp(self): - self.valid_line_item = { - "amount": 100.0, - "text": "Consultation", - "transactionType": "Service", - "serviceStartDate": "2023-01-01", - "serviceEndDate": "2023-01-02" - } - self.valid_invoice = { - "totalClaimAmount": 100.0, - "provider": { - "name": "Provider A" - }, - "serviceFor": { - "name": "Patient A" - }, - "lineItems": [self.valid_line_item] - } - self.valid_invoice_data = { - "invoice": self.valid_invoice - } - - def test_valid_invoice_data(self): - invoice_data = ExtractionResponse(**self.valid_invoice_data) - self.assertEqual(invoice_data.invoice.totalClaimAmount, 100.0) - self.assertEqual(invoice_data.invoice.provider.name, "Provider A") - self.assertEqual(invoice_data.invoice.serviceFor.name, "Patient A") - self.assertEqual(len(invoice_data.invoice.lineItems), 1) - self.assertEqual(invoice_data.invoice.lineItems[0].amount, 100.0) - - def test_invalid_invoice_data(self): - invalid_invoice_data = self.valid_invoice_data.copy() - invalid_invoice_data["invoice"]["totalClaimAmount"] = "invalid_amount" - with self.assertRaises(ValidationError): - ExtractionResponse(**invalid_invoice_data) - - def test_missing_required_field(self): - invalid_invoice_data = self.valid_invoice_data.copy() - del invalid_invoice_data["invoice"]["provider"]["name"] - with self.assertRaises(ValidationError): - ExtractionResponse(**invalid_invoice_data) - - -if __name__ == '__main__': - unittest.main() diff --git a/test/invoice_processing/predict_component/predict/data_extraction/prompts/test_prompt_manager.py b/test/invoice_processing/predict_component/predict/data_extraction/prompts/test_prompt_manager.py deleted file mode 100644 index bd0414ea..00000000 --- a/test/invoice_processing/predict_component/predict/data_extraction/prompts/test_prompt_manager.py +++ /dev/null @@ -1,45 +0,0 @@ -import unittest -from unittest.mock import patch, mock_open - -from src.invoice_processing.predict_component.predict.data_extraction.prompts.prompt_manager import PromptManager - - -class TestPromptManager(unittest.TestCase): - @patch('builtins.open', new_callable=mock_open, - read_data='---\ndescription: Test template\nauthor: Test Author\n---\nHello, {{ name }}!') - @patch('src.invoice_processing.predict_component.predict.data_extraction.prompts.prompt_manager.FileSystemLoader.get_source') - def test_get_prompt(self, mock_get_source, mock_file): - mock_get_source.return_value = ('template content', 'template/path', lambda: True) - result = PromptManager.get_prompt('test_template', name='World') - self.assertEqual(result, 'Hello, World!') - - @patch('builtins.open', new_callable=mock_open, - read_data='---\ndescription: Test template\nauthor: Test Author\n---\nHello, {{ name }}!') - @patch('src.invoice_processing.predict_component.predict.data_extraction.prompts.prompt_manager.FileSystemLoader.get_source') - def test_get_prompt_template_error(self, mock_get_source, mock_file): - mock_get_source.return_value = ('template content', 'template/path', lambda: True) - with self.assertRaises(ValueError) as context: - PromptManager.get_prompt('test_template') - self.assertIn('Error rendering template', str(context.exception)) - - @patch('builtins.open', new_callable=mock_open, - read_data='---\ndescription: Test template\nauthor: Test Author\n---\nHello, {{ name }}!') - @patch('src.invoice_processing.predict_component.predict.data_extraction.prompts.prompt_manager.FileSystemLoader.get_source') - def test_get_template_info(self, mock_get_source, mock_file): - mock_get_source.return_value = ('template content', 'template/path', lambda: True) - result = PromptManager.get_template_info('test_template') - expected_result = { - 'name': 'test_template', - 'description': 'Test template', - 'author': 'Test Author', - 'variables': ['name'], - 'frontmatter': { - 'description': 'Test template', - 'author': 'Test Author' - } - } - self.assertEqual(result, expected_result) - - -if __name__ == '__main__': - unittest.main() diff --git a/test/invoice_processing/predict_component/predict/data_extraction/test_configuration_container.py b/test/invoice_processing/predict_component/predict/data_extraction/test_configuration_container.py deleted file mode 100644 index fa6b24ae..00000000 --- a/test/invoice_processing/predict_component/predict/data_extraction/test_configuration_container.py +++ /dev/null @@ -1,29 +0,0 @@ -import unittest -from unittest.mock import patch, mock_open - -from src.invoice_processing.predict_component.predict.data_extraction.config.configuration_container import ( - ConfigurationContainer -) - - -class TestConfigurationContainer(unittest.TestCase): - - def setUp(self): - # Clear the config registry before each test - ConfigurationContainer._config_registry = {} - - def test_register_and_get_config(self): - config = {"key": "value"} - ConfigurationContainer.register_config("extractor1", config) - retrieved_config = ConfigurationContainer.get_config("extractor1") - self.assertEqual(retrieved_config, config) - - def test_get_config_not_registered(self): - retrieved_config = ConfigurationContainer.get_config("non_existent_extractor") - self.assertEqual(retrieved_config, {}) - - @patch("builtins.open", new_callable=mock_open, read_data='{"extractor2": {"key": "value"}}') - @patch("json.load", return_value={"extractor2": {"key": "value"}}) - def test_load_configs_from_file(self, mock_json_load, mock_open): - ConfigurationContainer.load_configs_from_file("dummy_path") - self.assertEqual(ConfigurationContainer._config_registry["extractor2"], {"key": "value"}) diff --git a/test/invoice_processing/predict_component/predict/data_extraction/test_extractor_factory.py b/test/invoice_processing/predict_component/predict/data_extraction/test_extractor_factory.py deleted file mode 100644 index 6d1056e7..00000000 --- a/test/invoice_processing/predict_component/predict/data_extraction/test_extractor_factory.py +++ /dev/null @@ -1,77 +0,0 @@ -import os -import unittest -from unittest.mock import MagicMock - -from src.invoice_processing.predict_component.predict.data_extraction.config.configuration_container import ( - ConfigurationContainer -) -from src.invoice_processing.predict_component.predict.data_extraction.data_extractor_factory import ( - DataExtractorFactory -) -from src.invoice_processing.predict_component.predict.data_extraction.extractors.gpt_only_extractor import ( - GPTOnlyExtractor -) -from test.invoice_processing.predict_component.predict.data_extraction.assets.mock_extractor import MockExtractor - - -class TestDataExtractorFactory(unittest.TestCase): - def setUp(self): - self.assets_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") - config_path = os.path.join(self.assets_path, "config.json") - ConfigurationContainer.load_configs_from_file(config_path) - self.additional_config = { - "prompt_config": { - "prompt_name": "medical_claim_reimbursement", - "line_item_instructions": "complex" - } - } - self.logger = MagicMock() - DataExtractorFactory.register("mockextractor", "mock", MockExtractor) - - def test_load_default_extractors(self): - DataExtractorFactory.load_default_extractors() - extractor = DataExtractorFactory.create("invoice", - "gpt_only", - self.additional_config, - self.logger) - - self.assertIsInstance(extractor, GPTOnlyExtractor) - - def test_create_extractor(self): - extractor = DataExtractorFactory.create("mock", - "mockextractor", - self.additional_config, - self.logger) - - self.assertIsInstance(extractor, MockExtractor) - - resp = extractor.extract_data("") - self.assertEqual(resp.invoice.provider.name, "Mock Provider") - self.assertEqual(resp.invoice.serviceFor.name, "Mock Patient") - - def test_create_extractor_invalid_category(self): - with self.assertRaises(ValueError): - DataExtractorFactory.create("invalid", "mock_ocr_extractor", - self.additional_config, - self.logger) - - def test_create_extractor_invalid_name(self): - with self.assertRaises(ValueError): - DataExtractorFactory.create("mock", "invalid_extractor", - self.additional_config, - self.logger) - - def test_list_categories(self): - categories = DataExtractorFactory.list_categories() - self.assertIn("mock", categories) - - def test_list_extractors(self): - extractors = DataExtractorFactory.list_extractors("mock") - self.assertIn("mockextractor", extractors) - - def test_register_invalid_extractor(self): - with self.assertRaises(ValueError): - DataExtractorFactory.register("invalid_extractor", "mock", object) - -if __name__ == '__main__': - unittest.main() diff --git a/test/invoice_processing/predict_component/predict/test_helpers.py b/test/invoice_processing/predict_component/predict/test_helpers.py deleted file mode 100644 index 3e2571d1..00000000 --- a/test/invoice_processing/predict_component/predict/test_helpers.py +++ /dev/null @@ -1,33 +0,0 @@ -import base64 -import unittest -from unittest.mock import patch, mock_open -import json - -from src.invoice_processing.predict_component.predict.helpers import save_output_as_json, convert_image_to_base64 - - -class TestPredictOrchestratorHelpers(unittest.TestCase): - @patch("builtins.open", new_callable=mock_open) - def test_save_output_as_json(self, mock_file): - output = {"key": "value"} - output_file_path = "test_output.json" - - save_output_as_json(output, output_file_path) - - mock_file.assert_called_once_with(output_file_path, 'w', encoding='utf-8') - handle = mock_file() - written_content = ''.join(call.args[0] for call in handle.write.call_args_list) - self.assertEqual(written_content, json.dumps(output, ensure_ascii=False, indent=4)) - - @patch("builtins.open", new_callable=mock_open, read_data=b"fake_image_data") - def test_convert_image_to_base64(self, mock_file): - image_path = "test_image.png" - - result = convert_image_to_base64(image_path) - - mock_file.assert_called_once_with(image_path, "rb") - self.assertEqual(result, base64.b64encode(b"fake_image_data").decode('utf-8')) - - -if __name__ == '__main__': - unittest.main() diff --git a/test/invoice_processing/predict_component/predict/test_predict.py b/test/invoice_processing/predict_component/predict/test_predict.py deleted file mode 100644 index 9298ead5..00000000 --- a/test/invoice_processing/predict_component/predict/test_predict.py +++ /dev/null @@ -1,107 +0,0 @@ -import os -import unittest -from unittest.mock import ANY, patch, MagicMock -from src.invoice_processing.predict_component.predict.data_extraction.models.extraction_response import ( - ExtractionResponse, - Invoice, - LineItem, - Provider, - ServiceFor -) - -from src.invoice_processing.predict_component.predict.predict import predict, main, process - - -class TestPredictFunctions(unittest.TestCase): - @patch('src.invoice_processing.predict_component.predict.data_extraction.data_extractor_factory.DataExtractorFactory.create') - @patch('os.makedirs') - @patch('src.invoice_processing.predict_component.predict.predict.glob_by_extesion') - @patch('src.invoice_processing.predict_component.predict.predict.MLFlowLogger') - @patch('src.invoice_processing.predict_component.predict.predict.mlflow') - @patch('src.invoice_processing.predict_component.predict.predict.process') - @patch('pandas.DataFrame.to_csv') - def test_predict(self, mock_to_csv, mock_process, mock_mlflow, mock_logger, mock_glob, mock_makedirs, mock_factory_create): - mock_extractor = MagicMock() - mock_factory_create.return_value = mock_extractor - mock_process.return_value = ExtractionResponse( - invoice=Invoice( - totalClaimAmount=0.0, - provider=Provider( - name="" - ), - serviceFor=ServiceFor( - name="" - ), - lineItems=[ - LineItem( - amount=0.0, - text="", - transactionType="", - serviceStartDate="", - serviceEndDate="" - ) - ] - ) - ) - mock_glob.return_value = ['file1.png', 'file2.jpg'] - azure_openai_endpoint = "https://example.com" - azure_openai_api_key = "test_api_key" - - predict('gpt_only', 0, 'gpt-4o', azure_openai_endpoint, azure_openai_api_key, - "{'prompt_name':'medical_claim_reimbursement','line_item_instructions':'complex'}", - 'test_data', 'prediction_path') - - mock_mlflow.log_params.assert_any_call({ - "gpt_deployment_name": "gpt-4o", - "temperature": 0, - "prompt_name": "medical_claim_reimbursement", - "line_item_instructions": "complex" - }) - - mock_factory_create.assert_called_once_with('invoice', 'gpt_only', { - "azure_openai_endpoint": azure_openai_endpoint, - "azure_openai_api_key": azure_openai_api_key, - "gpt_deployment_name": 'gpt-4o', - "temperature": 0, - "prompt_config": {'prompt_name': 'medical_claim_reimbursement', 'line_item_instructions': 'complex'} - }, ANY) - mock_makedirs.assert_called_once_with('prediction_path', exist_ok=True) - self.assertEqual(mock_process.call_count, 2) - - @patch('src.invoice_processing.predict_component.predict.predict.convert_image_to_base64') - @patch('src.invoice_processing.predict_component.predict.predict.save_output_as_json') - @patch('src.invoice_processing.predict_component.predict.predict.Extractor') - def test_process(self, mock_extractor, mock_save_output_as_json, mock_convert_image_to_base64): - mock_convert_image_to_base64.return_value = "IMAGINE_I_AM_BASE64" - mock_extractor.extract_data.return_value = ExtractionResponse( - invoice=Invoice( - provider=Provider( - name="Bob" - ), - serviceFor=ServiceFor( - name="Greg" - ), - lineItems=[], - totalClaimAmount=0.99 - ) - ) - input_path = 'file1.png' - output_path = 'output_path' - process(mock_extractor, input_path, output_path) - mock_extractor.extract_data.assert_called_once_with("IMAGINE_I_AM_BASE64") - output_file_path = os.path.join(output_path, "file1_result.json") - mock_save_output_as_json.assert_called_once_with(mock_extractor.extract_data.return_value.model_dump(), output_file_path) - - @patch('src.invoice_processing.predict_component.predict.predict.predict') - def test_main(self, mock_predict): - main('gpt_only',0 , 'gpt-4o', "https://example.com", "test_api_key", - "{'prompt_name':'claim_reimbursement','line_item_instructions':'complex'}", - 'test_data', 'prediction_path') - - mock_predict.assert_called_once_with('gpt_only', 0,'gpt-4o', "https://example.com", "test_api_key", - "{'prompt_name':'claim_reimbursement','line_item_instructions':'complex'}", - 'test_data', 'prediction_path') - - -if __name__ == '__main__': - unittest.main() diff --git a/test/invoice_processing/score_component/test_experiment_config.yaml b/test/invoice_processing/score_component/test_experiment_config.yaml deleted file mode 100644 index c6f28d10..00000000 --- a/test/invoice_processing/score_component/test_experiment_config.yaml +++ /dev/null @@ -1,15 +0,0 @@ -score_config: - fuzzy_match_config: - field_match_threshold: 0.0 - fuzzy_compare_methods: - levenshtein: true - exact_match_fields: - start_date_match: true - end_date_match: true - amount_match: true - find_best_matches_strategy: levenshtein - matchers_dict: - serviceStartDate: date_exact_match - serviceEndDate: date_exact_match - amount: amount_exact_match - description: description_levenshtein diff --git a/test/invoice_processing/score_component/test_extraction_evaluator.py b/test/invoice_processing/score_component/test_extraction_evaluator.py deleted file mode 100644 index 6afbf9b8..00000000 --- a/test/invoice_processing/score_component/test_extraction_evaluator.py +++ /dev/null @@ -1,119 +0,0 @@ -""" -Unit tests for functions of the ExtractionEvaluator class in the experimentation framework -""" - -import unittest -import yaml -import pandas as pd -from src.invoice_processing.score_component.score.score import ( - create_extraction_evaluator, - get_score_config, -) -from src.invoice_processing.score_component.score.matchers.levenshtein_matcher import ( - LevenshteinMatcher, -) -from src.invoice_processing.score_component.score.matchers.text_exact_matcher import ( - TextExactMatcher, -) - - -class TestExtractionEvaluator(unittest.TestCase): - """ - Test extraction_evaluator.py - """ - - def __init__(self, methodName="runTest"): - super().__init__(methodName) - self.score_config = str( - yaml.safe_load( - open( - "test/invoice_processing/score_component/test_experiment_config.yaml" - ) - )["score_config"] - ) - - def setup_datasets(self): - """ - Setup ground truth data and a corresponding predictions data, - whose fields are a perfect match. - Returns: - ground_truth_df: ground truth dataframe - pred_df: predictions dataframe - """ - - ground_truth_df = pd.DataFrame( - { - "gt_index": [0, 1, 2, 3], - "serviceStartDate": ["1/3/24", "12/29/23", "1/4/24", "1/10/24"], - "serviceEndDate": ["1/9/24", "12/30/23", "1/4/24", "1/12/24"], - "amount": [134, 324, 78, 200], - "description": [ - "Child care service", - "After school program", - "Learning center", - "Swimming lessons", - ], - } - ) - pred_df = pd.DataFrame( - { - "pred_index": [0, 1, 2, 3], - "serviceStartDate": ["12/29/23", "1/1/24", "1/4/24", ""], - "serviceEndDate": ["12/30/23", "1/9/24", "1/4/24", "3/1/24"], - "amount": [324, 134, 76, 100], - "description": [ - "After school program", - "Child care", - "Learning center", - "Summer camp", - ], - "miles": [None, None, None, None], - } - ) - return ground_truth_df, pred_df - - def test_find_best_matches_levenshtein(self): - """ - Test find_best_matches function that is meant to find the best - matches from the predictions data to the ground truth data. - """ - score_config_dict = get_score_config(self.score_config) - fuzzy_match_config = score_config_dict["fuzzy_match_config"] - ground_truth_df, pred_df = self.setup_datasets() - evaluator = create_extraction_evaluator(self.score_config) - comparison_df = evaluator.compare_line_item_values_per_invoice( - ground_truth_df, pred_df - ) - best_matches_dict = LevenshteinMatcher().find_best_matches( - comparison_df, fuzzy_match_config - ) - best_matches_df = pd.DataFrame(best_matches_dict["levenshtein"]) - matches_indices = list( - zip(best_matches_df["gt_index"], best_matches_df["pred_index"]) - ) - self.assertTrue((0.0, 1.0) in matches_indices) - self.assertTrue((1.0, 0.0) in matches_indices) - self.assertTrue((2.0, 2.0) in matches_indices) - - def test_find_best_matches_base_exact_matcher(self): - """ - Test find_best_matches function that is meant to find the best - matches from the predictions data to the ground truth data. - """ - ground_truth_df, pred_df = self.setup_datasets() - evaluator = create_extraction_evaluator(self.score_config) - comparison_df = evaluator.compare_line_item_values_per_invoice( - ground_truth_df, pred_df - ) - best_matches_dict = TextExactMatcher().find_best_matches(comparison_df) - best_matches_df = pd.DataFrame(best_matches_dict["exact_match"]) - matches_indices = list( - zip(best_matches_df["gt_index"], best_matches_df["pred_index"]) - ) - self.assertTrue((0.0, 1.0) in matches_indices) - self.assertTrue((1.0, 0.0) in matches_indices) - self.assertTrue((2.0, 2.0) in matches_indices) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/invoice_processing/score_component/test_score.py b/test/invoice_processing/score_component/test_score.py deleted file mode 100644 index 19290751..00000000 --- a/test/invoice_processing/score_component/test_score.py +++ /dev/null @@ -1,457 +0,0 @@ -""" -Unit tests for the evaluation step in the experimentation framework -""" - -import unittest -import yaml -import pandas as pd -from src.invoice_processing.score_component.score.score import ( - evaluate, - get_gt_and_pred_data_for_evaluation, -) - - -class TestScore(unittest.TestCase): - """ - Test evaluate output for perfect match - """ - - def __init__(self, methodName="runTest"): - super().__init__(methodName) - self.score_config = str( - yaml.safe_load( - open( - "test/invoice_processing/score_component/test_experiment_config.yaml" - ) - )["score_config"] - ) - - def test_evaluate_perfect_match(self): - - ground_truth = [ - { - "reference_id": "12345", - "lineItems": [ - { - "serviceStartDate": "12/29/23", - "serviceEndDate": "12/30/23", - "amount": 324, - "description": "Child care", - }, - { - "serviceStartDate": "1/1/24", - "serviceEndDate": "3/1/24", - "amount": 134, - "description": "After school program", - }, - { - "serviceStartDate": "4/1/24", - "serviceEndDate": "4/1/24", - "amount": 76, - "description": "Learning center", - }, - ], - } - ] - - pred = { - "12345.jpg": { - "invoice": { - "lineItems": [ - { - "serviceStartDate": "12/29/23", - "serviceEndDate": "12/30/23", - "amount": 324, - "text": "Child care", - "miles": None, - }, - { - "serviceStartDate": "1/1/24", - "serviceEndDate": "3/1/24", - "amount": 134, - "text": "After school program", - "miles": None, - }, - { - "serviceStartDate": "4/1/24", - "serviceEndDate": "4/1/24", - "amount": 76, - "text": "Learning center", - "miles": None, - }, - ], - } - } - } - - ( - final_results_df, - overall_accuracy, - gt_invoices_number, - pred_invoices_number, - all_unmatched_gt, - all_unmatched_pred, - comparison_df_all, - best_matches_all, - all_matches_results_total, - overall_precision, - overall_recall, - ) = evaluate(pred, ground_truth, self.score_config) - self.assertEqual(overall_accuracy, 1.0) - self.assertEqual(overall_precision, 1.0) - self.assertEqual(overall_recall, 1.0) - - def test_evaluate_partial_match(self): - """ - Test evaluate output for partial match - """ - ground_truth = [ - { - "reference_id": "12345", - "lineItems": [ - { - "serviceStartDate": "12/26/23", - "serviceEndDate": "12/30/23", - "amount": 324, - "description": "Child care", - }, - { - "serviceStartDate": "1/1/24", - "serviceEndDate": "3/1/24", - "amount": 134, - "description": "Tuition", - }, - { - "serviceStartDate": "4/1/24", - "serviceEndDate": "4/1/24", - "amount": 76, - "description": "Learning center", - }, - ], - } - ] - - pred = { - "12345.jpg": { - "invoice": { - "lineItems": [ - { - "serviceStartDate": "12/29/23", - "serviceEndDate": "12/30/23", - "amount": 324, - "text": "Child care", - "miles": None, - }, - { - "serviceStartDate": "1/1/24", - "serviceEndDate": "3/1/24", - "amount": 134, - "text": "After school program", - "miles": None, - }, - { - "serviceStartDate": "4/1/24", - "serviceEndDate": "4/1/24", - "amount": 76, - "text": "Learning center", - "miles": None, - }, - { - "serviceStartDate": "12/30/23", - "serviceEndDate": "12/31/23", - "amount": 267, - "text": "Emergency room", - "miles": None, - }, - ], - } - } - } - - ( - final_results_df, - overall_accuracy, - gt_invoices_number, - pred_invoices_number, - all_unmatched_gt, - all_unmatched_pred, - comparison_df_all, - best_matches_all, - all_matches_results_total, - overall_precision, - overall_recall, - ) = evaluate(pred, ground_truth, self.score_config) - self.assertEqual(round(overall_accuracy, 3), 0.634) - - def test_evaluate_no_match(self): - """ - Test evaluate output for no match - """ - ground_truth = [ - { - "reference_id": "12345", - "lineItems": [ - { - "serviceStartDate": "12/26/23", - "serviceEndDate": "12/30/23", - "amount": 324, - "description": "Child care", - }, - { - "serviceStartDate": "1/1/24", - "serviceEndDate": "3/1/24", - "amount": 134, - "description": "Tuition", - }, - { - "serviceStartDate": "4/1/24", - "serviceEndDate": "4/1/24", - "amount": 76, - "description": "Learning center", - }, - ], - } - ] - - pred = { - "12345.jpg": { - "invoice": { - "lineItems": [ - { - "serviceStartDate": "", - "serviceEndDate": "12/31/23", - "amount": 267, - "text": "Emergency room", - "miles": None, - }, - { - "serviceStartDate": "2/1/24", - "serviceEndDate": "", - "amount": 152, - "text": "After school program", - "miles": None, - }, - { - "serviceStartDate": "", - "serviceEndDate": "5/1/24", - "amount": 74, - "text": "Medicines", - "miles": None, - }, - ] - } - } - } - - ( - final_results_df, - overall_accuracy, - gt_invoices_number, - pred_invoices_number, - all_unmatched_gt, - all_unmatched_pred, - comparison_df_all, - best_matches_all, - all_matches_results_total, - overall_precision, - overall_recall, - ) = evaluate(pred, ground_truth, self.score_config) - self.assertEqual(round(overall_accuracy, 3), 0.07) - self.assertEqual(overall_precision, 1.0) - self.assertEqual(overall_recall, 1.0) - - def test_evaluate_partial_match_for_recall(self): - """ - Test evaluate output for partial match, for recall. - """ - ground_truth = [ - { - "reference_id": "12345", - "lineItems": [ - { - "serviceStartDate": "12/26/23", - "serviceEndDate": "12/30/23", - "amount": 324, - "description": "Child care", - }, - { - "serviceStartDate": "1/1/24", - "serviceEndDate": "3/1/24", - "amount": 134, - "description": "Tuition", - }, - { - "serviceStartDate": "4/1/24", - "serviceEndDate": "4/1/24", - "amount": 76, - "description": "Learning center", - }, - { - "serviceStartDate": "4/5/24", - "serviceEndDate": "4/7/24", - "amount": 94, - "description": "Lunch fee", - }, - ], - } - ] - - pred = { - "12345.jpg": { - "InvoiceDetails": { - "lineItems": [ - { - "serviceStartDate": "12/29/23", - "serviceEndDate": "12/30/23", - "amount": 324, - "description": "Child care", - "miles": None, - }, - { - "serviceStartDate": "1/1/24", - "serviceEndDate": "3/1/24", - "amount": 134, - "description": "After school program", - "miles": None, - }, - { - "serviceStartDate": "4/1/24", - "serviceEndDate": "4/1/24", - "amount": 76, - "description": "Learning center", - "miles": None, - }, - ], - } - } - } - - ( - final_results_df, - overall_accuracy, - gt_invoices_number, - pred_invoices_number, - all_unmatched_gt, - all_unmatched_pred, - comparison_df_all, - best_matches_all, - all_matches_results_total, - overall_precision, - overall_recall, - ) = evaluate(pred, ground_truth, self.score_config) - self.assertEqual(round(overall_recall, 3), 0.75) - - def test_evaluate_partial_match_for_precision(self): - """ - Test evaluate output for partial match, for precision. - """ - ground_truth = [ - { - "reference_id": "12345", - "lineItems": [ - { - "serviceStartDate": "12/26/23", - "serviceEndDate": "12/30/23", - "amount": 324, - "description": "Child care", - }, - { - "serviceStartDate": "1/1/24", - "serviceEndDate": "3/1/24", - "amount": 134, - "description": "Tuition", - }, - { - "serviceStartDate": "4/1/24", - "serviceEndDate": "4/1/24", - "amount": 76, - "description": "Learning center", - }, - ], - } - ] - - pred = { - "12345.jpg": { - "InvoiceDetails": { - "lineItems": [ - { - "serviceStartDate": "12/29/23", - "serviceEndDate": "12/30/23", - "amount": 324, - "description": "Child care", - "miles": None, - }, - { - "serviceStartDate": "1/1/24", - "serviceEndDate": "3/1/24", - "amount": 134, - "description": "After school program", - "miles": None, - }, - { - "serviceStartDate": "4/1/24", - "serviceEndDate": "4/1/24", - "amount": 76, - "description": "Learning center", - "miles": None, - }, - { - "serviceStartDate": "9/3/24", - "serviceEndDate": "9/3/24", - "amount": 100, - "description": "Registration fee", - "miles": None, - }, - ], - } - } - } - - ( - final_results_df, - overall_accuracy, - gt_invoices_number, - pred_invoices_number, - all_unmatched_gt, - all_unmatched_pred, - comparison_df_all, - best_matches_all, - all_matches_results_total, - overall_precision, - overall_recall, - ) = evaluate(pred, ground_truth, self.score_config) - self.assertEqual(round(overall_precision, 3), 0.75) - - def test_get_gt_and_pred_data_for_evaluation(self): - ground_truth = { - "reference_id": "12345.jpg", - "lineItems": [ - { - "description": "Dependent care", - "amount": 120, - "serviceStartDate": "07/07/2021", - "serviceEndDate": "07/23/2021", - }, - { - "description": "Lunch fee", - "amount": 64, - "serviceStartDate": "07/07/2021", - "serviceEndDate": "07/23/2021", - }, - ], - } - - predictions = { - "invoice": { - "lineItems": [], - } - } - gt_data, pred_data = get_gt_and_pred_data_for_evaluation( - ground_truth, predictions - ) - self.assertTrue("miles" not in pred_data.columns.tolist()) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/invoice_processing/score_component/test_utils.py b/test/invoice_processing/score_component/test_utils.py deleted file mode 100644 index ed5f1f64..00000000 --- a/test/invoice_processing/score_component/test_utils.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -Unit tests for utils.py of the evaluation step of the experimentation framework. -""" - -import unittest -import pandas as pd -from src.invoice_processing.score_component.score.utils import normalize_string - - -class TestScore(unittest.TestCase): - - def test_normalize_str(self): - """ - Test normalize_str. - """ - value = " ( vaLue ) " - normalized_str = normalize_string(value) - self.assertEqual(normalized_str, "(value)") - - -if __name__ == "__main__": - unittest.main() diff --git a/test/invoice_processing/test_invoice_processing_to_delete.py b/test/invoice_processing/test_invoice_processing_to_delete.py new file mode 100644 index 00000000..36a932e3 --- /dev/null +++ b/test/invoice_processing/test_invoice_processing_to_delete.py @@ -0,0 +1,6 @@ +def test_invoice_processing_print(): + try: + print("Hello") is None + except: + print("Test print function failed.") + assert False From 37388baa1a9d5436729d974746f438a020bc77c1 Mon Sep 17 00:00:00 2001 From: lorrinferdinand-hue Date: Wed, 10 Dec 2025 16:54:21 -0700 Subject: [PATCH 20/21] Update TestInitialSetup.md Trigger pr checks with minor change --- docs/how-to/TestInitialSetup.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/how-to/TestInitialSetup.md b/docs/how-to/TestInitialSetup.md index fcced39d..0f4d1f0e 100644 --- a/docs/how-to/TestInitialSetup.md +++ b/docs/how-to/TestInitialSetup.md @@ -1,6 +1,6 @@ # Testing the initial setup -**Step 1.** In the main branch, supply an explicit value or accept the defaults in the file, config/config.yaml. The pipelines uses multiple variables and they should be set for both 'pr' and 'dev' plus any additional environments. Also, set the variables for all models (i.e. nyc_taxi, london_taxi). The config.yaml file is split into the following sections, set the values in each section: +**Step 1.** In the main branch, supply an explicit value or accept the defaults in the file, config/config.yaml. The pipelines uses multiple variables and they should be set for both 'pr' and 'dev' plus any additional environments. Also, set the variables for all models (i.e. nyc_taxi, london_taxi, sequence_model). The config.yaml file is split into the following sections, set the values in each section: - aml_config: Stores the configuration of azure resources hosting the Azure Machine Learning workspace. - environment_config: Stores the base image and dynamic properties set at runtime. From 19a6264a0dc494ebf6e3d0403385cb0015448ff1 Mon Sep 17 00:00:00 2001 From: lorrinferdinand-hue Date: Wed, 10 Dec 2025 17:06:20 -0700 Subject: [PATCH 21/21] Update TestInitialSetup.md Trigger another pr check --- docs/how-to/TestInitialSetup.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/how-to/TestInitialSetup.md b/docs/how-to/TestInitialSetup.md index 0f4d1f0e..054f04ef 100644 --- a/docs/how-to/TestInitialSetup.md +++ b/docs/how-to/TestInitialSetup.md @@ -1,6 +1,6 @@ # Testing the initial setup -**Step 1.** In the main branch, supply an explicit value or accept the defaults in the file, config/config.yaml. The pipelines uses multiple variables and they should be set for both 'pr' and 'dev' plus any additional environments. Also, set the variables for all models (i.e. nyc_taxi, london_taxi, sequence_model). The config.yaml file is split into the following sections, set the values in each section: +**Step 1.** In the main branch, supply an explicit value or accept the defaults in the file, config/config.yaml. The pipelines uses multiple variables and they should be set for both 'pr' and 'dev' plus any additional environments. Also, set the variables for all models (i.e. nyc_taxi, london_taxi, sequence_model, llm_example). The config.yaml file is split into the following sections, set the values in each section: - aml_config: Stores the configuration of azure resources hosting the Azure Machine Learning workspace. - environment_config: Stores the base image and dynamic properties set at runtime.