Skip to content

Support printing tensors with custom names or titles ; support delimiter of multi dim tensor ; simplify code #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 57 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
0e87bc5
more
fzyzcjy Jun 26, 2025
741959e
more
fzyzcjy Jun 26, 2025
805b35e
more
fzyzcjy Jun 26, 2025
7b0f41c
more
fzyzcjy Jun 26, 2025
e09b296
more
fzyzcjy Jun 26, 2025
a6c756d
more
fzyzcjy Jun 26, 2025
d0e7b54
more
fzyzcjy Jun 26, 2025
b66aa9a
more
fzyzcjy Jun 26, 2025
f43e6a7
more
fzyzcjy Jun 26, 2025
d0a9ba4
more
fzyzcjy Jun 26, 2025
47636bf
more
fzyzcjy Jun 26, 2025
d5973fc
more
fzyzcjy Jun 26, 2025
76106c3
more
fzyzcjy Jun 26, 2025
b809539
more
fzyzcjy Jun 26, 2025
bc3cd58
more
fzyzcjy Jun 26, 2025
0ffd3e8
more
fzyzcjy Jun 26, 2025
4876b98
more
fzyzcjy Jun 26, 2025
6dd6a8b
more
fzyzcjy Jun 26, 2025
8cc9bce
more
fzyzcjy Jun 26, 2025
a3ca231
more
fzyzcjy Jun 26, 2025
6c18f17
more
fzyzcjy Jun 26, 2025
14a850d
more
fzyzcjy Jun 26, 2025
e0d2ed9
more
fzyzcjy Jun 26, 2025
5018cf4
more
fzyzcjy Jun 26, 2025
6eb37e2
more
fzyzcjy Jun 26, 2025
bdbf379
more
fzyzcjy Jun 26, 2025
96eaae4
more
fzyzcjy Jun 26, 2025
c014aab
more
fzyzcjy Jun 26, 2025
06abf6c
more
fzyzcjy Jun 26, 2025
577afe2
more
fzyzcjy Jun 26, 2025
e66e561
more
fzyzcjy Jun 26, 2025
c7a901c
more
fzyzcjy Jun 26, 2025
5d3bde9
more
fzyzcjy Jun 26, 2025
aa0aa6f
more
fzyzcjy Jun 26, 2025
5d67739
more
fzyzcjy Jun 26, 2025
d645829
more
fzyzcjy Jun 26, 2025
bacaf8c
more
fzyzcjy Jun 26, 2025
9a4b5ff
Revert "more"
fzyzcjy Jun 26, 2025
9f1e6e4
more
fzyzcjy Jun 26, 2025
821eb26
more
fzyzcjy Jun 26, 2025
8d619e6
more
fzyzcjy Jun 26, 2025
7c8da33
more
fzyzcjy Jun 26, 2025
4fcee3d
more
fzyzcjy Jun 26, 2025
cf18d09
more
fzyzcjy Jun 26, 2025
8a52136
more
fzyzcjy Jun 26, 2025
768d952
more
fzyzcjy Jun 26, 2025
567f62d
more
fzyzcjy Jun 26, 2025
1610757
more
fzyzcjy Jun 26, 2025
c79d346
more
fzyzcjy Jun 26, 2025
2288462
more
fzyzcjy Jun 26, 2025
e7c6395
more
fzyzcjy Jun 26, 2025
7e6e9e6
morew
fzyzcjy Jun 26, 2025
f33a38c
more
fzyzcjy Jun 26, 2025
a255336
more
fzyzcjy Jun 26, 2025
29bf992
more
fzyzcjy Jun 26, 2025
7aca9d8
more
fzyzcjy Jun 26, 2025
797597a
more
fzyzcjy Jun 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 166 additions & 94 deletions csrc/debug_print.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,116 +15,176 @@
TYPE, NAME, \
AT_DISPATCH_CASE_FLOATING_AND_REDUCED_FLOATING_TYPES(__VA_ARGS__))

template <typename float_t>
__global__ void PrintFloatTensor1D(float_t *__restrict__ x,
const size_t stride_0, const size_t n,
const bool print_ptr) {
if (print_ptr) {
printf("addr: %lld\n", x);
__device__ void PrintCommon(void* x, const char* name_ptr, const bool print_ptr) {
if (name_ptr != nullptr) {
printf("name=%s, ", name_ptr);
}
for (size_t i = 0; i < n; ++i) {
printf("%.4f, ", float(x[i * stride_0]));
if (print_ptr) {
printf("addr=%lld, ", x);
}
printf("\n");
}

template <typename int_t>
__global__ void PrintIntTensor1D(int_t *__restrict__ x, const size_t stride_0,
const size_t n, const bool print_ptr) {
if (print_ptr) {
printf("addr: %lld\n", x);
}
for (size_t i = 0; i < n; ++i) {
printf("%lld, ", int64_t(x[i * stride_0]));
}
printf("\n");
template <typename T>
struct is_my_floating_point : std::is_floating_point<T> {};

template <>
struct is_my_floating_point<c10::Half> : std::true_type {};

template <>
struct is_my_floating_point<c10::BFloat16> : std::true_type {};

template <typename T>
struct always_false : std::false_type {};

template <typename scalar_t>
__device__ void PrintElem(scalar_t value) {
if constexpr (is_my_floating_point<scalar_t>::value) {
printf("%.4f, ", float(value));
} else if constexpr (std::is_integral<scalar_t>::value) {
printf("%lld, ", static_cast<long long>(value));
} else {
static_assert(always_false<scalar_t>::value, "PrintElem: unsupported scalar_t type");
}
}

template <typename float_t>
__global__ void PrintFloatTensor2D(float_t *__restrict__ x,
const size_t shape_0, const size_t stride_1,
const size_t stride_0, const size_t n,
const bool print_ptr) {
if (print_ptr) {
printf("addr: %lld\n", x);
__global__ void PrintTensor1D(
float_t *__restrict__ x,
const size_t shape_0,
const size_t stride_0,
const char* name_ptr, const bool print_ptr, const bool print_shape
) {
PrintCommon(x, name_ptr, print_ptr);
if (print_shape) {
printf("shape=(%d), stride=(%d)", (int) shape_0, (int) stride_0);
}
for (size_t i = 0; i < n; ++i) {
printf("%.4f, ",
float(x[(i / shape_0) * stride_1 + (i % shape_0) * stride_0]));
printf("\n[");
for (size_t index_0 = 0; index_0 < shape_0; ++index_0) {
PrintElem(x[index_0 * stride_0]);
}
printf("\n");
printf("]\n");
}

template <typename int_t>
__global__ void PrintIntTensor2D(int_t *__restrict__ x, const size_t shape_0,
const size_t stride_1, const size_t stride_0,
const size_t n, const bool print_ptr) {
if (print_ptr) {
printf("addr: %lld\n", x);
template <typename float_t>
__global__ void PrintTensor2D(
float_t *__restrict__ x,
const size_t shape_0, const size_t shape_1,
const size_t stride_0, const size_t stride_1,
const char* name_ptr, const bool print_ptr, const bool print_shape
) {
PrintCommon(x, name_ptr, print_ptr);
if (print_shape) {
printf("shape=(%d, %d), stride=(%d, %d)", (int) shape_0, (int) shape_1, (int) stride_0, (int) stride_1);
}
for (size_t i = 0; i < n; ++i) {
printf("%lld, ",
int64_t(x[(i / shape_0) * stride_1 + (i % shape_0) * stride_0]));
printf("\n[");
for (size_t index_0 = 0; index_0 < shape_0; ++index_0) {
printf("[");
for (size_t index_1 = 0; index_1 < shape_1; ++index_1) {
PrintElem(x[index_0 * stride_0 + index_1 * stride_1]);
}
printf("], ");
}
printf("\n");
printf("]\n");
}

template <typename float_t>
__global__ void PrintFloatTensor3D(float_t *__restrict__ x,
const size_t shape_1, const size_t shape_0,
const size_t stride_2, const size_t stride_1,
const size_t stride_0, const size_t n,
const bool print_ptr) {
if (print_ptr) {
printf("addr: %lld\n", x);
__global__ void PrintTensor3D(
float_t *__restrict__ x,
const size_t shape_0, const size_t shape_1, const size_t shape_2,
const size_t stride_0, const size_t stride_1, const size_t stride_2,
const char* name_ptr, const bool print_ptr, const bool print_shape
) {
PrintCommon(x, name_ptr, print_ptr);
if (print_shape) {
printf("shape=(%d, %d, %d), stride=(%d, %d, %d)", (int) shape_0, (int) shape_1, (int) shape_2, (int) stride_0, (int) stride_1, (int) stride_2);
}
for (size_t i = 0; i < n; ++i) {
printf("%.4f, ", float(x[(i / shape_0 / shape_1) * stride_2 +
((i / shape_0) % shape_1) * stride_1 +
(i % shape_0) * stride_0]));
printf("\n[");
for (size_t index_0 = 0; index_0 < shape_0; ++index_0) {
printf("[");
for (size_t index_1 = 0; index_1 < shape_1; ++index_1) {
printf("[");
for (size_t index_2 = 0; index_2 < shape_2; ++index_2) {
PrintElem(x[index_0 * stride_0 + index_1 * stride_1 + index_2 * stride_2]);
}
printf("], ");
}
printf("], ");
}
printf("\n");
printf("]\n");
}

template <typename int_t>
__global__ void PrintIntTensor3D(int_t *__restrict__ x, const size_t shape_1,
const size_t shape_0, const size_t stride_2,
const size_t stride_1, const size_t stride_0,
const size_t n, const bool print_ptr) {
if (print_ptr) {
printf("addr: %lld\n", x);
template <typename float_t>
__global__ void PrintTensor4D(
float_t *__restrict__ x,
const size_t shape_0, const size_t shape_1, const size_t shape_2, const size_t shape_3,
const size_t stride_0, const size_t stride_1, const size_t stride_2, const size_t stride_3,
const char* name_ptr, const bool print_ptr, const bool print_shape
) {
PrintCommon(x, name_ptr, print_ptr);
if (print_shape) {
printf("shape=(%d, %d, %d, %d), stride=(%d, %d, %d, %d)", (int) shape_0, (int) shape_1, (int) shape_2, (int) shape_3, (int) stride_0, (int) stride_1, (int) stride_2, (int) stride_3);
}
for (size_t i = 0; i < n; ++i) {
printf("%lld, ", int64_t(x[(i / shape_0 / shape_1) * stride_2 +
((i / shape_0) % shape_1) * stride_1 +
(i % shape_0) * stride_0]));
printf("\n[");
for (size_t index_0 = 0; index_0 < shape_0; ++index_0) {
printf("[");
for (size_t index_1 = 0; index_1 < shape_1; ++index_1) {
printf("[");
for (size_t index_2 = 0; index_2 < shape_2; ++index_2) {
printf("[");
for (size_t index_3 = 0; index_3 < shape_3; ++index_3) {
PrintElem(x[index_0 * stride_0 + index_1 * stride_1 + index_2 * stride_2 + index_3 * stride_3]);
}
printf("], ");
}
printf("], ");
}
printf("], ");
}
printf("\n");
printf("]\n");
}

void PrintTensor(torch::Tensor x, bool print_ptr) {
void PrintTensor(torch::Tensor x, std::optional<torch::Tensor> name_buffer, bool print_ptr, bool print_shape) {
cudaStream_t stream = c10::cuda::getCurrentCUDAStream(x.device().index());
TORCH_CHECK(x.is_cuda(), "The input tensor should be a CUDA tensor");

const char* name_ptr = name_buffer.has_value() ? reinterpret_cast<char*>(name_buffer->data_ptr<uint8_t>()) : nullptr;

if (x.is_floating_point()) {
if (x.dim() == 1) {
AT_DISPATCH_FLOATING_AND_REDUCED_FLOATING_TYPES(
x.scalar_type(), "PrintFloatTensor1D", ([&] {
PrintFloatTensor1D<<<1, 1, 0, stream>>>(
x.data_ptr<scalar_t>(), x.stride(0), x.numel(), print_ptr);
x.scalar_type(), "PrintTensor1D", ([&] {
PrintTensor1D<<<1, 1, 0, stream>>>(
x.data_ptr<scalar_t>(),
x.size(0), x.stride(0),
name_ptr, print_ptr, print_shape
);
}));
} else if (x.dim() == 2) {
AT_DISPATCH_FLOATING_AND_REDUCED_FLOATING_TYPES(
x.scalar_type(), "PrintFloatTensor2D", ([&] {
PrintFloatTensor2D<<<1, 1, 0, stream>>>(
x.data_ptr<scalar_t>(), x.size(1), x.stride(0), x.stride(1),
x.numel(), print_ptr);
x.scalar_type(), "PrintTensor2D", ([&] {
PrintTensor2D<<<1, 1, 0, stream>>>(
x.data_ptr<scalar_t>(),
x.size(0), x.size(1), x.stride(0), x.stride(1),
name_ptr, print_ptr, print_shape
);
}));
} else if (x.dim() == 3) {
AT_DISPATCH_FLOATING_AND_REDUCED_FLOATING_TYPES(
x.scalar_type(), "PrintFloatTensor3D", ([&] {
PrintFloatTensor3D<<<1, 1, 0, stream>>>(
x.data_ptr<scalar_t>(), x.size(1), x.size(2), x.stride(0),
x.stride(1), x.stride(2), x.numel(), print_ptr);
x.scalar_type(), "PrintTensor3D", ([&] {
PrintTensor3D<<<1, 1, 0, stream>>>(
x.data_ptr<scalar_t>(),
x.size(0), x.size(1), x.size(2), x.stride(0), x.stride(1), x.stride(2),
name_ptr, print_ptr, print_shape
);
}));
} else if (x.dim() == 4) {
AT_DISPATCH_FLOATING_AND_REDUCED_FLOATING_TYPES(
x.scalar_type(), "PrintTensor4D", ([&] {
PrintTensor4D<<<1, 1, 0, stream>>>(
x.data_ptr<scalar_t>(),
x.size(0), x.size(1), x.size(2), x.size(3), x.stride(0), x.stride(1), x.stride(2), x.stride(3),
name_ptr, print_ptr, print_shape
);
}));
} else {
// NOTE(Zihao): I'm just too lazy to do this, codegen for higher
Expand All @@ -133,37 +193,49 @@ void PrintTensor(torch::Tensor x, bool print_ptr) {
}
cudaError_t status = cudaGetLastError();
TORCH_CHECK(status == cudaSuccess,
"PrintFloatTensor failed with error " +
"PrintTensor failed with error " +
std::string(cudaGetErrorString(status)));
} else {
if (x.dim() == 1) {
AT_DISPATCH_INTEGRAL_TYPES(x.scalar_type(), "PrintIntTensor1D", ([&] {
PrintIntTensor1D<<<1, 1, 0, stream>>>(
x.data_ptr<scalar_t>(), x.stride(0),
x.numel(), print_ptr);
}));
AT_DISPATCH_INTEGRAL_TYPES(x.scalar_type(), "PrintTensor1D", ([&] {
PrintTensor1D<<<1, 1, 0, stream>>>(
x.data_ptr<scalar_t>(),
x.size(0), x.stride(0),
name_ptr, print_ptr, print_shape
);
}));
} else if (x.dim() == 2) {
AT_DISPATCH_INTEGRAL_TYPES(x.scalar_type(), "PrintIntTensor2D", ([&] {
PrintIntTensor2D<<<1, 1, 0, stream>>>(
x.data_ptr<scalar_t>(), x.size(1),
x.stride(0), x.stride(1), x.numel(),
print_ptr);
}));
AT_DISPATCH_INTEGRAL_TYPES(x.scalar_type(), "PrintTensor2D", ([&] {
PrintTensor2D<<<1, 1, 0, stream>>>(
x.data_ptr<scalar_t>(),
x.size(0), x.size(1), x.stride(0), x.stride(1),
name_ptr, print_ptr, print_shape
);
}));
} else if (x.dim() == 3) {
AT_DISPATCH_INTEGRAL_TYPES(x.scalar_type(), "PrintIntTensor3D", ([&] {
PrintIntTensor3D<<<1, 1, 0, stream>>>(
x.data_ptr<scalar_t>(), x.size(1),
x.size(2), x.stride(0), x.stride(1),
x.stride(2), x.numel(), print_ptr);
}));
AT_DISPATCH_INTEGRAL_TYPES(x.scalar_type(), "PrintTensor3D", ([&] {
PrintTensor3D<<<1, 1, 0, stream>>>(
x.data_ptr<scalar_t>(),
x.size(0), x.size(1), x.size(2), x.stride(0), x.stride(1), x.stride(2),
name_ptr, print_ptr, print_shape
);
}));
} else if (x.dim() == 4) {
AT_DISPATCH_INTEGRAL_TYPES(x.scalar_type(), "PrintTensor4D", ([&] {
PrintTensor4D<<<1, 1, 0, stream>>>(
x.data_ptr<scalar_t>(),
x.size(0), x.size(1), x.size(2), x.size(3), x.stride(0), x.stride(1), x.stride(2), x.stride(3),
name_ptr, print_ptr, print_shape
);
}));
} else {
// NOTE(Zihao): I'm just too lazy to do this, codegen for higher
// dimensions should be a better idea
TORCH_CHECK(false, "Input dimension not supported.");
}
cudaError_t status = cudaGetLastError();
TORCH_CHECK(status == cudaSuccess,
"PrintIntTensor failed with error " +
"PrintTensor failed with error " +
std::string(cudaGetErrorString(status)));
}
}
Expand Down
Loading