Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions rectools/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ class SparseFeaturesSchema(BaseFeaturesSchema):
cat_n_stored_values: int


class InteractionsFeaturesSchema(BaseConfig):
"""Interactions features schema."""

cat_feature_names: tp.List[str]
cat_feature_names_w_values: tp.List[tp.Tuple[str, str]]
direct_feature_names: tp.List[str]


FeaturesSchema = tp.Union[DenseFeaturesSchema, SparseFeaturesSchema]


Expand All @@ -102,6 +110,7 @@ class DatasetSchema(BaseConfig):
n_interactions: int
users: EntitySchema
items: EntitySchema
interactions: tp.Optional[InteractionsFeaturesSchema] = None


@attr.s(slots=True, frozen=True)
Expand Down Expand Up @@ -135,6 +144,7 @@ class Dataset:
interactions: Interactions = attr.ib()
user_features: tp.Optional[Features] = attr.ib(default=None)
item_features: tp.Optional[Features] = attr.ib(default=None)
interactions_schema: tp.Optional[InteractionsFeaturesSchema] = attr.ib(default=None)

@staticmethod
def _get_feature_schema(features: tp.Optional[Features]) -> tp.Optional[FeaturesSchema]:
Expand Down Expand Up @@ -170,6 +180,7 @@ def get_schema(self) -> DatasetSchemaDict:
n_interactions=self.interactions.df.shape[0],
users=user_schema,
items=item_schema,
interactions=self.interactions_schema,
)
return schema.model_dump(mode="json")

Expand Down Expand Up @@ -206,7 +217,7 @@ def get_hot_item_features(self) -> tp.Optional[Features]:
return self.item_features.take(range(self.n_hot_items))

@classmethod
def construct(
def construct( # pylint: disable=too-many-locals
cls,
interactions_df: pd.DataFrame,
user_features_df: tp.Optional[pd.DataFrame] = None,
Expand All @@ -216,6 +227,8 @@ def construct(
cat_item_features: tp.Iterable[str] = (),
make_dense_item_features: bool = False,
keep_extra_cols: bool = False,
interactions_cat_features: tp.Iterable[str] = (),
interactions_direct_features: tp.Iterable[str] = (),
) -> "Dataset":
"""Class method for convenient `Dataset` creation.

Expand Down Expand Up @@ -249,6 +262,10 @@ def construct(
- if ``True``, `DenseFeatures.from_dataframe` method will be used.
keep_extra_cols: bool, default ``False``
Flag to keep all columns from interactions besides the default ones.
interactions_cat_features : tp.Iterable[str], default ``()``
List of categorical feature names in interactions dataframe.
interactions_direct_features : tp.Iterable[str], default ``()``
List of direct (non-categorical) feature names in interactions dataframe.

Returns
-------
Expand All @@ -258,6 +275,32 @@ def construct(
for col in (Columns.User, Columns.Item):
if col not in interactions_df:
raise KeyError(f"Column '{col}' must be present in `interactions_df`")

# Validate interactions features
cat_features = set(interactions_cat_features)
direct_features = set(interactions_direct_features)
required_columns = cat_features | direct_features
actual_columns = set(interactions_df.columns)
if not actual_columns >= required_columns:
raise KeyError(f"Missed columns {required_columns - actual_columns}")

# Create interactions feature schema
cat_feature_names_w_values = []
for cat_feature in cat_features:
values = interactions_df[cat_feature].unique() # TODO: decide NaN values
for value in values:
cat_feature_names_w_values.append((cat_feature, value))

interactions_schema = (
InteractionsFeaturesSchema(
cat_feature_names=list(cat_features),
direct_feature_names=list(direct_features),
cat_feature_names_w_values=cat_feature_names_w_values,
)
if cat_features or direct_features
else None
)

user_id_map = IdMap.from_values(interactions_df[Columns.User].values)
item_id_map = IdMap.from_values(interactions_df[Columns.Item].values)
interactions = Interactions.from_raw(interactions_df, user_id_map, item_id_map, keep_extra_cols)
Expand All @@ -278,7 +321,14 @@ def construct(
Columns.Item,
"item",
)
return cls(user_id_map, item_id_map, interactions, user_features, item_features)
return cls(
user_id_map=user_id_map,
item_id_map=item_id_map,
interactions=interactions,
user_features=user_features,
item_features=item_features,
interactions_schema=interactions_schema,
)

@staticmethod
def _make_features(
Expand Down
2 changes: 1 addition & 1 deletion rectools/models/nn/item_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,4 +486,4 @@ def forward(self, items: torch.Tensor) -> torch.Tensor:
@property
def out_dim(self) -> int:
"""Return item net constructor output dimension."""
return self.item_net_blocks[0].out_dim # type: ignore[return-value]
return self.item_net_blocks[0].out_dim
41 changes: 37 additions & 4 deletions rectools/models/nn/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
ItemNetConstructorBase,
SumOfEmbeddingsConstructor,
)
from .context_net import CatFeaturesContextNet, ContextNetBase
from .data_preparator import InitKwargs, TransformerDataPreparatorBase
from .lightning import TransformerLightningModule, TransformerLightningModuleBase
from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
Expand Down Expand Up @@ -117,6 +118,16 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]:
),
]

ContextNetType = tpe.Annotated[
tp.Type[ContextNetBase],
BeforeValidator(_get_class_obj),
PlainSerializer(
func=get_class_or_function_full_path,
return_type=str,
when_used="json",
),
]

TransformerDataPreparatorType = tpe.Annotated[
tp.Type[TransformerDataPreparatorBase],
BeforeValidator(_get_class_obj),
Expand Down Expand Up @@ -216,6 +227,7 @@ class TransformerModelConfig(ModelConfig):
negative_sampler_type: TransformerNegativeSamplerType = CatalogUniformSampler
similarity_module_type: SimilarityModuleType = DistanceSimilarityModule
backbone_type: TransformerBackboneType = TransformerTorchBackbone
context_net_type: ContextNetType = CatFeaturesContextNet
get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None
get_trainer_func: tp.Optional[TrainerCallableSerialized] = None
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None
Expand All @@ -228,6 +240,7 @@ class TransformerModelConfig(ModelConfig):
negative_sampler_kwargs: tp.Optional[InitKwargs] = None
similarity_module_kwargs: tp.Optional[InitKwargs] = None
backbone_kwargs: tp.Optional[InitKwargs] = None
context_net_kwargs: tp.Optional[InitKwargs] = None


TransformerModelConfig_T = tp.TypeVar("TransformerModelConfig_T", bound=TransformerModelConfig)
Expand Down Expand Up @@ -278,6 +291,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
negative_sampler_type: tp.Type[TransformerNegativeSamplerBase] = CatalogUniformSampler,
similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule,
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
context_net_type: tp.Type[ContextNetBase] = CatFeaturesContextNet,
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
get_trainer_func: tp.Optional[TrainerCallable] = None,
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
Expand All @@ -290,6 +304,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
negative_sampler_kwargs: tp.Optional[InitKwargs] = None,
similarity_module_kwargs: tp.Optional[InitKwargs] = None,
backbone_kwargs: tp.Optional[InitKwargs] = None,
context_net_kwargs: tp.Optional[InitKwargs] = None,
**kwargs: tp.Any,
) -> None:
super().__init__(verbose=verbose)
Expand Down Expand Up @@ -321,6 +336,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
self.lightning_module_type = lightning_module_type
self.negative_sampler_type = negative_sampler_type
self.backbone_type = backbone_type
self.context_net_type = context_net_type
self.get_val_mask_func = get_val_mask_func
self.get_trainer_func = get_trainer_func
self.get_val_mask_func_kwargs = get_val_mask_func_kwargs
Expand All @@ -333,7 +349,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
self.negative_sampler_kwargs = negative_sampler_kwargs
self.similarity_module_kwargs = similarity_module_kwargs
self.backbone_kwargs = backbone_kwargs

self.context_net_kwargs = context_net_kwargs
self._init_data_preparator()
self._init_trainer()

Expand Down Expand Up @@ -392,6 +408,16 @@ def _construct_item_net(self, dataset: Dataset) -> ItemNetBase:
**self._get_kwargs(self.item_net_constructor_kwargs),
)

def _construct_context_net(self, dataset_schema: DatasetSchema) -> tp.Optional[ContextNetBase]:
if dataset_schema.interactions is None:
return None
return self.context_net_type.from_dataset_schema(
dataset_schema,
self.n_factors,
self.dropout_rate,
**self._get_kwargs(self.context_net_kwargs),
)

def _construct_item_net_from_dataset_schema(self, dataset_schema: DatasetSchema) -> ItemNetBase:
return self.item_net_constructor_type.from_dataset_schema(
dataset_schema,
Expand Down Expand Up @@ -421,14 +447,17 @@ def _init_transformer_layers(self) -> TransformerLayersBase:
def _init_similarity_module(self) -> SimilarityModuleBase:
return self.similarity_module_type(**self._get_kwargs(self.similarity_module_kwargs))

def _init_torch_model(self, item_model: ItemNetBase) -> TransformerBackboneBase:
def _init_torch_model(
self, item_model: ItemNetBase, context_net: tp.Optional[ContextNetBase]
) -> TransformerBackboneBase:
pos_encoding_layer = self._init_pos_encoding_layer()
transformer_layers = self._init_transformer_layers()
similarity_module = self._init_similarity_module()
return self.backbone_type(
n_heads=self.n_heads,
dropout_rate=self.dropout_rate,
item_model=item_model,
context_net=context_net,
pos_encoding_layer=pos_encoding_layer,
transformer_layers=transformer_layers,
similarity_module=similarity_module,
Expand Down Expand Up @@ -464,7 +493,10 @@ def _init_lightning_model(
def _build_model_from_dataset(self, dataset: Dataset) -> None:
self.data_preparator.process_dataset_train(dataset)
item_model = self._construct_item_net(self.data_preparator.train_dataset)
torch_model = self._init_torch_model(item_model)
context_net = self._construct_context_net(
DatasetSchema.model_validate(self.data_preparator.train_dataset.get_schema())
)
torch_model = self._init_torch_model(item_model, context_net)

dataset_schema = self.data_preparator.train_dataset.get_schema()
item_external_ids = self.data_preparator.train_dataset.item_id_map.external_ids
Expand Down Expand Up @@ -589,7 +621,8 @@ def _model_from_checkpoint(cls, checkpoint: tp.Dict[str, tp.Any]) -> tpe.Self:

# Init and update torch model and lightning model
item_model = loaded._construct_item_net_from_dataset_schema(dataset_schema)
torch_model = loaded._init_torch_model(item_model)
context_net = loaded._construct_context_net(dataset_schema)
torch_model = loaded._init_torch_model(item_model, context_net)
loaded._init_lightning_model(
torch_model=torch_model,
dataset_schema=dataset_schema,
Expand Down
14 changes: 7 additions & 7 deletions rectools/models/nn/transformers/bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
ValMaskCallable,
)
from .constants import MASKING_VALUE, PADDING_VALUE
from .data_preparator import InitKwargs, TransformerDataPreparatorBase
from .data_preparator import BatchElement, InitKwargs, TransformerDataPreparatorBase
from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
from .net_blocks import (
LearnableInversePositionalEncoding,
Expand Down Expand Up @@ -128,7 +128,7 @@ def _mask_session(

def _collate_fn_train(
self,
batch: List[Tuple[List[int], List[float]]],
batch: List[BatchElement],
) -> Dict[str, torch.Tensor]:
"""
Mask session elements to receive `x`.
Expand All @@ -141,7 +141,7 @@ def _collate_fn_train(
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, self.session_max_len))
yw = np.zeros((batch_size, self.session_max_len))
for i, (ses, ses_weights) in enumerate(batch):
for i, (ses, ses_weights, _) in enumerate(batch):
masked_session, target = self._mask_session(ses)
x[i, -len(ses) :] = masked_session # ses: [session_len] -> x[i]: [session_max_len]
y[i, -len(ses) :] = target # ses: [session_len] -> y[i]: [session_max_len]
Expand All @@ -154,12 +154,12 @@ def _collate_fn_train(
)
return batch_dict

def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
def _collate_fn_val(self, batch: List[BatchElement]) -> Dict[str, torch.Tensor]:
batch_size = len(batch)
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, 1)) # until only leave-one-strategy
yw = np.zeros((batch_size, 1)) # until only leave-one-strategy
for i, (ses, ses_weights) in enumerate(batch):
for i, (ses, ses_weights, _) in enumerate(batch):
input_session = [ses[idx] for idx, weight in enumerate(ses_weights) if weight == 0]
session = input_session.copy()

Expand All @@ -179,14 +179,14 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st
)
return batch_dict

def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
def _collate_fn_recommend(self, batch: List[BatchElement]) -> Dict[str, torch.Tensor]:
"""
Right truncation, left padding to `session_max_len`
During inference model will use (`session_max_len` - 1) interactions
and one extra "MASK" token will be added for making predictions.
"""
x = np.zeros((len(batch), self.session_max_len))
for i, (ses, _) in enumerate(batch):
for i, (ses, _, _) in enumerate(batch):
session = ses.copy()
session = session + [self.extra_token_ids[MASKING_VALUE]]
x[i, -len(ses) - 1 :] = session[-self.session_max_len :]
Expand Down
81 changes: 81 additions & 0 deletions rectools/models/nn/transformers/context_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import typing as tp
import warnings

import torch
import typing_extensions as tpe
from torch import nn

from rectools.dataset.dataset import DatasetSchema

# TODO: support non-string values in feature names/values


class ContextNetBase(torch.nn.Module):
"""TODO."""

def __init__(self, n_factors: int, dropout_rate: float, **kwargs: tp.Any):
super().__init__()

def forward(self, seqs: torch.Tensor, batch: tp.Dict[str, torch.Tensor]) -> torch.Tensor:
"""TODO."""
raise NotImplementedError

@classmethod
def from_dataset_schema(
cls, dataset_schema: DatasetSchema, *args: tp.Any, **kwargs: tp.Any
) -> tp.Optional[tpe.Self]:
"""Construct ItemNet from Dataset schema."""
raise NotImplementedError()

@property
def out_dim(self) -> int:
"""Return item embedding output dimension."""
raise NotImplementedError()


class CatFeaturesContextNet(ContextNetBase):
"""TODO."""

def __init__(
self,
n_factors: int,
dropout_rate: float,
n_cat_feature_values: int,
batch_key: str = "context_cat_inputs",
**kwargs: tp.Any,
) -> None:
super().__init__(n_factors, dropout_rate, **kwargs)
print(n_cat_feature_values)
self.embedding_bag = nn.EmbeddingBag(num_embeddings=n_cat_feature_values, embedding_dim=n_factors, mode="sum")
self.dropout = nn.Dropout(dropout_rate)
self.batch_key = batch_key

@classmethod
def from_dataset_schema( # TODO: decide about target aware schema
cls, dataset_schema: DatasetSchema, n_factors: int, dropout_rate: float, **kwargs: tp.Any
) -> tp.Optional[tpe.Self]:
"""TODO."""
if dataset_schema.interactions is None:
warnings.warn("No interactions schema found in dataset schema, context net will not be constructed")
return None
if dataset_schema.interactions.direct_feature_names:
warnings.warn("Direct features are not supported in context net")
if len(dataset_schema.interactions.cat_feature_names_w_values) == 0:
warnings.warn("No categorical features found in dataset schema, context net will not be constructed")
return None
n_cat_feature_values = len(dataset_schema.interactions.cat_feature_names_w_values)
return cls(n_factors=n_factors, dropout_rate=dropout_rate, n_cat_feature_values=n_cat_feature_values)

def forward(self, seqs: torch.Tensor, batch: tp.Dict[str, torch.Tensor]) -> torch.Tensor:
"""TODO."""
batch_size, session_max_len, n_factors = seqs.shape
inputs = batch[self.batch_key].view(batch_size * session_max_len, -1)
context_embs = self.embedding_bag(input=inputs)
context_embs = self.dropout(context_embs)
context_embs = context_embs.view(batch_size, session_max_len, n_factors)
return context_embs

@property
def out_dim(self) -> int:
"""Return output dimension."""
return self.embedding_bag.embedding_dim
Loading
Loading