Skip to content

Commit 0c48e85

Browse files
committed
wip: add feast_utils, feast_replication_utils and tests
1 parent 956e133 commit 0c48e85

File tree

4 files changed

+1292
-0
lines changed

4 files changed

+1292
-0
lines changed
Lines changed: 389 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,389 @@
1+
import contextlib
2+
import functools
3+
import operator
4+
from datetime import datetime
5+
from pathlib import (
6+
Path,
7+
)
8+
9+
import dask
10+
import feast
11+
import feast.repo_operations
12+
import feast.utils as utils
13+
import toolz
14+
from attr import (
15+
field,
16+
frozen,
17+
)
18+
from attr.validators import (
19+
instance_of,
20+
)
21+
from feast.infra.offline_stores.offline_utils import (
22+
DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL,
23+
)
24+
25+
import xorq.api as xo
26+
27+
28+
@dask.base.normalize_token.register(dask.utils.methodcaller)
29+
def normalize_methodcaller(mc):
30+
return dask.base.normalize_token(
31+
(
32+
dask.utils.methodcaller,
33+
mc.method,
34+
)
35+
)
36+
37+
38+
@frozen
39+
class Store:
40+
path = field(validator=instance_of(Path), converter=Path)
41+
42+
def __attrs_post_init__(self):
43+
assert self.path.exists()
44+
45+
@property
46+
@functools.cache
47+
def store(self):
48+
return feast.FeatureStore(self.path)
49+
50+
@property
51+
def config(self):
52+
return self.store.config
53+
54+
@property
55+
def provider(self):
56+
return self.store._get_provider()
57+
58+
@property
59+
@functools.cache
60+
def repo_contents(self):
61+
with contextlib.chdir(self.path):
62+
return feast.repo_operations._get_repo_contents(
63+
self.path, self.project_name
64+
)
65+
66+
@property
67+
def registry(self):
68+
return self.store._registry
69+
70+
@property
71+
def entities(self):
72+
return self.store.list_entities()
73+
74+
@property
75+
def project_name(self):
76+
return self.config.project
77+
78+
def apply(self, skip_source_validation=False):
79+
with contextlib.chdir(self.path):
80+
return feast.repo_operations.apply_total(
81+
self.config, self.path, skip_source_validation=skip_source_validation
82+
)
83+
84+
def teardown(self):
85+
return self.store.teardown()
86+
87+
def list_on_demand_feature_view_names(self):
88+
return tuple(el.name for el in self.repo_contents.on_demand_feature_views)
89+
90+
def get_on_demand_feature_view(self, on_demand_feature_view_name):
91+
return self.registry.get_on_demand_feature_view(
92+
on_demand_feature_view_name, self.store.project
93+
)
94+
95+
def list_feature_view_names(self):
96+
return tuple(el.name for el in self.repo_contents.feature_views)
97+
98+
def get_feature_view(self, feature_view_name):
99+
return self.registry.get_feature_view(feature_view_name, self.store.project)
100+
101+
def get_feature_refs(self, features):
102+
return utils._get_features(self.registry, self.store.project, list(features))
103+
104+
def get_feature_views_to_use(self, features):
105+
(all_feature_views, all_on_demand_feature_views) = (
106+
utils._get_feature_views_to_use(
107+
self.registry,
108+
self.store.project,
109+
list(features),
110+
)
111+
)
112+
return (all_feature_views, all_on_demand_feature_views)
113+
114+
def get_grouped_feature_views(self, features):
115+
feature_refs = self.get_feature_refs(features)
116+
(all_feature_views, all_on_demand_feature_views) = (
117+
self.get_feature_views_to_use(features)
118+
)
119+
fvs, odfvs = utils._group_feature_refs(
120+
feature_refs,
121+
all_feature_views,
122+
all_on_demand_feature_views,
123+
)
124+
(feature_views, on_demand_feature_views) = (
125+
tuple(view for view, _ in gen) for gen in (fvs, odfvs)
126+
)
127+
return feature_views, on_demand_feature_views
128+
129+
def validate_entity_expr(self, entity_expr, features, full_feature_names=False):
130+
(_, on_demand_feature_views) = self.get_grouped_feature_views(features)
131+
if self.store.config.coerce_tz_aware:
132+
# FIXME: pass entity_expr back out
133+
# entity_df = utils.make_df_tzaware(typing.cast(pd.DataFrame, entity_df))
134+
pass
135+
bad_pairs = (
136+
(feature_name, odfv.name)
137+
for odfv in on_demand_feature_views
138+
for feature_name in odfv.get_request_data_schema().keys()
139+
if feature_name not in entity_expr.columns
140+
)
141+
if pair := next(bad_pairs, None):
142+
from feast.feature_store import RequestDatanotFoundInEntityDfException
143+
144+
(feature_name, feature_view_name) = pair
145+
raise RequestDatanotFoundInEntityDfException(
146+
feature_name=feature_name,
147+
feature_view_name=feature_view_name,
148+
)
149+
utils._validate_feature_refs(
150+
self.get_feature_refs(features),
151+
full_feature_names,
152+
)
153+
154+
def get_historical_features(self, entity_expr, features, full_feature_names=False):
155+
self.validate_entity_expr(
156+
entity_expr, features, full_feature_names=full_feature_names
157+
)
158+
(odfv_dct, fv_dct) = group_features(self, features)
159+
entity_expr, all_join_keys = process_all_feature_views(
160+
self, entity_expr, fv_dct
161+
)
162+
expr = process_odfvs(entity_expr, odfv_dct)
163+
return expr
164+
165+
def get_historical_features_feast(
166+
self, entity_df, features, full_feature_names=False
167+
):
168+
return self.store.get_historical_features(
169+
entity_df=entity_df,
170+
features=features,
171+
full_feature_names=full_feature_names,
172+
)
173+
174+
def get_online_features(self, features, entity_rows):
175+
return self.store.get_online_features(
176+
features=features,
177+
entity_rows=entity_rows,
178+
).to_dict()
179+
180+
def list_feature_service_names(self):
181+
return tuple(el.name for el in self.store.list_feature_services())
182+
183+
def get_feature_service(self, feature_service_name):
184+
return self.store.get_feature_service(feature_service_name)
185+
186+
def list_data_source_names(self):
187+
return tuple(
188+
el.name for el in self.registry.list_data_sources(self.project_name)
189+
)
190+
191+
def get_data_source(self, data_source_name):
192+
return self.registry.get_data_source(data_source_name, self.project_name)
193+
194+
@classmethod
195+
def make_applied_materialized(cls, path, end_date=None):
196+
end_date = end_date or datetime.now()
197+
store = cls(path)
198+
store.apply()
199+
store.store.materialize_incremental(end_date=end_date)
200+
return store
201+
202+
203+
def process_one_feature_view(
204+
entity_expr, store, feature_view, feature_names, all_join_keys
205+
):
206+
def _read_mapped(
207+
con,
208+
store,
209+
feature_view,
210+
feature_names,
211+
right_entity_key_columns,
212+
ets,
213+
ts,
214+
full_feature_names=False,
215+
):
216+
def maybe_rename(expr, dct):
217+
return (
218+
expr.rename({to_: from_ for from_, to_ in dct.items() if from_ in expr})
219+
if dct
220+
else expr
221+
)
222+
223+
if full_feature_names:
224+
raise ValueError
225+
expr = (
226+
xo.deferred_read_parquet(
227+
store.config.repo_path.joinpath(feature_view.batch_source.path), con=con
228+
)
229+
.pipe(maybe_rename, feature_view.batch_source.field_mapping)
230+
.pipe(maybe_rename, feature_view.projection.join_key_map)
231+
.select(list(right_entity_key_columns) + list(feature_names))
232+
)
233+
if ts == ets:
234+
new_ts = f"__{ts}"
235+
expr, ts = expr.pipe(maybe_rename, {ts: new_ts}), new_ts
236+
return expr, ts
237+
238+
def _merge(entity_expr, feature_expr, join_keys):
239+
return entity_expr.join(
240+
feature_expr, predicates=join_keys, how="left", rname="{name}__"
241+
)
242+
243+
def _normalize_timestamp(expr, *tss):
244+
casts = {
245+
ts: xo.expr.datatypes.Timestamp(timezone="UTC")
246+
for ts in tss
247+
if ts in expr and expr[ts].type().timezone is None
248+
}
249+
return expr.cast(casts) if casts else expr
250+
251+
def _filter_ttl(expr, ttl, ets, ts):
252+
isna_condition = expr[ts].isnull()
253+
le_condition = expr[ts] <= expr[ets]
254+
if ttl and ttl.total_seconds() != 0:
255+
ge_condition = (
256+
expr[ets] - xo.interval(seconds=ttl.total_seconds())
257+
) <= expr[ts]
258+
time_condition = ge_condition & le_condition
259+
else:
260+
time_condition = le_condition
261+
condition = isna_condition | time_condition
262+
return expr[condition]
263+
264+
def _drop_duplicates(expr, join_keys, ets, ts, cts):
265+
order_by = tuple(
266+
expr[ts].desc(nulls_first=False)
267+
# cts desc first: most recent update
268+
# ts desc: closest to the event ts
269+
for ts in (cts, ts)
270+
if ts in expr
271+
)
272+
ROW_NUM = "row_num"
273+
expr = (
274+
expr.mutate(
275+
**{
276+
ROW_NUM: (
277+
xo.row_number().over(
278+
group_by=list(join_keys) + [ets],
279+
order_by=order_by,
280+
)
281+
),
282+
}
283+
)
284+
.filter(xo._[ROW_NUM] == 0)
285+
.drop(ROW_NUM)
286+
)
287+
return expr
288+
289+
ets = DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
290+
assert ets in entity_expr
291+
con = entity_expr._find_backend()
292+
293+
ts, cts = (
294+
feature_view.batch_source.timestamp_field,
295+
feature_view.batch_source.created_timestamp_column,
296+
)
297+
join_keys = tuple(
298+
feature_view.projection.join_key_map.get(entity_column.name, entity_column.name)
299+
for entity_column in feature_view.entity_columns
300+
)
301+
all_join_keys = all_join_keys + [
302+
join_key for join_key in join_keys if join_key not in all_join_keys
303+
]
304+
right_entity_key_columns = list(filter(None, [ts, cts] + list(join_keys)))
305+
306+
entity_expr = _normalize_timestamp(entity_expr, ets)
307+
308+
feature_expr, ts = _read_mapped(
309+
con, store, feature_view, feature_names, right_entity_key_columns, ets, ts
310+
)
311+
expr = _merge(entity_expr, feature_expr, join_keys)
312+
expr = _normalize_timestamp(expr, ts, cts)
313+
expr = _filter_ttl(expr, feature_view.ttl, ets, ts)
314+
expr = _drop_duplicates(expr, all_join_keys, ets, ts, cts)
315+
return expr, all_join_keys
316+
317+
318+
def process_all_feature_views(store, entity_expr, fv_dct):
319+
all_join_keys = []
320+
for feature_view, feature_names in fv_dct.items():
321+
entity_expr, all_join_keys = process_one_feature_view(
322+
entity_expr, store, feature_view, feature_names, all_join_keys
323+
)
324+
return entity_expr, all_join_keys
325+
326+
327+
@toolz.curry
328+
def apply_odfv_dct(df, odfv_udfs):
329+
for other in (udf(df) for udf in odfv_udfs):
330+
df = df.join(other)
331+
return df
332+
333+
334+
def make_uniform_timestamps(expr, timezone="UTC", scale=6):
335+
import xorq.vendor.ibis.expr.datatypes as dt
336+
337+
casts = {
338+
name: dt.Timestamp(timezone=timezone, scale=scale)
339+
for name, typ in expr.schema().items()
340+
if isinstance(typ, dt.Timestamp)
341+
}
342+
return expr.cast(casts) if casts else expr
343+
344+
345+
def calc_odfv_schema_append(odfv_dct):
346+
fields = (field for odfv in odfv_dct for field in odfv.features)
347+
schema_append = {field.name: field.dtype.name for field in fields}
348+
return schema_append
349+
350+
351+
def process_odfvs(entity_expr, odfv_dct, full_feature_names=False):
352+
if full_feature_names:
353+
raise ValueError
354+
entity_expr = make_uniform_timestamps(entity_expr)
355+
odfv_udfs = tuple(odfv.feature_transformation.udf for odfv in odfv_dct.keys())
356+
schema_in = entity_expr.schema()
357+
schema_append = calc_odfv_schema_append(odfv_dct)
358+
udxf = xo.expr.relations.flight_udxf(
359+
process_df=apply_odfv_dct(odfv_udfs=odfv_udfs),
360+
maybe_schema_in=schema_in,
361+
maybe_schema_out=schema_in | schema_append,
362+
name="process_odfvs",
363+
)
364+
return udxf(entity_expr)
365+
366+
367+
def group_features(store, feature_names):
368+
splat = tuple(feature_name.split(":") for feature_name in feature_names)
369+
assert (2,) == tuple(set(map(len, splat)))
370+
name_to_use_to_view = {
371+
view.projection.name_to_use(): view
372+
for view in store.store.list_all_feature_views()
373+
}
374+
dct = toolz.groupby(
375+
operator.itemgetter(0),
376+
splat,
377+
)
378+
view_to_feature_names = {
379+
name_to_use_to_view[feature_view_name]: tuple(
380+
feature_name for _, feature_name in pairs
381+
)
382+
for feature_view_name, pairs in dct.items()
383+
}
384+
is_odfv = toolz.flip(isinstance)(feast.OnDemandFeatureView)
385+
(odfv_dct, fv_dct) = (
386+
toolz.keyfilter(f, view_to_feature_names)
387+
for f in (is_odfv, toolz.complement(is_odfv))
388+
)
389+
return odfv_dct, fv_dct

0 commit comments

Comments
 (0)