Skip to content

Commit ef991d0

Browse files
lizhouyufacebook-github-bot
authored andcommitted
Create an example for MPZCH (#3063)
Summary: Pull Request resolved: #3063 ### Major changes - Create a `mpzch` folder under the `torchrec/github/examples` folder - Implement a simple SparseArch module with a flag to switch between original and MPZCH managed collision modules - Profile the running time and QPS for model training(GPU)/inference(CPU) - Create a notebook tutorial for ZCH basics and the use of ZCH modules in TorchRec ### ToDos for OSS - When the internal torchrec MPZCH module is OSS - Remove the `BUCK` file - Replace all the `from torchrec.fb.modules` in `sparse_arch.py` to `from torchrec.modules` ### Potential improvement - Add hash collision counter - Show profiling results in the Readme file - Add multi-batch profiling Reviewed By: aporialiao Differential Revision: D75570684 fbshipit-source-id: 832119f6524c7c126384033f39276264022e0fe3
1 parent 0d6ef90 commit ef991d0

File tree

6 files changed

+863
-0
lines changed

6 files changed

+863
-0
lines changed

examples/zch/Readme.md

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Managed Collision Hash Example
2+
3+
This example demonstrates the usage of managed collision hash feature in TorchRec, which is designed to efficiently handle hash collisions in embedding tables. We include two implementations of the feature: sorted managed collision Hash (MCH) and MPZCH (Multi-Probe Zero Collision Hash).
4+
5+
## Folder Structure
6+
7+
```
8+
managed_collision_hash/
9+
├── Readme.md # This documentation file
10+
├── __init__.py # Python package marker
11+
├── main.py # Main script to run the benchmark
12+
└── sparse_arch.py # Implementation of the sparse architecture with managed collision
13+
└── zero_collision_hash_tutorial.ipynb # Jupyter notebook for the motivation of zero collision hash and the use of zero collision hash modules in TorchRec
14+
```
15+
16+
### Introduction of MPZCH
17+
18+
Multi-probe Zero Collision Hash (MPZCH) is a technique that can be used to reduce the collision rate for embedding table lookups. For the concept of hash collision and why we need to manage the collision, please refer to the [zero collision hash tutorial](zero_collision_hash_tutorial.ipynb).
19+
20+
A MPZCH module contains two essential tables: the identity table and the metadata table.
21+
The identity table is used to record the mapping from input hash value to the remapped ID. The value in each identity table slot is an input hash value, and that hash value's remmaped ID is the index of the slot.
22+
The metadata table share the same length as the identity table. The time when a hash value is inserted into a identity table slot is recorded in the same-indexed slot of the metadata table.
23+
24+
Specifically, MPZCH include the following two steps:
25+
1. **First Probe**: Check if there are available or evictable slots in its identity table.
26+
2. **Second Probe**: Check if the slot for indexed with the input hash value is occupied. If not, directly insert the input hash value into that slot. Otherwise, perform a linear probe to find the next available slot. If all the slots are occupied, find the next evictable slot whose value has stayed in the table for a time longer than a threshold, and replace the expired hash value with the input one.
27+
28+
The use of MPZCH module `HashZchManagedCollisionModule` are introduced with detailed comments in the [sparse_arch.py](sparse_arch.py) file.
29+
30+
The module can be configured to use different eviction policies and parameters.
31+
32+
The detailed function calls are shown in the diagram below
33+
![MPZCH Module Data Flow](docs/mpzch_module_dataflow.png)
34+
35+
#### Relationship among Important Parameters
36+
37+
The `HashZchManagedCollisionModule` module has three important parameters for initialization
38+
- `num_embeddings`: the number of embeddings in the embedding table
39+
- `num_buckets`: the number of buckets in the hash table
40+
41+
The `num_buckets` is used as the minimal sharding unit for the embedding table. Because we are doing linear probe in MPZCH, when resharding the embedding table, we want to avoid separate the remapped index of an input feature ID and its hash value to different ranks. So we make sure they are in the same bucket, and move the whole bucket during resharding.
42+
43+
## Usage
44+
We also prepare a profiling example of an Sparse Arch implemented with different ZCH techniques.
45+
To run the profiling example with sorted ZCH:
46+
47+
```bash
48+
python main.py
49+
```
50+
51+
To run the profiling example with MPZCH:
52+
53+
```bash
54+
python main.py --use_mpzch
55+
```
56+
57+
You can also specify the `batch_size`, `num_iters`, and `num_embeddings_per_table`:
58+
```bash
59+
python main.py --use_mpzch --batch_size 8 --num_iters 100 --num_embeddings_per_table 1000
60+
```
61+
62+
The example allows you to compare the QPS of embedding operations with sorted ZCH and MPZCH. On our server with A100 GPU, the initial QPS benchmark results with `batch_size=8`, `num_iters=100`, and `num_embeddings_per_table=1000` is presented in the table below:
63+
64+
| ZCH module | QPS |
65+
| --- | --- |
66+
| sorted ZCH | 1371.6942797862002 |
67+
| MPZCH | 2750.4449443587414 |
68+
69+
And with `batch_size=1024`, `num_iters=1000`, and `num_embeddings_per_table=1000` is
70+
71+
| ZCH module | QPS |
72+
| --- | --- |
73+
| sorted ZCH | 263827.54955056956 |
74+
| MPZCH | 551306.9687760604 |

examples/zch/__init__.py

Whitespace-only changes.
911 KB
Loading

examples/zch/main.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import argparse
11+
import time
12+
13+
import torch
14+
15+
from torchrec import EmbeddingConfig, KeyedJaggedTensor
16+
from torchrec.distributed.benchmark.benchmark_utils import get_inputs
17+
from tqdm import tqdm
18+
19+
from .sparse_arch import SparseArch
20+
21+
22+
def main(args: argparse.Namespace) -> None:
23+
"""
24+
This function tests the performance of a Sparse module with or without the MPZCH feature.
25+
Arguments:
26+
use_mpzch: bool, whether to enable MPZCH or not
27+
Prints:
28+
duration: time for a forward pass of the Sparse module with or without MPZCH enabled
29+
collision_rate: the collision rate of the MPZCH feature
30+
"""
31+
print(f"Is use MPZCH: {args.use_mpzch}")
32+
33+
# check available devices
34+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
35+
# device = torch.device("cpu")
36+
37+
print(f"Using device: {device}")
38+
39+
# create an embedding configuration
40+
embedding_config = [
41+
EmbeddingConfig(
42+
name="table_0",
43+
feature_names=["feature_0"],
44+
embedding_dim=8,
45+
num_embeddings=args.num_embeddings_per_table,
46+
),
47+
EmbeddingConfig(
48+
name="table_1",
49+
feature_names=["feature_1"],
50+
embedding_dim=8,
51+
num_embeddings=args.num_embeddings_per_table,
52+
),
53+
]
54+
55+
# generate kjt input list
56+
input_kjt_list = []
57+
for _ in range(args.num_iters):
58+
input_kjt_single = KeyedJaggedTensor.from_lengths_sync(
59+
keys=["feature_0", "feature_1"],
60+
# pick a set of 24 random numbers from 0 to args.num_embeddings_per_table
61+
values=torch.LongTensor(
62+
list(
63+
torch.randint(
64+
0, args.num_embeddings_per_table, (3 * args.batch_size,)
65+
)
66+
)
67+
),
68+
lengths=torch.LongTensor([1] * args.batch_size + [2] * args.batch_size),
69+
weights=None,
70+
)
71+
input_kjt_single = input_kjt_single.to(device)
72+
input_kjt_list.append(input_kjt_single)
73+
74+
num_requests = args.num_iters * args.batch_size
75+
76+
# make the model
77+
model = SparseArch(
78+
tables=embedding_config,
79+
device=device,
80+
return_remapped=True,
81+
use_mpzch=args.use_mpzch,
82+
buckets=1,
83+
)
84+
85+
# do the forward pass
86+
if device.type == "cuda":
87+
torch.cuda.synchronize()
88+
starter = torch.cuda.Event(enable_timing=True)
89+
ender = torch.cuda.Event(enable_timing=True)
90+
91+
# record the start time
92+
starter.record()
93+
for it_idx in tqdm(range(args.num_iters)):
94+
# ec_out, remapped_ids_out = model(input_kjt_single)
95+
input_kjt = input_kjt_list[it_idx].to(device)
96+
ec_out, remapped_ids_out = model(input_kjt)
97+
# record the end time
98+
ender.record()
99+
# wait for the end time to be recorded
100+
torch.cuda.synchronize()
101+
duration = starter.elapsed_time(ender) / 1000.0 # convert to seconds
102+
else:
103+
# in cpu mode, MPZCH can only run in inference mode, so we profile the model in eval mode
104+
model.eval()
105+
if args.use_mpzch:
106+
# when using MPZCH modules, we need to manually set the modules to be in inference mode
107+
# pyre-ignore
108+
model._mc_ec._managed_collision_collection._managed_collision_modules[
109+
"table_0"
110+
].reset_inference_mode()
111+
# pyre-ignore
112+
model._mc_ec._managed_collision_collection._managed_collision_modules[
113+
"table_1"
114+
].reset_inference_mode()
115+
116+
start_time = time.time()
117+
for it_idx in tqdm(range(args.num_iters)):
118+
input_kjt = input_kjt_list[it_idx].to(device)
119+
ec_out, remapped_ids_out = model(input_kjt)
120+
end_time = time.time()
121+
duration = end_time - start_time
122+
# get qps
123+
qps = num_requests / duration
124+
print(f"qps: {qps}")
125+
# print the duration
126+
print(f"duration: {duration} seconds")
127+
128+
129+
if __name__ == "__main__":
130+
parser = argparse.ArgumentParser()
131+
parser.add_argument("--use_mpzch", action="store_true", default=False)
132+
parser.add_argument("--num_iters", type=int, default=100)
133+
parser.add_argument("--batch_size", type=int, default=8)
134+
parser.add_argument("--num_embeddings_per_table", type=int, default=1000)
135+
args: argparse.Namespace = parser.parse_args()
136+
main(args)

examples/zch/sparse_arch.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
from typing import Dict, List, Optional, Tuple, Union
11+
12+
import torch
13+
from torch import nn
14+
15+
from torchrec import (
16+
EmbeddingCollection,
17+
EmbeddingConfig,
18+
JaggedTensor,
19+
KeyedJaggedTensor,
20+
KeyedTensor,
21+
)
22+
23+
# For MPZCH
24+
from torchrec.modules.hash_mc_evictions import (
25+
HashZchEvictionConfig,
26+
HashZchEvictionPolicyName,
27+
)
28+
29+
# For MPZCH
30+
from torchrec.modules.hash_mc_modules import HashZchManagedCollisionModule
31+
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection
32+
33+
# For original MC
34+
from torchrec.modules.mc_modules import (
35+
DistanceLFU_EvictionPolicy,
36+
ManagedCollisionCollection,
37+
MCHManagedCollisionModule,
38+
)
39+
40+
41+
class SparseArch(nn.Module):
42+
"""
43+
Class SparseArch
44+
An example of SparseArch with 2 tables, each with 2 features.
45+
It looks up the corresponding embedding for incoming KeyedJaggedTensors with 2 features
46+
and returns the corresponding embeddings.
47+
48+
Parameters:
49+
tables(List[EmbeddingConfig]): List of EmbeddingConfig that defines the embedding table
50+
device(torch.device): device on which the embedding table should be placed
51+
buckets(int): number of buckets for each table
52+
input_hash_size(int): input hash size for each table
53+
return_remapped(bool): whether to return remapped features, if so, the return will be
54+
a tuple of (Embedding(KeyedTensor), Remapped_ID(KeyedJaggedTensor)), otherwise, the return will be
55+
a tuple of (Embedding(KeyedTensor), None)
56+
is_inference(bool): whether to use inference mode. In inference mode, the module will not update the embedding table
57+
use_mpzch(bool): whether to use MPZCH or not. If true, the module will use MPZCH managed collision module,
58+
otherwise, it will use original MC managed collision module
59+
"""
60+
61+
def __init__(
62+
self,
63+
tables: List[EmbeddingConfig],
64+
device: torch.device,
65+
buckets: int = 4,
66+
input_hash_size: int = 4000,
67+
return_remapped: bool = False,
68+
is_inference: bool = False,
69+
use_mpzch: bool = False,
70+
) -> None:
71+
super().__init__()
72+
self._return_remapped = return_remapped
73+
74+
mc_modules = {}
75+
76+
if (
77+
use_mpzch
78+
): # if using the MPZCH module, we create a HashZchManagedCollisionModule for each table
79+
mc_modules["table_0"] = HashZchManagedCollisionModule(
80+
is_inference=is_inference,
81+
zch_size=(
82+
tables[0].num_embeddings
83+
), # the zch size, that is, the size of local embedding table, should be the same as the size of the embedding table
84+
input_hash_size=input_hash_size, # the input hash size, that is, the size of the input id space
85+
device=device, # the device on which the embedding table should be placed
86+
total_num_buckets=buckets, # the number of buckets, the detailed explanation of the use of buckets can be found in the readme file
87+
eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, # the eviction policy name, in this example use the single ttl eviction policy, which assume an id is evictable if it has been in the table longer than the ttl (time to live)
88+
eviction_config=HashZchEvictionConfig( # Here we need to specify for each feature, what is the ttl, that is, how long an id can stay in the table before it is evictable
89+
features=[
90+
"feature_0"
91+
], # because we only have one feature "feature_0" in this table, so we only need to specify the ttl for this feature
92+
single_ttl=1, # The unit of ttl is hour. Let's set the ttl to be default to 1, which means an id is evictable if it has been in the table for more than one hour.
93+
),
94+
)
95+
mc_modules["table_1"] = HashZchManagedCollisionModule(
96+
is_inference=is_inference,
97+
zch_size=(tables[1].num_embeddings),
98+
device=device,
99+
input_hash_size=input_hash_size,
100+
total_num_buckets=buckets,
101+
eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION,
102+
eviction_config=HashZchEvictionConfig(
103+
features=["feature_1"],
104+
single_ttl=1,
105+
),
106+
)
107+
else: # if not using the MPZCH module, we create a MCHManagedCollisionModule for each table
108+
mc_modules["table_0"] = MCHManagedCollisionModule(
109+
zch_size=(tables[0].num_embeddings),
110+
input_hash_size=input_hash_size,
111+
device=device,
112+
eviction_interval=2,
113+
eviction_policy=DistanceLFU_EvictionPolicy(),
114+
)
115+
mc_modules["table_1"] = MCHManagedCollisionModule(
116+
zch_size=(tables[1].num_embeddings),
117+
device=device,
118+
input_hash_size=input_hash_size,
119+
eviction_interval=1,
120+
eviction_policy=DistanceLFU_EvictionPolicy(),
121+
)
122+
123+
self._mc_ec: ManagedCollisionEmbeddingCollection = (
124+
ManagedCollisionEmbeddingCollection(
125+
EmbeddingCollection(
126+
tables=tables,
127+
device=device,
128+
),
129+
ManagedCollisionCollection(
130+
managed_collision_modules=mc_modules,
131+
embedding_configs=tables,
132+
),
133+
return_remapped_features=self._return_remapped,
134+
)
135+
)
136+
137+
def forward(
138+
self, kjt: KeyedJaggedTensor
139+
) -> Tuple[
140+
Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor]
141+
]:
142+
return self._mc_ec(kjt)

0 commit comments

Comments
 (0)