Skip to content

Commit 3790c26

Browse files
lizhouyufacebook-github-bot
authored andcommitted
OSS TorchRec MPZCH Modules (#3147)
Summary: Pull Request resolved: #3147 Pull Request resolved: #3089 Pull Request resolved: #3017 ### Major changes - Copy the following files from `fb` to corresponding location in the `torchrec` repository - `fb/distributed/hash_mc_embedding.py → torchrec/distributed/hash_mc_embedding.py` - `fb/modules/hash_mc_evictions.py → torchrec/modules/hash_mc_evictions.py` - `fb/modules/hash_mc_metrics.py → torchrec/modules/hash_mc_metrics.py` - `fb/modules/hash_mc_modules.py → torchrec/modules/hash_mc_modules.py` - `fb/modules/tests/test_hash_mc_evictions.py → torchrec/modules/tests/test_hash_mc_evictions.py` - `fb/modules/tests/test_hash_mc_modules.py → torchrec/modules/tests/test_hash_mc_modules.py` - Update `/modules/hash_mc_metrics.py` - Replace the tensorboard module with a local file logger in `hash_mc_metrics.py` module to avoid OSS CI test failures - The original tensorboard version is kept in the `torchrec/fb` folder. - Update the license declaration headers for the four OSS files Reviewed By: kausv Differential Revision: D77558442 fbshipit-source-id: 99c00712e0f8e84ff2629943c1b2e82d64a6b392
1 parent a37ef40 commit 3790c26

File tree

5 files changed

+1860
-0
lines changed

5 files changed

+1860
-0
lines changed

torchrec/modules/hash_mc_evictions.py

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
9+
# pyre-strict
10+
11+
#!/usr/bin/env python3
12+
13+
import logging
14+
import time
15+
from dataclasses import dataclass
16+
from enum import Enum, unique
17+
from typing import List, Optional, Tuple
18+
19+
import torch
20+
from pyre_extensions import none_throws
21+
22+
from torchrec.sparse.jagged_tensor import JaggedTensor
23+
24+
logger: logging.Logger = logging.getLogger(__name__)
25+
26+
27+
@unique
28+
class HashZchEvictionPolicyName(Enum):
29+
# eviction based on the time the ID is last seen during training,
30+
# and a single TTL
31+
SINGLE_TTL_EVICTION = "SINGLE_TTL_EVICTION"
32+
# eviction based on the time the ID is last seen during training,
33+
# and per-feature TTLs
34+
PER_FEATURE_TTL_EVICTION = "PER_FEATURE_TTL_EVICTION"
35+
# eviction based on least recently seen ID within the probe range
36+
LRU_EVICTION = "LRU_EVICTION"
37+
38+
39+
@torch.jit.script
40+
@dataclass
41+
class HashZchEvictionConfig:
42+
features: List[str]
43+
single_ttl: Optional[int] = None
44+
per_feature_ttl: Optional[List[int]] = None
45+
46+
47+
@torch.fx.wrap
48+
def get_kernel_from_policy(
49+
policy_name: Optional[HashZchEvictionPolicyName],
50+
) -> int:
51+
return (
52+
1
53+
if policy_name is not None
54+
and policy_name == HashZchEvictionPolicyName.LRU_EVICTION
55+
else 0
56+
)
57+
58+
59+
class HashZchEvictionScorer:
60+
def __init__(self, config: HashZchEvictionConfig) -> None:
61+
self._config: HashZchEvictionConfig = config
62+
63+
def gen_score(self, feature: JaggedTensor, device: torch.device) -> torch.Tensor:
64+
return torch.empty(0, device=device)
65+
66+
def gen_threshold(self) -> int:
67+
return -1
68+
69+
70+
class HashZchSingleTtlScorer(HashZchEvictionScorer):
71+
def gen_score(self, feature: JaggedTensor, device: torch.device) -> torch.Tensor:
72+
assert (
73+
self._config.single_ttl is not None and self._config.single_ttl > 0
74+
), "To use scorer HashZchSingleTtlScorer, a positive single_ttl is required."
75+
76+
return torch.full_like(
77+
feature.values(),
78+
# pyre-ignore [58]
79+
self._config.single_ttl + int(time.time() / 3600),
80+
dtype=torch.int32,
81+
device=device,
82+
)
83+
84+
def gen_threshold(self) -> int:
85+
return int(time.time() / 3600)
86+
87+
88+
class HashZchPerFeatureTtlScorer(HashZchEvictionScorer):
89+
def __init__(self, config: HashZchEvictionConfig) -> None:
90+
super().__init__(config)
91+
92+
assert self._config.per_feature_ttl is not None and len(
93+
self._config.features
94+
) == len(
95+
# pyre-ignore [6]
96+
self._config.per_feature_ttl
97+
), "To use scorer HashZchPerFeatureTtlScorer, a 1:1 mapping between features and per_feature_ttl is required."
98+
99+
self._per_feature_ttl = torch.IntTensor(self._config.per_feature_ttl)
100+
101+
def gen_score(self, feature: JaggedTensor, device: torch.device) -> torch.Tensor:
102+
feature_split = feature.weights()
103+
assert feature_split.size(0) == self._per_feature_ttl.size(0)
104+
105+
scores = self._per_feature_ttl.repeat_interleave(feature_split) + int(
106+
time.time() / 3600
107+
)
108+
109+
return scores.to(device=device)
110+
111+
def gen_threshold(self) -> int:
112+
return int(time.time() / 3600)
113+
114+
115+
@torch.fx.wrap
116+
def get_eviction_scorer(
117+
policy_name: str, config: HashZchEvictionConfig
118+
) -> HashZchEvictionScorer:
119+
if policy_name == HashZchEvictionPolicyName.SINGLE_TTL_EVICTION:
120+
return HashZchSingleTtlScorer(config)
121+
elif policy_name == HashZchEvictionPolicyName.PER_FEATURE_TTL_EVICTION:
122+
return HashZchPerFeatureTtlScorer(config)
123+
elif policy_name == HashZchEvictionPolicyName.LRU_EVICTION:
124+
return HashZchSingleTtlScorer(config)
125+
else:
126+
return HashZchEvictionScorer(config)
127+
128+
129+
class HashZchThresholdEvictionModule(torch.nn.Module):
130+
"""
131+
This module manages the computation of eviction score for input IDs. Based on the selected
132+
eviction policy, a scorer is initiated to generate a score for each ID. The kernel
133+
will use this score to make eviction decisions.
134+
135+
Args:
136+
policy_name: an enum value that indicates the eviction policy to use.
137+
config: a config that contains information needed to run the eviction policy.
138+
139+
Example::
140+
module = HashZchThresholdEvictionModule(...)
141+
score = module(feature)
142+
"""
143+
144+
_eviction_scorer: HashZchEvictionScorer
145+
146+
def __init__(
147+
self,
148+
policy_name: HashZchEvictionPolicyName,
149+
config: HashZchEvictionConfig,
150+
) -> None:
151+
super().__init__()
152+
153+
self._policy_name: HashZchEvictionPolicyName = policy_name
154+
self._config: HashZchEvictionConfig = config
155+
self._eviction_scorer = get_eviction_scorer(
156+
policy_name=self._policy_name,
157+
config=self._config,
158+
)
159+
160+
logger.info(
161+
f"HashZchThresholdEvictionModule: {self._policy_name=}, {self._config=}"
162+
)
163+
164+
def forward(
165+
self, feature: JaggedTensor, device: torch.device
166+
) -> Tuple[torch.Tensor, int]:
167+
"""
168+
Args:
169+
feature: a jagged tensor that contains the input IDs, and their lengths and
170+
weights (feature split).
171+
device: device of the tensor.
172+
173+
Returns:
174+
a tensor that contains the eviction score for each ID, plus an eviction threshold.
175+
"""
176+
return (
177+
self._eviction_scorer.gen_score(feature, device),
178+
self._eviction_scorer.gen_threshold(),
179+
)
180+
181+
182+
class HashZchOptEvictionModule(torch.nn.Module):
183+
"""
184+
This module manages the eviction of IDs from the ZCH table based on the selected eviction policy.
185+
Args:
186+
policy_name: an enum value that indicates the eviction policy to use.
187+
Example:
188+
module = HashZchOptEvictionModule(policy_name=HashZchEvictionPolicyName.LRU_EVICTION)
189+
"""
190+
191+
def __init__(
192+
self,
193+
policy_name: HashZchEvictionPolicyName,
194+
) -> None:
195+
super().__init__()
196+
197+
self._policy_name: HashZchEvictionPolicyName = policy_name
198+
199+
def forward(self, feature: JaggedTensor, device: torch.device) -> Tuple[None, int]:
200+
"""
201+
Does not apply to this Eviction Policy. Returns None and -1.
202+
Args:
203+
feature: No op
204+
device: No op
205+
Returns:
206+
None, -1
207+
"""
208+
return None, -1
209+
210+
211+
@torch.fx.wrap
212+
def get_eviction_module(
213+
policy_name: HashZchEvictionPolicyName, config: Optional[HashZchEvictionConfig]
214+
) -> torch.nn.Module:
215+
if policy_name in (
216+
HashZchEvictionPolicyName.SINGLE_TTL_EVICTION,
217+
HashZchEvictionPolicyName.PER_FEATURE_TTL_EVICTION,
218+
HashZchEvictionPolicyName.LRU_EVICTION,
219+
):
220+
return HashZchThresholdEvictionModule(policy_name, none_throws(config))
221+
else:
222+
return HashZchOptEvictionModule(policy_name)
223+
224+
225+
class HashZchEvictionModule(torch.nn.Module):
226+
"""
227+
This module manages the eviction of IDs from the ZCH table based on the selected eviction policy.
228+
Args:
229+
policy_name: an enum value that indicates the eviction policy to use.
230+
device: device of the tensor.
231+
config: an optional config required if threshold based eviction is selected.
232+
Example:
233+
module = HashZchEvictionModule(policy_name=HashZchEvictionPolicyName.LRU_EVICTION)
234+
"""
235+
236+
def __init__(
237+
self,
238+
policy_name: HashZchEvictionPolicyName,
239+
device: torch.device,
240+
config: Optional[HashZchEvictionConfig],
241+
) -> None:
242+
super().__init__()
243+
244+
self._policy_name: HashZchEvictionPolicyName = policy_name
245+
self._device: torch.device = device
246+
self._eviction_module: torch.nn.Module = get_eviction_module(
247+
self._policy_name, config
248+
)
249+
250+
logger.info(f"HashZchEvictionModule: {self._policy_name=}, {self._device=}")
251+
252+
def forward(self, feature: JaggedTensor) -> Tuple[Optional[torch.Tensor], int]:
253+
"""
254+
Args:
255+
feature: a jagged tensor that contains the input IDs, and their lengths and
256+
weights (feature split).
257+
258+
Returns:
259+
For threshold eviction, a tensor that contains the eviction score for each ID, plus an eviction threshold. Otherwise None and -1.
260+
"""
261+
return self._eviction_module(feature, self._device)

0 commit comments

Comments
 (0)