Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
version: "3"

services:
# runs the tests
test:
Expand All @@ -12,6 +10,8 @@ services:
depends_on:
- es
- mongo
volumes:
- '.:/usr/src/app'

# serves the docs locally with realtime auto-reloading
docs:
Expand Down
60 changes: 55 additions & 5 deletions splitgill/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from cytoolz.dicttoolz import get_in
from elasticsearch import Elasticsearch
from elasticsearch_dsl import Search
from elasticsearch_dsl import Index, Search
from elasticsearch_dsl.query import Query
from pymongo import ASCENDING, DESCENDING, IndexModel, MongoClient
from pymongo.collection import Collection
Expand Down Expand Up @@ -637,7 +637,10 @@ def get_versions(self) -> List[int]:
return sorted(self.get_version_changed_counts().keys())

def get_data_fields(
self, version: Optional[int] = None, query: Optional[Query] = None
self,
version: Optional[int] = None,
query: Optional[Query] = None,
**iter_terms_kwargs,
) -> List[DataField]:
"""
Retrieves the available data fields for this database, optionally at the given
Expand All @@ -647,6 +650,8 @@ def get_data_fields(
searched
:param query: the query to filter records with before finding the data fields,
if None, all record data is considered
:param iter_terms_kwargs: kwargs passed directly to iter_terms (e.g. chunk_size,
sample_probability)
:return: a list of DataField objects with the most frequent field first
"""
search = self.search(version if version is not None else SearchVersion.latest)
Expand All @@ -656,7 +661,15 @@ def get_data_fields(
fields: Dict[str, DataField] = {}

# create the basic field objects and add type counts
for term in iter_terms(search, DocumentField.DATA_TYPES):
for term in iter_terms(
search,
DocumentField.DATA_TYPES,
**{
k: v
for k, v in iter_terms_kwargs.items()
if k not in ['search', 'field']
},
):
path, raw_types = term.value.rsplit('.', 1)
if path not in fields:
fields[path] = DataField(path)
Expand Down Expand Up @@ -688,7 +701,10 @@ def get_data_fields(
return data_fields

def get_parsed_fields(
self, version: Optional[int] = None, query: Optional[Query] = None
self,
version: Optional[int] = None,
query: Optional[Query] = None,
**iter_terms_kwargs,
) -> List[ParsedField]:
"""
Retrieves the available parsed fields for this database, optionally at the given
Expand All @@ -698,6 +714,8 @@ def get_parsed_fields(
is searched
:param query: the query to filter records with before finding the parsed fields,
if None, all record data is considered
:param iter_terms_kwargs: kwargs passed directly to iter_terms (e.g. chunk_size,
sample_probability)
:return: a list of ParsedField objects with the most frequent field first
"""
search = self.search(version if version is not None else SearchVersion.latest)
Expand All @@ -707,7 +725,15 @@ def get_parsed_fields(
fields: Dict[str, ParsedField] = {}

# create the basic field objects and add type counts
for term in iter_terms(search, DocumentField.PARSED_TYPES):
for term in iter_terms(
search,
DocumentField.PARSED_TYPES,
**{
k: v
for k, v in iter_terms_kwargs.items()
if k not in ['search', 'field']
},
):
path, raw_types = term.value.rsplit('.', 1)
if path not in fields:
fields[path] = ParsedField(path)
Expand All @@ -725,3 +751,27 @@ def get_parsed_fields(
# descending frequency (so most frequent fields first)
parsed_fields.sort(key=lambda f: f.count, reverse=True)
return parsed_fields

def get_field_names(self) -> List[str]:
"""
Retrieves a list of field names from the latest index mapping.

Does not take any version or query parameters; simply returns all the "data."
fields available on the index, along with their available types. All relevant
type counts are set to 1 to enable use of e.g. .is_text(). Use get_data_fields
or get_parsed_fields if you need accurate counts, or to filter by version or
query.
"""
latest_index = Index(self.indices.latest, using=self._client.elasticsearch)
mapping = latest_index.get_mapping()
parsed_fields = []
for field_path, field_props in get_in(
[self.indices.latest, 'mappings', 'properties', 'data', 'properties'],
mapping.body,
default={},
).items():
parsed_field = ParsedField(field_path)
for type_name in field_props['properties'].keys():
parsed_field.add(type_name, 1)
parsed_fields.append(parsed_field)
return parsed_fields
60 changes: 49 additions & 11 deletions splitgill/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import date, datetime, timezone
from itertools import islice
from time import time
from typing import Iterable, Union
from typing import Iterable, Optional, Union

from cytoolz import get_in
from elasticsearch_dsl import A, Search
Expand Down Expand Up @@ -84,7 +84,13 @@ class Term:
count: int


def iter_terms(search: Search, field: str, chunk_size: int = 50) -> Iterable[Term]:
def iter_terms(
search: Search,
field: str,
chunk_size: int = 50,
sample_probability: float = 1.0,
seed: Optional[int] = None,
) -> Iterable[Term]:
"""
Yields Term objects, each representing a value and the number of documents which
contain that value in the given field. The Terms are yielded in descending order of
Expand All @@ -93,6 +99,10 @@ def iter_terms(search: Search, field: str, chunk_size: int = 50) -> Iterable[Ter
:param search: a Search instance to use to run the aggregation
:param field: the name of the field to get the terms for
:param chunk_size: the number of buckets to retrieve per request
:param sample_probability: the probability that a given record will be included in a
random sample; set to 1 to use all records (default 1)
:param seed: sets the seed manually (if None or not set, defaults to current date
timestamp / 3600)
:return: yields Term objects
"""
after = None
Expand All @@ -101,19 +111,47 @@ def iter_terms(search: Search, field: str, chunk_size: int = 50) -> Iterable[Ter
# when we don't need them, and it ensures we get a fresh copy of the
# search to work with
agg_search = search[:0]
agg_search.aggs.bucket(
'values',
'composite',
size=chunk_size,
sources={'value': A('terms', field=field)},

# this is the core aggregation
composite_agg = A(
'composite', size=chunk_size, sources={'value': A('terms', field=field)}
)
if after is not None:
agg_search.aggs['values'].after = after
result_keys = ['values', 'buckets']
after_keys = ['values', 'after_key']

if sample_probability < 1:
# this should stay relatively constant for caching purposes, but we can
# change it once a day.
# divide it by 3600 just to make it fit under the ES seed max (2147483647)
# for longer - otherwise this stops working in 2038. The actual number isn't
# important.
seed = (
seed
if seed is not None
else int(
datetime.now()
.replace(hour=0, minute=0, second=0, microsecond=0)
.timestamp()
/ 3600
)
)
# if we're sampling, the core agg gets nested underneath the sampler
agg_search.aggs.bucket(
'sampling', 'random_sampler', probability=sample_probability, seed=seed
).bucket('values', composite_agg)
if after is not None:
agg_search.aggs['sampling'].aggs['values'].after = after
result_keys = ['sampling'] + result_keys
after_keys = ['sampling'] + after_keys
else:
agg_search.aggs.bucket('values', composite_agg)
if after is not None:
agg_search.aggs['values'].after = after

result = agg_search.execute().aggs.to_dict()

buckets = get_in(('values', 'buckets'), result, [])
after = get_in(('values', 'after_key'), result, None)
buckets = get_in(result_keys, result, [])
after = get_in(after_keys, result, None)
if not buckets:
break
else:
Expand Down
76 changes: 76 additions & 0 deletions tests/test_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import random
import time
from collections import Counter
from datetime import datetime, timezone
from operator import itemgetter
from typing import List
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -1270,6 +1272,54 @@ def test_hierarchy(self, database: SplitgillDatabase):
check_data_fields(data_fields[10].children, [f__h_i])
assert all(field.parent.path == f__h.path for field in data_fields[9].children)

def test_random_sample(self, database: SplitgillDatabase):
# I can't get consistent records returned with this even with seeds and
# explicitly set versions and record IDs, so unfortunately there's a small
# element of randomness to this one
random.seed(1)
records = [
Record.new({field_name: 1})
for field_name in random.choices(
['a', 'b', 'c', 'd'], weights=[1, 3, 2, 4], k=1000
)
]
database.ingest(records, commit=True)
database.sync()

field_counts = sorted(
Counter(
(['_id'] * 1000) + [list(r.data.keys())[0] for r in records]
).items(),
key=itemgetter(1),
reverse=True,
)

data_fields = database.get_data_fields(sample_probability=0.5, seed=1)
assert len(data_fields) == 5
id_field = next(f for f in data_fields if f.path == '_id')
assert 900 < id_field.count < 1100
exact_counts = 0
for data_field, field in zip(data_fields, field_counts):
field_name, field_count = field
assert data_field.path == field_name
if data_field.count == field_count:
exact_counts += 1
# this is *technically* possible but very unlikely unless it's not sampling
assert exact_counts != len(data_fields)

parsed_fields = database.get_parsed_fields(sample_probability=0.5, seed=1)
assert len(parsed_fields) == 5
id_field = next(f for f in parsed_fields if f.path == '_id')
assert 900 < id_field.count < 1100
exact_counts = 0
for parsed_field, field in zip(parsed_fields, field_counts):
field_name, field_count = field
assert parsed_field.path == field_name
if parsed_field.count == field_count:
exact_counts += 1
# this is *technically* possible but very unlikely unless it's not sampling
assert exact_counts != len(parsed_fields)


def test_get_rounded_version(splitgill: SplitgillClient):
database = splitgill.get_database('test')
Expand Down Expand Up @@ -1426,3 +1476,29 @@ def test_resync_arcs(splitgill: SplitgillClient):
assert count == 2400
assert r_5_count == 3
assert r_780_count == 2


def test_get_field_names(splitgill: SplitgillClient):
database = splitgill.get_database('test')
records = [
Record.new({'a': 1}),
Record.new({'a': 2}),
Record.new({'b': 3}),
Record.new({'b': 'x'}),
Record.new({'b': 5}),
Record.new({'c': 'y'}),
]
database.ingest(records, commit=True)
database.sync()

field_names = database.get_field_names()
assert len(field_names) == 4
expected_fields = [
pf('_id', 3, t=1),
pf('a', 4, t=1, n=1),
pf('b', 4, t=1, n=1),
pf('c', 3, t=1),
]
for f in expected_fields:
f.type_counts[ParsedType.UNPARSED] = 1
assert field_names == expected_fields
Loading