|
1 | 1 | import datetime |
2 | 2 | import os |
| 3 | +from enum import Enum |
3 | 4 | from typing import ( |
4 | 5 | TYPE_CHECKING, |
5 | 6 | Any, |
|
16 | 17 |
|
17 | 18 | from nucleus.annotation_uploader import AnnotationUploader, PredictionUploader |
18 | 19 | from nucleus.async_job import AsyncJob, EmbeddingsExportJob |
19 | | -from nucleus.prediction import Prediction, from_json |
| 20 | +from nucleus.evaluation_match import EvaluationMatch |
| 21 | +from nucleus.prediction import from_json as prediction_from_json |
20 | 22 | from nucleus.track import Track |
21 | 23 | from nucleus.url_utils import sanitize_string_args |
22 | 24 | from nucleus.utils import ( |
|
77 | 79 | construct_model_run_creation_payload, |
78 | 80 | construct_taxonomy_payload, |
79 | 81 | ) |
| 82 | +from .prediction import Prediction |
80 | 83 | from .scene import LidarScene, Scene, VideoScene, check_all_scene_paths_remote |
81 | 84 | from .slice import ( |
82 | 85 | Slice, |
|
98 | 101 | WARN_FOR_LARGE_SCENES_UPLOAD = 5 |
99 | 102 |
|
100 | 103 |
|
| 104 | +class ObjectQueryType(str, Enum): |
| 105 | + IOU = "iou" |
| 106 | + FALSE_POSITIVE = "false_positive" |
| 107 | + FALSE_NEGATIVE = "false_negative" |
| 108 | + PREDICTIONS_ONLY = "predictions_only" |
| 109 | + GROUND_TRUTH_ONLY = "ground_truth_only" |
| 110 | + |
| 111 | + |
101 | 112 | class Dataset: |
102 | 113 | """Datasets are collections of your data that can be associated with models. |
103 | 114 |
|
@@ -1681,7 +1692,7 @@ def upload_predictions( |
1681 | 1692 | :class:`Category<CategoryPrediction>`, and :class:`Category<SceneCategoryPrediction>` predictions. Cuboid predictions |
1682 | 1693 | can only be uploaded to a :class:`pointcloud DatasetItem<LidarScene>`. |
1683 | 1694 |
|
1684 | | - When uploading an prediction, you need to specify which item you are |
| 1695 | + When uploading a prediction, you need to specify which item you are |
1685 | 1696 | annotating via the reference_id you provided when uploading the image |
1686 | 1697 | or pointcloud. |
1687 | 1698 |
|
@@ -1854,7 +1865,7 @@ def prediction_loc(self, model, reference_id, annotation_id): |
1854 | 1865 | :class:`KeypointsPrediction` \ |
1855 | 1866 | ]: Model prediction object with the specified annotation ID. |
1856 | 1867 | """ |
1857 | | - return from_json( |
| 1868 | + return prediction_from_json( |
1858 | 1869 | self._client.make_request( |
1859 | 1870 | payload=None, |
1860 | 1871 | route=f"dataset/{self.id}/model/{model.id}/loc/{reference_id}/{annotation_id}", |
@@ -1999,6 +2010,47 @@ def query_scenes(self, query: str) -> Iterable[Scene]: |
1999 | 2010 | for item_json in json_generator: |
2000 | 2011 | yield Scene.from_json(item_json, None, True) |
2001 | 2012 |
|
| 2013 | + def query_objects( |
| 2014 | + self, |
| 2015 | + query: str, |
| 2016 | + query_type: ObjectQueryType, |
| 2017 | + model_run_id: Optional[str] = None, |
| 2018 | + ) -> Iterable[Union[Annotation, Prediction, EvaluationMatch]]: |
| 2019 | + """ |
| 2020 | + Fetches all objects in the dataset that pertain to a given structured query. |
| 2021 | + The results are either Predictions, Annotations, or Evaluation Matches, based on the objectType input parameter |
| 2022 | +
|
| 2023 | + Args: |
| 2024 | + query: Structured query compatible with the `Nucleus query language <https://nucleus.scale.com/docs/query-language-reference>`_. |
| 2025 | + objectType: Defines the type of the object to query |
| 2026 | +
|
| 2027 | + Returns: |
| 2028 | + An iterable of either Predictions, Annotations, or Evaluation Matches |
| 2029 | + """ |
| 2030 | + json_generator = paginate_generator( |
| 2031 | + client=self._client, |
| 2032 | + endpoint=f"dataset/{self.id}/queryObjectsPage", |
| 2033 | + result_key=ITEMS_KEY, |
| 2034 | + page_size=MAX_ES_PAGE_SIZE, |
| 2035 | + query=query, |
| 2036 | + patch_mode=query_type, |
| 2037 | + model_run_id=model_run_id, |
| 2038 | + ) |
| 2039 | + |
| 2040 | + for item_json in json_generator: |
| 2041 | + if query_type == ObjectQueryType.GROUND_TRUTH_ONLY: |
| 2042 | + yield Annotation.from_json(item_json) |
| 2043 | + elif query_type == ObjectQueryType.PREDICTIONS_ONLY: |
| 2044 | + yield prediction_from_json(item_json) |
| 2045 | + elif query_type in [ |
| 2046 | + ObjectQueryType.IOU, |
| 2047 | + ObjectQueryType.FALSE_POSITIVE, |
| 2048 | + ObjectQueryType.FALSE_NEGATIVE, |
| 2049 | + ]: |
| 2050 | + yield EvaluationMatch.from_json(item_json) |
| 2051 | + else: |
| 2052 | + raise ValueError("Unknown object type", query_type) |
| 2053 | + |
2002 | 2054 | @property |
2003 | 2055 | def tracks(self) -> List[Track]: |
2004 | 2056 | """Tracks unique to this dataset. |
|
0 commit comments