Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 62 additions & 31 deletions fbgemm_gpu/bench/merge_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import numpy as np
import tabulate
import torch

from fbgemm_gpu.split_embedding_configs import SparseType
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
BoundsCheckMode,
Expand Down Expand Up @@ -99,46 +98,72 @@ def generate_requests(
return rs


# pyre-fixme[3]: Return type must be annotated.
def _get_random_tensor(
num_ads: int,
embedding_dimension: int,
ads_tables: int,
data_type: str,
gpu_idx: int,
include_quantization: bool,
):
use_pitched: bool,
alignment: int = 256, # alignment in bytes
) -> torch.Tensor:
device = torch.device(f"cuda:{gpu_idx}")

if data_type == "FP16" or include_quantization:
result_tensor = torch.randn(
num_ads,
embedding_dimension * ads_tables,
dtype=torch.float16,
device=torch.device(f"cuda:{gpu_idx}"),
)
dtype = torch.float16
width_elems = embedding_dimension * ads_tables
elem_size = torch.finfo(dtype).bits // 8

if use_pitched:
width_bytes = width_elems * elem_size
pitch_bytes = int(np.ceil(width_bytes / alignment) * alignment)
pitch_elems = pitch_bytes // elem_size
storage = torch.empty((num_ads, pitch_elems), dtype=dtype, device=device)
result_tensor = storage[:, :width_elems] # logical view
else:
result_tensor = torch.randn(
num_ads, width_elems, dtype=dtype, device=device
)

elif data_type == "INT8":
assert (
embedding_dimension % 2
) == 0, "needs to align to 2 bytes (half type size) for INT8"
result_tensor = torch.randint(
0,
255,
# 2 FP16 numbers for scale and bias, total of 4 bytes overhead
size=(num_ads, (embedding_dimension + 4) * ads_tables),
dtype=torch.uint8,
device=torch.device(f"cuda:{gpu_idx}"),
)
assert embedding_dimension % 2 == 0, "needs to align to 2 bytes for INT8"
dtype = torch.uint8
width_elems = (embedding_dimension + 4) * ads_tables
elem_size = 1

if use_pitched:
width_bytes = width_elems * elem_size
pitch_bytes = int(np.ceil(width_bytes / alignment) * alignment)
pitch_elems = pitch_bytes // elem_size
storage = torch.randint(
0, 255, (num_ads, pitch_elems), dtype=dtype, device=device
)
result_tensor = storage[:, :width_elems]
else:
result_tensor = torch.randint(
0, 255, (num_ads, width_elems), dtype=dtype, device=device
)

elif data_type == "INT4":
assert (
embedding_dimension % 4
) == 0, "needs to align to 2 bytes (half type size) for INT4"
result_tensor = torch.randint(
0,
255,
# Using torch.uint8 for int4 storage
size=(num_ads, (embedding_dimension // 2 + 4) * ads_tables),
dtype=torch.uint8,
device=torch.device(f"cuda:{gpu_idx}"),
)
assert embedding_dimension % 4 == 0, "needs to align to 2 bytes for INT4"
dtype = torch.uint8
width_elems = (embedding_dimension // 2 + 4) * ads_tables
elem_size = 1

if use_pitched:
width_bytes = width_elems * elem_size
pitch_bytes = int(np.ceil(width_bytes / alignment) * alignment)
pitch_elems = pitch_bytes // elem_size
storage = torch.randint(
0, 255, (num_ads, pitch_elems), dtype=dtype, device=device
)
result_tensor = storage[:, :width_elems]
else:
result_tensor = torch.randint(
0, 255, (num_ads, width_elems), dtype=dtype, device=device
)

else:
raise ValueError

Expand Down Expand Up @@ -253,6 +278,7 @@ def benchmark( # noqa C901
num_ads: int,
embedding_dimension: int,
ads_tables: int,
use_pitched: bool,
iters: int = 10,
p2p_bw: bool = False,
dst_device: int = 0,
Expand Down Expand Up @@ -298,6 +324,7 @@ def benchmark( # noqa C901
data_type,
gpu_idx,
include_quantization,
use_pitched,
)
for gpu_idx in range(num_gpus)
]
Expand Down Expand Up @@ -485,6 +512,7 @@ def pool_func_with_quantization(
@click.option("--num_of_embeddings", default=100000, type=int)
@click.option("--pooling_factor", default=25, type=int)
@click.option("--sweep", is_flag=True, default=False)
@click.option("--use_pitched", is_flag=True, default=False)
def cli(
all_to_one_only: bool,
sum_reduce_to_one_only: bool,
Expand All @@ -500,6 +528,7 @@ def cli(
num_of_embeddings: int,
pooling_factor: int,
sweep: bool,
use_pitched: bool,
) -> None:
csv_header = (
"mode, data_type, num_ads, embedding_dimension, ads_tables, num_gpus, dst_device, all_to_one_only, "
Expand Down Expand Up @@ -534,6 +563,7 @@ def handler(signum, frame):
num_ads,
embedding_dimension,
ads_tables,
use_pitched,
iters,
p2p_bw,
dst_device,
Expand All @@ -558,6 +588,7 @@ def handler(signum, frame):
num_ads,
embedding_dimension,
ads_tables,
use_pitched,
iters,
p2p_bw,
dst_device,
Expand Down
63 changes: 41 additions & 22 deletions fbgemm_gpu/test/merge_pooled_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import fbgemm_gpu

import hypothesis.strategies as st
import numpy as np
import torch
from hypothesis import given, settings, Verbosity

Expand All @@ -32,8 +33,29 @@
typed_gpu_unavailable: tuple[bool, str] = gpu_unavailable


@unittest.skipIf(*gpu_unavailable)
@unittest.skipIf(open_source, "Not supported in open source yet")
def make_pitched_tensor(
height: int,
width: int,
dtype: torch.dtype,
# pyre-fixme[2]: Parameter must be annotated.
device,
alignment: int = 256,
) -> torch.Tensor:
elem_size = (
torch.finfo(dtype).bits // 8
if dtype.is_floating_point
else torch.iinfo(dtype).bits // 8
)
width_bytes = width * elem_size
pitch_bytes = int(np.ceil(width_bytes / alignment) * alignment)
pitch_elems = pitch_bytes // elem_size
storage = torch.randn((height, pitch_elems), dtype=dtype, device=device)
view = storage[:, :width] # logical shape
return view.contiguous() if alignment == 0 else view # return pitched view


# @unittest.skipIf(open_source, "Not supported in open source yet")
@unittest.skipIf(*typed_gpu_unavailable)
class MergePooledEmbeddingsTest(unittest.TestCase):
# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `hypothesis.strategies.integers($parameter$min_value = 1, $parameter$max_value =
Expand All @@ -51,16 +73,11 @@ class MergePooledEmbeddingsTest(unittest.TestCase):
@settings(verbosity=Verbosity.verbose, max_examples=40, deadline=None)
def test_merge(
self,
# pyre-fixme[2]: Parameter must be annotated.
num_ads,
# pyre-fixme[2]: Parameter must be annotated.
embedding_dimension,
# pyre-fixme[2]: Parameter must be annotated.
ads_tables,
# pyre-fixme[2]: Parameter must be annotated.
num_gpus,
# pyre-fixme[2]: Parameter must be annotated.
non_default_stream,
num_ads: int,
embedding_dimension: int,
ads_tables: int,
num_gpus: int,
non_default_stream: bool,
# pyre-fixme[2]: Parameter must be annotated.
r,
dim: int,
Expand Down Expand Up @@ -107,27 +124,32 @@ def ref(pooled_ad_embeddings, batch_indices):
torch.testing.assert_close(output_ref, output_cpu)

# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `hypothesis.strategies.integers($parameter$min_value = 1, $parameter$max_value =
# 10)` to decorator factory `hypothesis.given`.
@given(
num_inputs=st.integers(min_value=1, max_value=10),
num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()),
r=st.randoms(use_true_random=False),
use_pitched=st.booleans(),
)
# Can instantiate 8 contexts which takes a long time.
@settings(verbosity=Verbosity.verbose, max_examples=40, deadline=None)
def test_all_to_one_device(
self,
# pyre-fixme[2]: Parameter must be annotated.
num_inputs,
# pyre-fixme[2]: Parameter must be annotated.
num_gpus,
num_inputs: int,
num_gpus: int,
# pyre-fixme[2]: Parameter must be annotated.
r,
use_pitched: bool,
) -> None:
dst_device = torch.device(f"cuda:{r.randint(0, num_gpus - 1)}")
with torch.cuda.device(dst_device):
inputs = [torch.randn(10, 20) for _ in range(num_inputs)]
if use_pitched:
inputs = [
make_pitched_tensor(10, 20, torch.float32, "cpu", alignment=256)
for _ in range(num_inputs)
]
else:
inputs = [torch.randn(10, 20) for _ in range(num_inputs)]

cuda_inputs = [
input.to(f"cuda:{i % num_gpus}") for i, input in enumerate(inputs)
]
Expand All @@ -150,8 +172,6 @@ def test_merge_pooled_embeddings_gpu_to_cpu(self) -> None:
torch.testing.assert_close(output, ref_output)

# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `hypothesis.strategies.integers($parameter$min_value = 1, $parameter$max_value =
# 10)` to decorator factory `hypothesis.given`.
@given(
num_inputs=st.integers(min_value=1, max_value=8),
num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()),
Expand Down Expand Up @@ -234,7 +254,6 @@ def test_sum_reduce_to_one(
cuda_output.cpu(), torch.stack(inputs).sum(dim=0)
)

@unittest.skipIf(*typed_gpu_unavailable)
def test_merge_pooled_embeddings_meta(self) -> None:
"""
Test that merge_pooled_embeddings works with meta tensor and
Expand Down
Loading