diff --git a/merlin_standard_lib/schema/schema.py b/merlin_standard_lib/schema/schema.py index 0052c08064..ee343b3ef1 100644 --- a/merlin_standard_lib/schema/schema.py +++ b/merlin_standard_lib/schema/schema.py @@ -16,6 +16,7 @@ import json import os from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union +from warnings import warn from google.protobuf import json_format, text_format from google.protobuf.message import Message as ProtoMessage @@ -216,6 +217,18 @@ class Schema(_Schema): """A collection of column schemas for a dataset.""" feature: List["ColumnSchema"] = betterproto.message_field(1) + _is_first_init = True + + def __post_init__(self): + super().__post_init__() + if self._is_first_init: + # TODO: Make description more descriptive. What version are we planning to remove it? + warn( + "Schema from `merlin_standard_lib` is deprecated, ", + "use Schema from `merlin.schema` instead.", + DeprecationWarning, + ) + self._is_first_init = False @classmethod def create( diff --git a/tests/merlin_standard_lib/schema/test_schema.py b/tests/merlin_standard_lib/schema/test_schema.py index 123d396942..890fc6c6bd 100644 --- a/tests/merlin_standard_lib/schema/test_schema.py +++ b/tests/merlin_standard_lib/schema/test_schema.py @@ -13,6 +13,8 @@ # limitations under the License. # +from warnings import catch_warnings + import pytest from merlin_standard_lib import categorical_cardinalities @@ -37,16 +39,20 @@ def test_column_schema(): def test_schema(): - s = schema.Schema( - [ - schema.ColumnSchema.create_continuous("con_1"), - schema.ColumnSchema.create_continuous("con_2_int", is_float=False), - schema.ColumnSchema.create_categorical("cat_1", 1000), - schema.ColumnSchema.create_categorical( - "cat_2", 100, value_count=schema.ValueCount(1, 20) - ), - ] - ) + with catch_warnings(record=True) as w: + schema.Schema._is_first_init = True + s = schema.Schema( + [ + schema.ColumnSchema.create_continuous("con_1"), + schema.ColumnSchema.create_continuous("con_2_int", is_float=False), + schema.ColumnSchema.create_categorical("cat_1", 1000), + schema.ColumnSchema.create_categorical( + "cat_2", 100, value_count=schema.ValueCount(1, 20) + ), + ] + ) + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) assert len(s.select_by_type(schema.FeatureType.INT).column_names) == 3 assert len(s.select_by_name(lambda x: x.startswith("con")).column_names) == 2