Skip to content

Commit fbf2088

Browse files
committed
gru test
1 parent d1473c7 commit fbf2088

File tree

98 files changed

+19590
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

98 files changed

+19590
-0
lines changed

exir/emit/test/test_emit.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,116 @@ def combine_fn(carry, x):
967967
self.assertIn("aten::select_copy", op_names)
968968
self.assertIn("executorch_prim::et_copy_index", op_names)
969969

970+
def test_emit_scan_gru(self) -> None:
971+
"""Test scan with a simple GRU-like computation."""
972+
from torch._higher_order_ops.scan import scan
973+
974+
class SimpleGRU(torch.nn.Module):
975+
"""Simple single-layer unidirectional GRU using scan."""
976+
977+
def __init__(self, input_size: int, hidden_size: int):
978+
super().__init__()
979+
self.input_size = input_size
980+
self.hidden_size = hidden_size
981+
982+
# GRU gates: reset, update, new
983+
self.weight_ih = torch.nn.Parameter(
984+
torch.randn(3 * hidden_size, input_size), requires_grad=False
985+
)
986+
self.weight_hh = torch.nn.Parameter(
987+
torch.randn(3 * hidden_size, hidden_size), requires_grad=False
988+
)
989+
self.bias_ih = torch.nn.Parameter(
990+
torch.randn(3 * hidden_size), requires_grad=False
991+
)
992+
self.bias_hh = torch.nn.Parameter(
993+
torch.randn(3 * hidden_size), requires_grad=False
994+
)
995+
996+
def forward(
997+
self, x: torch.Tensor, h0: torch.Tensor
998+
) -> Tuple[torch.Tensor, torch.Tensor]:
999+
"""
1000+
Args:
1001+
x: Input tensor of shape [seq_len, batch, input_size]
1002+
h0: Initial hidden state of shape [batch, hidden_size]
1003+
Returns:
1004+
output: Output tensor of shape [seq_len, batch, hidden_size]
1005+
h_n: Final hidden state of shape [batch, hidden_size]
1006+
"""
1007+
weight_ih = self.weight_ih
1008+
weight_hh = self.weight_hh
1009+
bias_ih = self.bias_ih
1010+
bias_hh = self.bias_hh
1011+
1012+
def gru_cell(
1013+
h: torch.Tensor, x_t: torch.Tensor
1014+
) -> Tuple[torch.Tensor, torch.Tensor]:
1015+
# Compute gates
1016+
gates_ih = torch.nn.functional.linear(x_t, weight_ih, bias_ih)
1017+
gates_hh = torch.nn.functional.linear(h, weight_hh, bias_hh)
1018+
1019+
# Split into reset, update, new gates
1020+
r_ih, z_ih, n_ih = gates_ih.chunk(3, dim=-1)
1021+
r_hh, z_hh, n_hh = gates_hh.chunk(3, dim=-1)
1022+
1023+
r = torch.sigmoid(r_ih + r_hh)
1024+
z = torch.sigmoid(z_ih + z_hh)
1025+
n = torch.tanh(n_ih + r * n_hh)
1026+
1027+
h_new = (1 - z) * n + z * h
1028+
return h_new, h_new.clone()
1029+
1030+
final_h, outputs = scan(gru_cell, h0, x)
1031+
return outputs, final_h
1032+
1033+
# Create model and inputs
1034+
input_size = 4
1035+
hidden_size = 8
1036+
seq_len = 5
1037+
batch_size = 2
1038+
1039+
model = SimpleGRU(input_size, hidden_size)
1040+
x = torch.randn(seq_len, batch_size, input_size)
1041+
h0 = torch.randn(batch_size, hidden_size)
1042+
inputs = (x, h0)
1043+
1044+
# Run through eager PyTorch
1045+
eager_outputs = model(*inputs)
1046+
1047+
# Export and convert to edge
1048+
module = to_edge(
1049+
export(model, inputs, strict=True),
1050+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
1051+
)
1052+
et = module.to_executorch()
1053+
program = et.executorch_program
1054+
1055+
# Verify the program has expected operators
1056+
op_names = [op.name for op in program.execution_plan[0].operators]
1057+
1058+
# Should have scan control flow operators
1059+
self.assertIn("aten::sym_size", op_names)
1060+
self.assertIn("aten::select_copy", op_names)
1061+
self.assertIn("executorch_prim::et_copy_index", op_names)
1062+
1063+
# Verify we can load the program
1064+
buffer = et.buffer
1065+
loaded_model = _load_for_executorch_from_buffer(buffer)
1066+
1067+
# Run through executorch
1068+
et_outputs = loaded_model(inputs)
1069+
1070+
# Compare outputs (with tolerance for floating point)
1071+
self.assertTrue(
1072+
torch.allclose(et_outputs[0], eager_outputs[0], atol=1e-5),
1073+
f"Output mismatch: {et_outputs[0]} vs {eager_outputs[0]}",
1074+
)
1075+
self.assertTrue(
1076+
torch.allclose(et_outputs[1], eager_outputs[1], atol=1e-5),
1077+
f"Final hidden state mismatch: {et_outputs[1]} vs {eager_outputs[1]}",
1078+
)
1079+
9701080
def test_dim_order(self) -> None:
9711081
class SimpleLinear(torch.nn.Module):
9721082
def __init__(self) -> None:

preprocess.pt2

305 KB
Binary file not shown.
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
//===----------------------------------------------------------------------===//
10+
/// \file extension/kernel_util/make_boxed_from_unboxed_functor.h
11+
/// Defines a template that can be used to create a boxed version of an unboxed
12+
/// functor.
13+
/// Example usage:
14+
/// ```
15+
/// Tensor&
16+
/// my_op(KernelRuntimeContext& ctx, const Tensor& self, const Tensor& other,
17+
/// Tensor& out)
18+
/// {
19+
/// // ...
20+
/// return out;
21+
/// }
22+
///
23+
/// Kernel my_kernel = Kernel::make_boxed_kernel("my_ns::my_op",
24+
/// EXECUTORCH_FN(my_op));
25+
/// static auto res = register_kernels({my_kernel});
26+
/// ```
27+
/// Or simply:
28+
/// ```
29+
/// EXECUTORCH_LIBRARY(my_ns, "my_op", my_op);
30+
/// ```
31+
///
32+
/// The trick here is to convert each EValue to inferred argument type. This
33+
/// uses a lot of C++17 features.
34+
//===----------------------------------------------------------------------===//
35+
36+
#pragma once
37+
38+
#include <executorch/extension/kernel_util/meta_programming.h>
39+
#include <executorch/extension/kernel_util/type_list.h>
40+
#include <executorch/runtime/core/evalue.h>
41+
#include <executorch/runtime/core/event_tracer_hooks.h>
42+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
43+
#include <executorch/runtime/kernel/kernel_runtime_context.h>
44+
#include <executorch/runtime/kernel/operator_registry.h>
45+
#include <cstdlib>
46+
#include <memory>
47+
#include <type_traits>
48+
#include <typeinfo>
49+
50+
namespace executorch {
51+
namespace runtime {
52+
class KernelRuntimeContext; // Forward declaration
53+
} // namespace runtime
54+
} // namespace executorch
55+
56+
namespace executorch {
57+
namespace extension {
58+
59+
// This extension has a lot of generic internal names like "size"; use a unique
60+
// internal namespace to avoid conflicts with other extensions.
61+
namespace kernel_util_internal {
62+
63+
// Template trait to check if a type is a non-const tensor
64+
template <class T>
65+
struct is_nonconst_tensor : std::false_type {};
66+
67+
template <>
68+
struct is_nonconst_tensor<executorch::aten::Tensor&> : std::true_type {};
69+
70+
// Template trait to check if a type is a non-const tensor
71+
// Count non-const tensors in a typelist
72+
template <class TypeList>
73+
struct count_nonconst_tensors;
74+
75+
template <>
76+
struct count_nonconst_tensors<typelist<>> {
77+
static constexpr size_t value = 0;
78+
};
79+
80+
template <class T>
81+
struct count_nonconst_tensors<typelist<T>> {
82+
static constexpr size_t value = 0;
83+
};
84+
85+
template <>
86+
struct count_nonconst_tensors<typelist<executorch::aten::Tensor&>> {
87+
static constexpr size_t value = 1;
88+
};
89+
90+
template <class Head, class... Tail>
91+
struct count_nonconst_tensors<typelist<Head, Tail...>> {
92+
private:
93+
static constexpr size_t tail_tensor_count =
94+
count_nonconst_tensors<typelist<Tail...>>::value;
95+
static constexpr size_t tail_args_count = sizeof...(Tail);
96+
static constexpr bool is_head_a_tensor = is_nonconst_tensor<Head>::value;
97+
static constexpr bool all_tail_args_are_tensor =
98+
tail_tensor_count == tail_args_count;
99+
100+
public:
101+
static constexpr size_t value = (is_head_a_tensor && all_tail_args_are_tensor)
102+
? tail_tensor_count + 1
103+
: tail_tensor_count;
104+
};
105+
106+
template <class T>
107+
struct decay_if_not_tensor final {
108+
using type = std::decay_t<T>;
109+
};
110+
template <>
111+
struct decay_if_not_tensor<executorch::aten::Tensor&> final {
112+
using type = executorch::aten::Tensor&;
113+
};
114+
template <>
115+
struct decay_if_not_tensor<const executorch::aten::Tensor&> final {
116+
using type = const executorch::aten::Tensor&;
117+
};
118+
119+
template <class T>
120+
struct evalue_to_arg final {
121+
static T call(executorch::runtime::EValue& v) {
122+
return std::move(v).to<T>();
123+
}
124+
};
125+
126+
template <>
127+
struct evalue_to_arg<executorch::aten::Tensor&> final {
128+
static executorch::aten::Tensor& call(executorch::runtime::EValue& v) {
129+
return v.toTensor();
130+
}
131+
};
132+
133+
template <>
134+
struct evalue_to_arg<const executorch::aten::Tensor&> final {
135+
static const executorch::aten::Tensor& call(executorch::runtime::EValue& v) {
136+
return v.toTensor();
137+
}
138+
};
139+
140+
template <class T>
141+
struct evalue_to_arg<std::optional<T>> final {
142+
static std::optional<T> call(executorch::runtime::EValue& v) {
143+
return v.toOptional<T>();
144+
}
145+
};
146+
147+
template <class T>
148+
struct evalue_to_arg<executorch::aten::ArrayRef<std::optional<T>>> final {
149+
static executorch::aten::ArrayRef<std::optional<T>> call(
150+
executorch::runtime::EValue& v) {
151+
return v.toListOptionalTensor();
152+
}
153+
};
154+
155+
template <
156+
class Functor,
157+
size_t nonconst_tensors_to_log,
158+
size_t... evalue_arg_indices,
159+
typename... ArgTypes>
160+
void call_functor_with_args_from_stack(
161+
executorch::runtime::KernelRuntimeContext& ctx,
162+
executorch::runtime::Span<executorch::runtime::EValue*> stack,
163+
std::index_sequence<evalue_arg_indices...>,
164+
typelist<ArgTypes...>*) {
165+
executorch::runtime::internal::EventTracerProfileOpScope
166+
event_tracer_op_scope(ctx.internal_event_tracer(), Functor::func_name_);
167+
EXECUTORCH_SCOPE_PROF(Functor::func_name_);
168+
(*Functor::func_ptr())(
169+
ctx,
170+
evalue_to_arg<typename decay_if_not_tensor<ArgTypes>::type>::call(
171+
*stack[evalue_arg_indices])...);
172+
constexpr size_t num_inputs =
173+
std::index_sequence<evalue_arg_indices...>::size();
174+
for (size_t i = num_inputs - nonconst_tensors_to_log; i < num_inputs; ++i) {
175+
executorch::runtime::internal::event_tracer_log_evalue(
176+
ctx.internal_event_tracer(), *stack[i]);
177+
}
178+
}
179+
180+
} // namespace kernel_util_internal
181+
182+
/**
183+
* WrapUnboxedIntoFunctor: Given a function pointer, wrap it into a functor that
184+
* takes EValues as input and returns void. The wrapped functor will unbox all
185+
* inputs and forward them to unboxed kernel.
186+
*/
187+
template <class FuncType>
188+
struct WrapUnboxedIntoFunctor {
189+
static_assert(
190+
kernel_util_internal::is_compile_time_function_pointer<FuncType>::value,
191+
"Can't handle function other than EXECUTORCH_FN");
192+
using TrueType = typename FuncType::FuncType;
193+
using ReturnType = typename kernel_util_internal::infer_function_traits_t<
194+
TrueType>::return_type;
195+
using ArgsType = typename kernel_util_internal::infer_function_traits_t<
196+
TrueType>::parameter_types;
197+
// check if the first argument is KernelRuntimeContext, if so, remove it
198+
static constexpr bool first_arg_is_context = std::is_same<
199+
::executorch::runtime::KernelRuntimeContext,
200+
std::remove_reference_t<
201+
kernel_util_internal::head_with_default_t<void, ArgsType>>>::value;
202+
using ContextRemovedArgsType = std::conditional_t<
203+
first_arg_is_context,
204+
kernel_util_internal::drop_if_nonempty_t<ArgsType, 1>,
205+
ArgsType>;
206+
207+
static void call(
208+
::executorch::runtime::KernelRuntimeContext& ctx,
209+
executorch::runtime::Span<executorch::runtime::EValue*> stack) {
210+
constexpr size_t num_inputs =
211+
kernel_util_internal::size<ContextRemovedArgsType>::value;
212+
constexpr size_t num_nonconst_tensors =
213+
kernel_util_internal::count_nonconst_tensors<
214+
ContextRemovedArgsType>::value;
215+
static_assert(num_nonconst_tensors == 1, "Invalid number of inputs");
216+
return kernel_util_internal::
217+
call_functor_with_args_from_stack<FuncType, num_nonconst_tensors>(
218+
ctx,
219+
stack,
220+
std::make_index_sequence<num_inputs>(),
221+
static_cast<ContextRemovedArgsType*>(nullptr));
222+
}
223+
};
224+
225+
template <typename FuncType>
226+
static executorch::runtime::Kernel make_boxed_kernel(
227+
const char* name,
228+
FuncType) {
229+
return executorch::runtime::Kernel(
230+
name, WrapUnboxedIntoFunctor<FuncType>::call);
231+
}
232+
233+
} // namespace extension
234+
} // namespace executorch
235+
236+
// Inspired from C10_CONCATENATE
237+
#define ET_CONCATENATE_IMPL(s1, s2) s1##s2
238+
#define ET_CONCATENATE(s1, s2) ET_CONCATENATE_IMPL(s1, s2)
239+
#define ET_UID __LINE__
240+
241+
#define EXECUTORCH_LIBRARY(ns, op_name, func) \
242+
_EXECUTORCH_LIBRARY_IMPL(ns, op_name, func, ET_UID)
243+
244+
#define _EXECUTORCH_LIBRARY_IMPL(ns, op_name, func, uid) \
245+
static constexpr const char ET_CONCATENATE(name_of_op_, uid)[] = \
246+
#ns "::" op_name; \
247+
static auto ET_CONCATENATE(res_##ns##_, uid) = \
248+
::executorch::runtime::register_kernel( \
249+
::executorch::extension::make_boxed_kernel( \
250+
#ns "::" op_name, \
251+
EXECUTORCH_FN(func, ET_CONCATENATE(name_of_op_, uid))))
252+
253+
namespace torch {
254+
namespace executor {
255+
// TODO(T197294990): Remove these deprecated aliases once all users have moved
256+
// to the new `::executorch` namespaces.
257+
using ::executorch::extension::make_boxed_kernel;
258+
using ::executorch::extension::WrapUnboxedIntoFunctor;
259+
} // namespace executor
260+
} // namespace torch

0 commit comments

Comments
 (0)