Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions dlt/common/schema/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,17 @@
"""Known hints of a column"""
COLUMN_HINTS: Set[TColumnHint] = set(get_args(TColumnHint))

TColumnPropMergeType = Literal[
"replace",
"remove_if_empty",
]


class TColumnPropInfo(NamedTuple):
name: Union[TColumnProp, str]
defaults: Tuple[Any, ...] = (None,)
is_hint: bool = False
merge_type: TColumnPropMergeType = "replace"


_ColumnPropInfos = [
Expand All @@ -117,10 +123,10 @@ class TColumnPropInfo(NamedTuple):
TColumnPropInfo("variant", (False, None)),
TColumnPropInfo("partition", (False, None)),
TColumnPropInfo("cluster", (False, None)),
TColumnPropInfo("primary_key", (False, None)),
TColumnPropInfo("primary_key", (False, None), False, "remove_if_empty"),
TColumnPropInfo("sort", (False, None)),
TColumnPropInfo("unique", (False, None)),
TColumnPropInfo("merge_key", (False, None)),
TColumnPropInfo("merge_key", (False, None), False, "remove_if_empty"),
TColumnPropInfo("row_key", (False, None)),
TColumnPropInfo("parent_key", (False, None)),
TColumnPropInfo("root_key", (False, None)),
Expand Down Expand Up @@ -149,6 +155,10 @@ class TColumnPropInfo(NamedTuple):
]
TTypeDetectionFunc = Callable[[Type[Any], Any], Optional[TDataType]]

RemoveIfEmptyPropInfos = {
info.name: info for info in _ColumnPropInfos if info.merge_type == "remove_if_empty"
}


class TColumnType(TypedDict, total=False):
data_type: Optional[TDataType]
Expand Down
62 changes: 51 additions & 11 deletions dlt/common/schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
VERSION_TABLE_NAME,
PIPELINE_STATE_TABLE_NAME,
ColumnPropInfos,
TColumnPropMergeType,
TColumnName,
TFileFormat,
TPartialTableSchema,
Expand Down Expand Up @@ -154,6 +155,22 @@ def has_default_column_prop_value(prop: str, value: Any) -> bool:
return value in (None, False)


def has_merge_type(prop: str, merge_type: TColumnPropMergeType = "remove_if_empty") -> bool:
if prop in ColumnPropInfos:
return ColumnPropInfos[prop].merge_type == merge_type
return False


def remove_column_props_with_merge_type(
column_schema: TColumnSchema, merge_type: TColumnPropMergeType = "remove_if_empty"
) -> TColumnSchema:
"""Removes properties that have merge type remove if empty"""
for prop in list(column_schema.keys()):
if has_merge_type(prop, merge_type):
column_schema.pop(prop) # type: ignore
return column_schema


def remove_column_defaults(column_schema: TColumnSchema) -> TColumnSchema:
"""Removes default values from `column_schema` in place, returns the input for chaining"""
# remove hints with default values
Expand Down Expand Up @@ -420,15 +437,28 @@ def diff_table_references(


def merge_column(
col_a: TColumnSchema, col_b: TColumnSchema, merge_defaults: bool = True
col_a: TColumnSchema,
col_b: TColumnSchema,
merge_defaults: bool = True,
respect_merge_type: bool = False,
) -> TColumnSchema:
"""Merges `col_b` into `col_a`. if `merge_defaults` is True, only hints from `col_b` that are not default in `col_a` will be set.
"""Merges col_b into col_a in place. Returns col_a.

Modifies col_a in place and returns it
merge_defaults: If False, only merge non-default values from col_b
respect_merge_type: If True, apply "remove_if_empty" merge rules to col_a properties
"""
col_b_clean = col_b if merge_defaults else remove_column_defaults(copy(col_b))
for n, v in col_b_clean.items():
col_a[n] = v # type: ignore

col_b_clean = copy(col_b) if merge_defaults else remove_column_defaults(copy(col_b))

for prop in list(col_a.keys()):
if prop in col_b_clean:
col_a[prop] = col_b_clean.pop(prop) # type: ignore
else:
if respect_merge_type and has_merge_type(prop, "remove_if_empty"):
col_a.pop(prop) # type: ignore

for prop, value in col_b_clean.items():
col_a[prop] = value # type: ignore

return col_a

Expand All @@ -438,6 +468,7 @@ def merge_columns(
columns_b: TTableSchemaColumns,
merge_columns: bool = False,
columns_partial: bool = True,
respect_merge_type: bool = False,
) -> TTableSchemaColumns:
"""Merges `columns_a` with `columns_b`. `columns_a` is modified in place.

Expand All @@ -458,14 +489,19 @@ def merge_columns(
if column_a and not is_complete_column(column_a):
columns_a.pop(col_name)
if column_a and merge_columns:
column_b = merge_column(column_a, column_b)
column_b = merge_column(
column_a, column_b, merge_defaults=True, respect_merge_type=respect_merge_type
)
# set new or updated column
columns_a[col_name] = column_b
return columns_a


def diff_table(
schema_name: str, tab_a: TTableSchema, tab_b: TPartialTableSchema
schema_name: str,
tab_a: TTableSchema,
tab_b: TPartialTableSchema,
respect_merge_type: bool = False,
) -> TPartialTableSchema:
"""Creates a partial table that contains properties found in `tab_b` that are not present or different in `tab_a`.
The name is always present in returned partial.
Expand All @@ -480,18 +516,22 @@ def diff_table(
ensure_compatible_tables(schema_name, tab_a, tab_b, ensure_columns=False)

# get new columns, changes in the column data type or other properties are not allowed
tab_a_columns = tab_a["columns"]
tab_a_columns = copy(tab_a["columns"])
new_columns: List[TColumnSchema] = []
for col_b_name, col_b in tab_b["columns"].items():
if col_b_name in tab_a_columns:
col_a = tab_a_columns[col_b_name]
col_a = tab_a_columns.pop(col_b_name)
# all other properties can change
merged_column = merge_column(copy(col_a), col_b)
merged_column = merge_column(copy(col_a), col_b, respect_merge_type=respect_merge_type)
if merged_column != col_a:
new_columns.append(merged_column)
else:
new_columns.append(col_b)

# if respect_merge_type:
# for col_a in tab_a_columns.values():
# remove_column_props_with_merge_type(col_a, "remove_if_empty")

# return partial table containing only name and properties that differ (column, filters etc.)
table_name = tab_a["name"]

Expand Down
5 changes: 3 additions & 2 deletions dlt/extract/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,9 @@ def _compute_and_update_tables(
computed_table["x-normalizer"] = {"evolve-columns-once": True}
existing_table = self.schema.tables.get(table_name, None)
if existing_table:
# TODO: revise this. computed table should overwrite certain hints (ie. primary and merge keys) completely
diff_table = utils.diff_table(self.schema.name, existing_table, computed_table)
diff_table = utils.diff_table(
self.schema.name, existing_table, computed_table, respect_merge_type=True
)
else:
diff_table = computed_table

Expand Down
2 changes: 1 addition & 1 deletion tests/normalize/test_model_item_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def test_selected_column_names_normalized(
parsed_norm_select_query = sqlglot.parse_one(normalized_select_query, read=dialect)

# Ensure the normalized model query contains a subquery in the FROM clause
from_clause = parsed_norm_select_query.args.get("from")
from_clause = parsed_norm_select_query.find(sqlglot.exp.From)
assert isinstance(from_clause, sqlglot.exp.From)
assert isinstance(from_clause.this, sqlglot.exp.Subquery)
assert isinstance(from_clause.this.this, sqlglot.exp.Select)
Expand Down
74 changes: 73 additions & 1 deletion tests/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import random
import shutil
import threading
import yaml
from time import sleep
from typing import Any, List, Tuple, cast
from typing import Any, List, Tuple, cast, Union
from tenacity import retry_if_exception, Retrying, stop_after_attempt
from unittest.mock import patch
import pytest
Expand Down Expand Up @@ -1837,6 +1838,77 @@ def infer():
# print(pipeline.default_schema.to_pretty_yaml())


@pytest.mark.parametrize(
"empty_value",
["", []],
ids=["empty_string", "empty_list"],
)
def test_apply_hints_with_empty_values(empty_value: Union[str, List[Any]]) -> None:
@dlt.resource
def some_data():
yield {"id": 1, "val": "some_data"}

s = some_data()
pipeline = dlt.pipeline(pipeline_name="empty_value_hints", destination=DUMMY_COMPLETE)

# check initial schema
pipeline.run(s)
table = pipeline.default_schema.get_table("some_data")
assert table["columns"]["id"] == {
"name": "id",
"data_type": "bigint",
"nullable": True,
}

# check schema after setting primary key
s.apply_hints(primary_key=["id"])
pipeline.run(s)
table = pipeline.default_schema.get_table("some_data")
assert table["columns"]["id"] == {
"name": "id",
"data_type": "bigint",
"nullable": False,
"primary_key": True,
}

# check schema after passing an empty value as hints, which should remove primary
s.apply_hints(primary_key=empty_value)
pipeline.run(s)
table = pipeline.default_schema.get_table("some_data")
assert table["columns"]["id"] == {
"name": "id",
"data_type": "bigint",
"nullable": False,
}


def test_apply_hints_with_empty_values_with_schema() -> None:
pipeline = dlt.pipeline(pipeline_name="empty_value_hints_with_schema", destination=DUMMY_COMPLETE)

with open("tests/common/cases/schemas/eth/ethereum_schema_v11.yml", "r", encoding="utf-8") as f:
schema = dlt.Schema.from_dict(yaml.safe_load(f))

@dlt.source(schema=schema)
def ethereum():
@dlt.resource # type: ignore[call-overload]
def blocks():
with open(
"tests/normalize/cases/ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2.json",
"r",
encoding="utf-8",
) as f:
yield json.load(f)

return blocks()

source = ethereum()
source.blocks.apply_hints(write_disposition="replace")

pipeline.run(source)
table = pipeline.default_schema.get_table("blocks")
assert table["columns"]["number"].get("primary_key") is True


def test_invalid_data_edge_cases() -> None:
# pass lambda directly to run, allowed now because functions can be extracted too
pipeline = dlt.pipeline(pipeline_name="invalid", destination=DUMMY_COMPLETE)
Expand Down
Loading