Skip to content

Commit aae998a

Browse files
frascuchonjfcalvopre-commit-ci[bot]damianpumar
authored
[ENHANCEMENT] argilla-server: List records endpoint using db (#5170)
# Description <!-- Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. --> This PR rewrites the current list dataset records endpoint to use the DB instead of the search engine since no filtering is applied to the endpoint. ~~The PR introduces a new abstraction layer to manage DB internals: repository. With this layer, we have all db methods related to a resource in a single place, which helps to maintainability and reusability.~~ DB Query details have been moved to the Record model class, simplifying the context flows. **Type of change** <!-- Please delete options that are not relevant. Remember to title the PR according to the type of change --> - Refactor (change restructuring the codebase without changing functionality) - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** <!-- Please add some reference about how your feature has been tested. --> **Checklist** <!-- Please go over the list and make sure you've taken everything into account --> - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: José Francisco Calvo <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Damián Pumar <[email protected]> Co-authored-by: José Francisco Calvo <[email protected]>
1 parent 496a8c3 commit aae998a

File tree

5 files changed

+157
-76
lines changed

5 files changed

+157
-76
lines changed

argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py

Lines changed: 20 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Dict, List, Optional, Tuple, Union
15+
from typing import Any, Dict, List, Optional, Union
1616
from uuid import UUID
1717

1818
from fastapi import APIRouter, Depends, Query, Security, status
@@ -43,17 +43,16 @@
4343
SearchSuggestionsOptions,
4444
SuggestionFilterScope,
4545
)
46-
from argilla_server.contexts import datasets, search
46+
from argilla_server.contexts import datasets, search, records
4747
from argilla_server.database import get_async_db
48-
from argilla_server.enums import RecordSortField, ResponseStatusFilter
48+
from argilla_server.enums import RecordSortField
4949
from argilla_server.errors.future import MissingVectorError, NotFoundError, UnprocessableEntityError
5050
from argilla_server.errors.future.base_errors import MISSING_VECTOR_ERROR_CODE
5151
from argilla_server.models import Dataset, Field, Record, User, VectorSettings
5252
from argilla_server.search_engine import (
5353
AndFilter,
5454
SearchEngine,
5555
SearchResponses,
56-
UserResponseStatusFilter,
5756
get_search_engine,
5857
)
5958
from argilla_server.security import auth
@@ -72,42 +71,13 @@
7271
router = APIRouter()
7372

7473

75-
async def _filter_records_using_search_engine(
76-
db: "AsyncSession",
77-
search_engine: "SearchEngine",
78-
dataset: Dataset,
79-
limit: int,
80-
offset: int,
81-
user: Optional[User] = None,
82-
include: Optional[RecordIncludeParam] = None,
83-
) -> Tuple[List[Record], int]:
84-
search_responses = await _get_search_responses(
85-
db=db,
86-
search_engine=search_engine,
87-
dataset=dataset,
88-
limit=limit,
89-
offset=offset,
90-
user=user,
91-
)
92-
93-
record_ids = [response.record_id for response in search_responses.items]
94-
user_id = user.id if user else None
95-
96-
return (
97-
await datasets.get_records_by_ids(
98-
db=db, dataset_id=dataset.id, user_id=user_id, records_ids=record_ids, include=include
99-
),
100-
search_responses.total,
101-
)
102-
103-
10474
def _to_search_engine_filter_scope(scope: FilterScope, user: Optional[User]) -> search_engine.FilterScope:
10575
if isinstance(scope, RecordFilterScope):
10676
return search_engine.RecordFilterScope(property=scope.property)
10777
elif isinstance(scope, MetadataFilterScope):
10878
return search_engine.MetadataFilterScope(metadata_property=scope.metadata_property)
10979
elif isinstance(scope, SuggestionFilterScope):
110-
return search_engine.SuggestionFilterScope(question=scope.question, property=scope.property)
80+
return search_engine.SuggestionFilterScope(question=scope.question, property=str(scope.property))
11181
elif isinstance(scope, ResponseFilterScope):
11282
return search_engine.ResponseFilterScope(question=scope.question, property=scope.property, user=user)
11383
else:
@@ -223,18 +193,6 @@ async def _get_search_responses(
223193
return await search_engine.search(**search_params)
224194

225195

226-
async def _build_response_status_filter_for_search(
227-
response_statuses: Optional[List[ResponseStatusFilter]] = None, user: Optional[User] = None
228-
) -> Optional[UserResponseStatusFilter]:
229-
user_response_status_filter = None
230-
231-
if response_statuses:
232-
# TODO(@frascuchon): user response and status responses should be split into different filter types
233-
user_response_status_filter = UserResponseStatusFilter(user=user, statuses=response_statuses)
234-
235-
return user_response_status_filter
236-
237-
238196
async def _validate_search_records_query(db: "AsyncSession", query: SearchRecordsQuery, dataset: Dataset):
239197
try:
240198
await search.validate_search_records_query(db, query, dataset)
@@ -246,27 +204,34 @@ async def _validate_search_records_query(db: "AsyncSession", query: SearchRecord
246204
async def list_dataset_records(
247205
*,
248206
db: AsyncSession = Depends(get_async_db),
249-
search_engine: SearchEngine = Depends(get_search_engine),
250207
dataset_id: UUID,
251208
include: Optional[RecordIncludeParam] = Depends(parse_record_include_param),
252209
offset: int = 0,
253210
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE),
254211
current_user: User = Security(auth.get_current_user),
255212
):
256213
dataset = await Dataset.get_or_raise(db, dataset_id)
257-
258214
await authorize(current_user, DatasetPolicy.list_records_with_all_responses(dataset))
259215

260-
records, total = await _filter_records_using_search_engine(
261-
db,
262-
search_engine,
263-
dataset=dataset,
264-
limit=limit,
216+
include_args = (
217+
dict(
218+
with_responses=include.with_responses,
219+
with_suggestions=include.with_suggestions,
220+
with_vectors=include.with_all_vectors or include.vectors,
221+
)
222+
if include
223+
else {}
224+
)
225+
226+
dataset_records, total = await records.list_dataset_records(
227+
db=db,
228+
dataset_id=dataset.id,
265229
offset=offset,
266-
include=include,
230+
limit=limit,
231+
**include_args,
267232
)
268233

269-
return Records(items=records, total=total)
234+
return Records(items=dataset_records, total=total)
270235

271236

272237
@router.delete("/datasets/{dataset_id}/records", status_code=status.HTTP_204_NO_CONTENT)

argilla-server/src/argilla_server/contexts/records.py

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,58 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Dict, Sequence
15+
from typing import Dict, Sequence, Union, List, Tuple, Optional
1616
from uuid import UUID
1717

18-
from sqlalchemy import select
18+
from sqlalchemy import select, and_, func, Select
1919
from sqlalchemy.ext.asyncio import AsyncSession
20-
from sqlalchemy.orm import selectinload
20+
from sqlalchemy.orm import selectinload, contains_eager
2121

22-
from argilla_server.models import Dataset, Record
22+
from argilla_server.database import get_async_db
23+
from argilla_server.models import Dataset, Record, VectorSettings, Vector
24+
25+
26+
async def list_dataset_records(
27+
db: AsyncSession,
28+
dataset_id: UUID,
29+
offset: int,
30+
limit: int,
31+
with_responses: bool = False,
32+
with_suggestions: bool = False,
33+
with_vectors: Union[bool, List[str]] = False,
34+
) -> Tuple[Sequence[Record], int]:
35+
query = _record_by_dataset_id_query(
36+
dataset_id=dataset_id,
37+
offset=offset,
38+
limit=limit,
39+
with_responses=with_responses,
40+
with_suggestions=with_suggestions,
41+
with_vectors=with_vectors,
42+
)
43+
44+
records = (await db.scalars(query)).unique().all()
45+
total = await db.scalar(select(func.count(Record.id)).filter_by(dataset_id=dataset_id))
46+
47+
return records, total
2348

2449

2550
async def list_dataset_records_by_ids(
2651
db: AsyncSession, dataset_id: UUID, record_ids: Sequence[UUID]
2752
) -> Sequence[Record]:
28-
query = select(Record).filter(Record.id.in_(record_ids), Record.dataset_id == dataset_id)
29-
return (await db.execute(query)).unique().scalars().all()
53+
query = select(Record).where(and_(Record.id.in_(record_ids), Record.dataset_id == dataset_id))
54+
return (await db.scalars(query)).unique().all()
3055

3156

3257
async def list_dataset_records_by_external_ids(
3358
db: AsyncSession, dataset_id: UUID, external_ids: Sequence[str]
3459
) -> Sequence[Record]:
3560
query = (
3661
select(Record)
37-
.filter(Record.external_id.in_(external_ids), Record.dataset_id == dataset_id)
62+
.where(and_(Record.external_id.in_(external_ids), Record.dataset_id == dataset_id))
3863
.options(selectinload(Record.dataset))
3964
)
40-
return (await db.execute(query)).unique().scalars().all()
65+
66+
return (await db.scalars(query)).unique().all()
4167

4268

4369
async def fetch_records_by_ids_as_dict(
@@ -52,3 +78,38 @@ async def fetch_records_by_external_ids_as_dict(
5278
) -> Dict[str, Record]:
5379
records_by_external_ids = await list_dataset_records_by_external_ids(db, dataset.id, external_ids)
5480
return {record.external_id: record for record in records_by_external_ids}
81+
82+
83+
def _record_by_dataset_id_query(
84+
dataset_id,
85+
offset: Optional[int] = None,
86+
limit: Optional[int] = None,
87+
with_responses: bool = False,
88+
with_suggestions: bool = False,
89+
with_vectors: Union[bool, List[str]] = False,
90+
) -> Select:
91+
query = select(Record).filter_by(dataset_id=dataset_id)
92+
93+
if with_responses:
94+
query = query.options(selectinload(Record.responses))
95+
96+
if with_suggestions:
97+
query = query.options(selectinload(Record.suggestions))
98+
99+
if with_vectors is True:
100+
query = query.options(selectinload(Record.vectors))
101+
elif isinstance(with_vectors, list):
102+
subquery = select(VectorSettings.id).filter(
103+
and_(VectorSettings.dataset_id == dataset_id, VectorSettings.name.in_(with_vectors))
104+
)
105+
query = query.outerjoin(
106+
Vector, and_(Vector.record_id == Record.id, Vector.vector_settings_id.in_(subquery))
107+
).options(contains_eager(Record.vectors))
108+
109+
if offset is not None:
110+
query = query.offset(offset)
111+
112+
if limit is not None:
113+
query = query.limit(limit)
114+
115+
return query.order_by(Record.inserted_at)

argilla-server/src/argilla_server/models/database.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,27 @@
1717
from typing import Any, List, Optional, Union
1818
from uuid import UUID
1919

20-
from sqlalchemy import JSON, ForeignKey, String, Text, UniqueConstraint, and_, sql, select, func, text
2120
from sqlalchemy import Enum as SAEnum
21+
from sqlalchemy import (
22+
JSON,
23+
ForeignKey,
24+
String,
25+
Text,
26+
UniqueConstraint,
27+
and_,
28+
sql,
29+
)
2230
from sqlalchemy.engine.default import DefaultExecutionContext
2331
from sqlalchemy.ext.asyncio import async_object_session
2432
from sqlalchemy.ext.mutable import MutableDict, MutableList
25-
from sqlalchemy.orm import Mapped, mapped_column, relationship, column_property
33+
from sqlalchemy.orm import Mapped, mapped_column, relationship
2634

2735
from argilla_server.api.schemas.v1.questions import QuestionSettings
2836
from argilla_server.enums import (
2937
DatasetStatus,
3038
FieldType,
3139
MetadataPropertyType,
3240
QuestionType,
33-
RecordStatus,
3441
ResponseStatus,
3542
SuggestionType,
3643
UserRole,

0 commit comments

Comments
 (0)