Skip to content

[STABLE ABI] Stable forced_align on cpu #4022

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 25 commits into
base: stable_accessor
Choose a base branch
from
Draft
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
2062dc7
Make alphas_a standard C array
samanklesaria Aug 5, 2025
e70113c
Convert backptr to standard array
samanklesaria Aug 5, 2025
4039399
Create Accessor class
samanklesaria Aug 5, 2025
b733629
Add MutAccessor
samanklesaria Aug 5, 2025
9beb34a
Fix multidimensional indexing bug
samanklesaria Aug 5, 2025
11d1e21
Use strides rather than computing standard strides from dims
samanklesaria Aug 5, 2025
b47c053
Merge Accessor and MutAccessor
samanklesaria Aug 6, 2025
7a94b04
Move Accessor to its own file and add tests
samanklesaria Aug 6, 2025
75d246a
Add comment about original indexing
samanklesaria Aug 6, 2025
30ed519
Add requested comment about scalar_t
samanklesaria Aug 6, 2025
be13f64
WIP
samanklesaria Aug 6, 2025
258ca00
Merge branch 'main' into forced_align_accessors
samanklesaria Aug 6, 2025
77fd1ad
Use stable tensors throughout forced_align code
samanklesaria Aug 6, 2025
ced6124
Free alphas_a array
samanklesaria Aug 7, 2025
d27a416
Merge branch 'stable_forced_align' into forced_align_backptr
samanklesaria Aug 7, 2025
71ce212
Free backPtr_a
samanklesaria Aug 7, 2025
eb50150
Merge branch 'forced_align_backptr' into forced_align_accessors
samanklesaria Aug 7, 2025
9629864
Fix merge conflict
samanklesaria Aug 7, 2025
847b726
Correct dimensionality of path variable
samanklesaria Aug 7, 2025
2663def
Use 1d indexing in original layout for alphas_a
samanklesaria Aug 8, 2025
5fa467d
Merge branch 'stable_forced_align' into forced_align_backptr
samanklesaria Aug 8, 2025
724606a
Merge branch 'forced_align_backptr' into forced_align_accessors
samanklesaria Aug 8, 2025
b1595b1
Merge branch 'stable_accessor' into stable_forced_align_cpu
samanklesaria Aug 19, 2025
768b2b5
Merge branch 'stable_accessor' into stable_forced_align_cpu
samanklesaria Aug 19, 2025
86f7557
Use c-style dtype API
samanklesaria Aug 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 92 additions & 56 deletions src/libtorchaudio/forced_align/cpu/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,28 @@
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#include <libtorchaudio/accessor.h>
#include <torch/headeronly/util/Half.h>


using namespace std;

namespace torchaudio {
namespace alignment {
namespace cpu {

using torch::stable::Tensor;

// Inspired from
// https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
template <typename scalar_t, at::ScalarType target_scalar_type>
template <typename scalar_t, typename target_t>
void forced_align_impl(
const torch::Tensor& logProbs,
const torch::Tensor& targets,
const int64_t blank,
torch::Tensor& paths) {
const Tensor logProbs,
const Tensor targets,
target_t blank,
Tensor paths) {
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
using target_t = typename std::
conditional<target_scalar_type == torch::kInt, int, int64_t>::type;
const auto batchIndex =
0; // TODO: support batch version and use the real batch index
const auto T = logProbs.size(1);
Expand All @@ -37,12 +42,12 @@ void forced_align_impl(
backPtr_a[i] = -1;
}

auto logProbs_a = logProbs.accessor<scalar_t, 3>();
auto targets_a = targets.accessor<target_t, 2>();
auto paths_a = paths.accessor<target_t, 2>();
auto logProbs_a = Accessor<3, scalar_t, true>(logProbs);
auto targets_a = Accessor<2, target_t, true>(targets);
auto paths_a = Accessor<2, target_t, false>(paths);
auto R = 0;
for (auto i = 1; i < L; i++) {
if (targets_a[batchIndex][i] == targets_a[batchIndex][i - 1]) {
if (targets_a.index(batchIndex, i) == targets_a.index(batchIndex, i - 1)) {
++R;
}
}
Expand All @@ -57,22 +62,23 @@ void forced_align_impl(
auto start = T - (L + R) > 0 ? 0 : 1;
auto end = (S == 1) ? 1 : 2;
for (auto i = start; i < end; i++) {
auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
alphas_a[i] = logProbs_a[batchIndex][0][labelIdx]; // alphas_a[0, i]
auto labelIdx = (i % 2 == 0) ? blank : targets_a.index(batchIndex, i / 2);
alphas_a[i] = logProbs_a.index(batchIndex,0,labelIdx);

}
for (auto t = 1; t < T; t++) {
if (T - t <= L + R) {
if ((start % 2 == 1) &&
targets_a[batchIndex][start / 2] !=
targets_a[batchIndex][start / 2 + 1]) {
targets_a.index(batchIndex, start / 2) !=
targets_a.index(batchIndex, start / 2 + 1)) {
start = start + 1;
}
start = start + 1;
}
if (t <= L + R) {
if (end % 2 == 0 && end < 2 * L &&
targets_a[batchIndex][end / 2 - 1] !=
targets_a[batchIndex][end / 2]) {
targets_a.index(batchIndex, end / 2 - 1) !=
targets_a.index(batchIndex, end / 2)) {
end = end + 1;
}
end = end + 1;
Expand All @@ -85,8 +91,8 @@ void forced_align_impl(
}
if (start == 0) {
alphas_a[curIdxOffset * S] =
alphas_a[prevIdxOffset * S] + logProbs_a[batchIndex][t][blank]; // alphas_a[curIdxOffset][0]
backPtr_a[S * t] = 0; // backPtr_a[t][0] = 0
alphas_a[prevIdxOffset * S] + logProbs_a.index(batchIndex, t, blank);
backPtr_a[S * t] = 0; // backPtr_a[t][0] = 0
startloop += 1;
}

Expand All @@ -95,14 +101,14 @@ void forced_align_impl(
auto x1 = alphas_a[prevIdxOffset * S + i - 1]; // alphas_a[prevIdxOffset][i - 1];
auto x2 = -std::numeric_limits<scalar_t>::infinity();

auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
auto labelIdx = (i % 2 == 0) ? blank : targets_a.index(batchIndex, i / 2);

// In CTC, the optimal path may optionally chose to skip a blank label.
// x2 represents skipping a letter, and can only happen if we're not
// currently on a blank_label, and we're not on a repeat letter
// (i != 1) just ensures we don't access targets[i - 2] if its i < 2
if (i % 2 != 0 && i != 1 &&
targets_a[batchIndex][i / 2] != targets_a[batchIndex][i / 2 - 1]) {
targets_a.index(batchIndex, i / 2) != targets_a.index(batchIndex, i / 2 - 1)) {
x2 = alphas_a[prevIdxOffset * S + i - 2]; // alphas_a[prevIdxOffset][i - 2];
}
scalar_t result = 0.0;
Expand All @@ -116,7 +122,8 @@ void forced_align_impl(
result = x0;
backPtr_a[t * S + i] = 0; // backPtr_a[t][i] = 0
}
alphas_a[curIdxOffset * S + i] = result + logProbs_a[batchIndex][t][labelIdx]; // alphas_a[curIdxOffset][i]

alphas_a[curIdxOffset * S + i] = result + logProbs_a.index(batchIndex, t, labelIdx); // alphas_a[curIdxOffset][i]
}
}
auto idx1 = (T - 1) % 2;
Expand All @@ -125,31 +132,35 @@ void forced_align_impl(
delete[] alphas_a;
// path stores the token index for each time step after force alignment.
for (auto t = T - 1; t > -1; t--) {
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[batchIndex][ltrIdx / 2];
paths_a[batchIndex][t] = lbl_idx;
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a.index(batchIndex, ltrIdx / 2);
paths_a.set_index(lbl_idx, batchIndex, t);
ltrIdx -= backPtr_a[t * S + ltrIdx]; // backPtr_a[t][ltrIdx]
}
delete[] backPtr_a;
}

std::tuple<torch::Tensor, torch::Tensor> compute(
const torch::Tensor& logProbs,
const torch::Tensor& targets,
const torch::Tensor& inputLengths,
const torch::Tensor& targetLengths,
std::tuple<Tensor, Tensor> compute(
const Tensor& logProbs,
const Tensor& targets,
const Tensor& inputLengths,
const Tensor& targetLengths,
const int64_t blank) {
TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor");
TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor");
TORCH_CHECK(
logProbs.device() == targets.device(),
logProbs.get_device() == targets.get_device(),
"log_probs and targets need to be on the same device");
int32_t logprobs_dtype;
aoti_torch_get_dtype(logProbs.get(), &logprobs_dtype);
TORCH_CHECK(
logProbs.dtype() == torch::kFloat64 ||
logProbs.dtype() == torch::kFloat32 ||
logProbs.dtype() == torch::kFloat16,
logprobs_dtype == aoti_torch_dtype_float64() ||
logprobs_dtype == aoti_torch_dtype_float32() ||
logprobs_dtype == aoti_torch_dtype_float16(),
"log_probs must be float64, float32 or float16 (half) type");
int32_t targets_dtype;
aoti_torch_get_dtype(targets.get(), &targets_dtype);
TORCH_CHECK(
targets.dtype() == torch::kInt32 || targets.dtype() == torch::kInt64,
targets_dtype == aoti_torch_dtype_int32() || targets_dtype == aoti_torch_dtype_int64(),
"targets must be int32 or int64 type");
TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
Expand All @@ -172,39 +183,64 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
blank >= 0 && blank < logProbs.size(-1),
"blank must be within [0, num classes)");

TORCH_CHECK(
logProbs.size(1) == at::max(inputLengths).item().toInt(),
"input length mismatch");
TORCH_CHECK(
targets.size(1) == at::max(targetLengths).item().toInt(),
"target length mismatch");
// TODO: Requires port of `max` and `item` operators.
// TORCH_CHECK(
// logProbs.size(1) == at::max(inputLengths).item().toInt(),
// "input length mismatch");
// TORCH_CHECK(
// targets.size(1) == at::max(targetLengths).item().toInt(),
// "target length mismatch");

const auto B = logProbs.size(0);
const auto T = logProbs.size(1);
auto paths = torch::zeros(
{B, T},
torch::TensorOptions().device(targets.device()).dtype(targets.dtype()));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
logProbs.scalar_type(), "forced_align_impl", [&] {
if (targets.scalar_type() == torch::kInt64) {
forced_align_impl<scalar_t, torch::kInt64>(
logProbs, targets, blank, paths);
} else {
forced_align_impl<scalar_t, torch::kInt32>(
logProbs, targets, blank, paths);
}
});

int64_t paths_size[2] = {B, T};
int64_t paths_stride[2] = {T, 1};
AtenTensorHandle paths_h;
int32_t targets_device;
aoti_torch_get_device_type(targets.get(), &targets_device);
aoti_torch_empty_strided(2, paths_size, paths_stride, targets_dtype, targets_device, targets.get_device(), &paths_h);
auto paths = Tensor(paths_h);


if (targets_dtype == aoti_torch_dtype_int64()) {
if (logprobs_dtype == aoti_torch_dtype_float64()) {
forced_align_impl<double, int64_t>(logProbs, targets, blank, paths);
} else if (logprobs_dtype == aoti_torch_dtype_float32()) {
forced_align_impl<float, int64_t>(logProbs, targets, blank, paths);
} else if (logprobs_dtype == aoti_torch_dtype_float16()) {
forced_align_impl<c10::Half, int64_t>(logProbs, targets, blank, paths);
}
} else if (targets_dtype == aoti_torch_dtype_int32()) {
if (logprobs_dtype == aoti_torch_dtype_float64()) {
forced_align_impl<double, int32_t>(logProbs, targets, blank, paths);
} else if (logprobs_dtype == aoti_torch_dtype_float32()) {
forced_align_impl<float, int32_t>(logProbs, targets, blank, paths);
} else if (logprobs_dtype == aoti_torch_dtype_float16()) {
forced_align_impl<c10::Half, int32_t>(logProbs, targets, blank, paths);
}
}
return std::make_tuple(
paths,
logProbs
);
}


void boxed_compute(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor t1(to<AtenTensorHandle>(stack[0]));
Tensor t2(to<AtenTensorHandle>(stack[1]));
Tensor t3(to<AtenTensorHandle>(stack[2]));
Tensor t4(to<AtenTensorHandle>(stack[3]));
int64_t blank = to<int64_t>(stack[4]);
auto result = compute(
std::move(t1), std::move(t2), std::move(t3), std::move(t4), blank);
stack[0] = from(std::get<0>(result));
stack[1] = from(std::get<1>(result));
}


TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("forced_align", &compute);
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("forced_align", &boxed_compute);
}

} // namespace cpu
Expand Down
Loading