Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 182 additions & 23 deletions src/uipath/_services/documents_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from .._utils import Endpoint
from ..models.documents import (
ActionPriority,
ClassificationResponse,
ClassificationResult,
DigitizationResult,
ExtractionResponse,
ExtractionResponseIXP,
Expand All @@ -31,6 +33,10 @@ def _is_provided(arg: Any) -> bool:
return arg is not None


def _is_not_provided(arg: Any) -> bool:
return arg is None


def _must_not_be_provided(**kwargs: Any) -> None:
for name, value in kwargs.items():
if value is not None:
Expand All @@ -49,27 +55,57 @@ def _are_mutually_exclusive(**kwargs: Any) -> None:
raise ValueError(f"`{', '.join(provided)}` are mutually exclusive")


def _validate_extract_params(
project_name: Optional[str] = None,
file: Optional[FileContent] = None,
file_path: Optional[str] = None,
digitization_result: Optional[DigitizationResult] = None,
project_type: Optional[ProjectType] = ProjectType.IXP,
document_type_name: Optional[str] = None,
def _validate_extract_params_and_get_project_type(
project_name: Optional[str],
file: Optional[FileContent],
file_path: Optional[str],
digitization_result: Optional[DigitizationResult],
classification_result: Optional[ClassificationResult],
project_type: Optional[ProjectType],
document_type_name: Optional[str],
) -> ProjectType:
_are_mutually_exclusive(file=file, file_path=file_path)

if _is_provided(project_name):
_must_be_provided(project_type=project_type)
_must_not_be_provided(digitization_result=digitization_result)
else:
_must_not_be_provided(project_type=project_type)
_must_not_be_provided(file=file, file_path=file_path)
if _is_provided(digitization_result):
project_type = digitization_result.project_type
else:
_must_be_provided(classification_result=classification_result)
_must_not_be_provided(document_type_name=document_type_name)
project_type = ProjectType.MODERN

if _is_not_provided(classification_result):
if project_type == ProjectType.MODERN:
_must_be_provided(document_type_name=document_type_name)
else:
_must_not_be_provided(document_type_name=document_type_name)

return project_type
Comment on lines +67 to +88
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Painful to read through this. We really need to split into more granular functionalities per project type. But leaving that aside, isn't it missing some cases that are invalid? e.g. passing digitization with an IXP project and also a classification result.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I refined these checks in a later PR, so I also need to update them here. But as you said, there are lots of conditions and branches, it's even worse trying to describe them in words, like in docstrings



def _validate_classify_params(
project_name: Optional[str],
file: Optional[FileContent],
file_path: Optional[str],
digitization_result: Optional[DigitizationResult],
):
_are_mutually_exclusive(file=file, file_path=file_path)

if _is_provided(project_name):
_must_not_be_provided(digitization_result=digitization_result)
else:
_must_be_provided(digitization_result=digitization_result)
_must_not_be_provided(project_type=project_type, file=file, file_path=file_path)
project_type = digitization_result.project_type
_must_not_be_provided(file=file, file_path=file_path)

if project_type == ProjectType.MODERN:
_must_be_provided(document_type_name=document_type_name)
else:
_must_not_be_provided(document_type_name=document_type_name)
if digitization_result.project_type != ProjectType.MODERN:
raise ValueError(
"Classification is only supported for DU Modern projects. The provided digitization result is from an IXP project."
)


class DocumentsService(FolderContext, BaseService):
Expand Down Expand Up @@ -144,14 +180,18 @@ async def _get_project_tags_async(self, project_id: str) -> Set[str]:

def _get_document_id(
self,
project_id: Optional[str] = None,
file: Optional[FileContent] = None,
file_path: Optional[str] = None,
digitization_result: Optional[DigitizationResult] = None,
project_id: Optional[str],
file: Optional[FileContent],
file_path: Optional[str],
digitization_result: Optional[DigitizationResult],
classification_result: Optional[ClassificationResult],
) -> str:
if digitization_result is not None:
return digitization_result.document_object_model.document_id

if classification_result is not None:
return classification_result.document_id

return self._start_digitization(
project_id=project_id, file=file, file_path=file_path
)
Expand All @@ -162,11 +202,14 @@ def _get_project_id_and_validate_tag(
project_name: Optional[str],
project_type: Optional[ProjectType],
digitization_result: Optional[DigitizationResult],
classification_result: Optional[ClassificationResult],
) -> str:
if digitization_result is None:
if project_name is not None:
project_id = self._get_project_id_by_name(project_name, project_type)
else:
elif digitization_result is not None:
project_id = digitization_result.project_id
else:
project_id = classification_result.project_id

tags = self._get_project_tags(project_id)
if tag not in tags:
Expand Down Expand Up @@ -260,10 +303,14 @@ def _get_document_type_id(
project_id: str,
document_type_name: Optional[str],
project_type: ProjectType,
classification_result: Optional[ClassificationResult],
) -> str:
if project_type == ProjectType.IXP:
return str(UUID(int=0))

if classification_result is not None:
return classification_result.document_type_id

response = self.request(
"GET",
url=Endpoint(f"/du_/api/framework/projects/{project_id}/document-types"),
Expand Down Expand Up @@ -477,6 +524,55 @@ async def result_getter() -> Tuple[str, str, Any]:

return ExtractionResponse.model_validate(extraction_response)

def _start_classification(
self,
project_id: str,
tag: str,
document_id: str,
) -> str:
return self.request(
"POST",
url=Endpoint(
f"/du_/api/framework/projects/{project_id}/{tag}/classification/start"
),
params={"api-version": 1.1},
headers=self._get_common_headers(),
json={"documentId": document_id},
).json()["operationId"]

def _wait_for_classification(
self,
project_id: str,
tag: str,
operation_id: str,
) -> List[ClassificationResult]:
def result_getter() -> Tuple[str, Optional[str], Optional[str]]:
result = self.request(
method="GET",
url=Endpoint(
f"/du_/api/framework/projects/{project_id}/{tag}/classification/result/{operation_id}"
),
params={"api-version": 1.1},
headers=self._get_common_headers(),
).json()
return (
result["status"],
result.get("error", None),
result.get("result", None),
)

classification_response = self._wait_for_operation(
result_getter=result_getter,
wait_statuses=["NotStarted", "Running"],
success_status="Succeeded",
)
for classification_result in classification_response["classificationResults"]:
classification_result["ProjectId"] = project_id

return ClassificationResponse.model_validate(
classification_response
).classification_results

@traced(name="documents_digitize", run_type="uipath")
def digitize(
self,
Expand All @@ -497,6 +593,52 @@ def digitize(
project_id=project_id, document_id=document_id, project_type=project_type
)

@traced(name="documents_classify", run_type="uipath")
def classify(
self,
tag: str,
project_name: Optional[str] = None,
file: Optional[FileContent] = None,
file_path: Optional[str] = None,
digitization_result: Optional[DigitizationResult] = None,
) -> List[ClassificationResult]:
"""Classify a document using a DU Modern project."""

_validate_classify_params(
project_name=project_name,
file=file,
file_path=file_path,
digitization_result=digitization_result,
)

project_id = self._get_project_id_and_validate_tag(
tag=tag,
project_name=project_name,
project_type=ProjectType.MODERN,
digitization_result=digitization_result,
classification_result=None,
)

document_id = self._get_document_id(
project_id=project_id,
file=file,
file_path=file_path,
digitization_result=digitization_result,
classification_result=None,
)

operation_id = self._start_classification(
project_id=project_id,
tag=tag,
document_id=document_id,
)

return self._wait_for_classification(
project_id=project_id,
tag=tag,
operation_id=operation_id,
)

@traced(name="documents_extract", run_type="uipath")
def extract(
self,
Expand All @@ -505,7 +647,8 @@ def extract(
file: Optional[FileContent] = None,
file_path: Optional[str] = None,
digitization_result: Optional[DigitizationResult] = None,
project_type: ProjectType = ProjectType.IXP,
classification_result: Optional[ClassificationResult] = None,
project_type: Optional[ProjectType] = None,
document_type_name: Optional[str] = None,
) -> Union[ExtractionResponse, ExtractionResponseIXP]:
"""Extract predicted data from a document using an IXP project.
Expand Down Expand Up @@ -563,12 +706,27 @@ def extract(
project_type=None,
)
```
Using existing classification result:
```python
with open("alex.pdf", "rb") as file:
classification_results = uipath.documents.classify(
tag="Production",
project_name="MyModernProjectName",
file=file,
)

extraction_result = uipath.documents.extract(
tag="Production",
classification_result=max(results, key=lambda result: result.confidence),
)
```
"""
_validate_extract_params(
project_type = _validate_extract_params_and_get_project_type(
project_name=project_name,
file=file,
file_path=file_path,
digitization_result=digitization_result,
classification_result=classification_result,
project_type=project_type,
document_type_name=document_type_name,
)
Expand All @@ -578,21 +736,22 @@ def extract(
project_name=project_name,
project_type=project_type,
digitization_result=digitization_result,
classification_result=classification_result,
)

project_type = project_type or digitization_result.project_type

document_id = self._get_document_id(
project_id=project_id,
file=file,
file_path=file_path,
digitization_result=digitization_result,
classification_result=classification_result,
)

document_type_id = self._get_document_type_id(
project_id=project_id,
document_type_name=document_type_name,
project_type=project_type,
classification_result=classification_result,
)

operation_id = self._start_extraction(
Expand Down
51 changes: 51 additions & 0 deletions src/uipath/models/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,54 @@ class DigitizationResult(BaseModel):
document_text: str = Field(alias="documentText")
project_id: str = Field(alias="projectId")
project_type: ProjectType = Field(alias="projectType")


class Reference(BaseModel):
model_config = ConfigDict(
serialize_by_alias=True,
validate_by_alias=True,
)

text_start_index: int = Field(alias="TextStartIndex")
text_length: int = Field(alias="TextLength")
tokens: List[str] = Field(alias="Tokens")


class DocumentBounds(BaseModel):
model_config = ConfigDict(
serialize_by_alias=True,
validate_by_alias=True,
)

start_page: int = Field(alias="StartPage")
page_count: int = Field(alias="PageCount")
text_start_index: int = Field(alias="TextStartIndex")
text_length: int = Field(alias="TextLength")
page_range: int = Field(alias="PageRange")


class ClassificationResult(BaseModel):
model_config = ConfigDict(
serialize_by_alias=True,
validate_by_alias=True,
)

document_id: str = Field(alias="DocumentId")
document_type_id: str = Field(alias="DocumentTypeId")
confidence: float = Field(alias="Confidence")
ocr_confidence: float = Field(alias="OcrConfidence")
reference: Reference = Field(alias="Reference")
document_bounds: DocumentBounds = Field(alias="DocumentBounds")
classifier_name: str = Field(alias="ClassifierName")
project_id: str = Field(alias="ProjectId")


class ClassificationResponse(BaseModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not used. I guess the purpose of this class was to be used as response type for the classify method, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I forgot to delete it. At first, I thought about returning the entire response, but I think it’s more intuitive to just return a list of classification results, to keep it consistent with the other functions

model_config = ConfigDict(
serialize_by_alias=True,
validate_by_alias=True,
)

classification_results: List[ClassificationResult] = Field(
alias="classificationResults"
)
Loading