Skip to content

Commit e4c9f6d

Browse files
yushangdipytorchmergebot
authored andcommitted
[nativert] Move c10_kernel (pytorch#156208)
Summary: Torch Native Runtime RFC: pytorch/rfcs#72 As part of the effort to open source TorchNativeRuntime (or what we call Sigmoid), we are moving the Pytree implementation to torch/: fbcode/sigmoid/kernels -> fbcode/caffe2/torch/nativert/kernels Test Plan: ``` buck run fbcode//mode/dev-nosan //caffe2/test/cpp/nativert:c10_kernel_test ``` Differential Revision: D76825830 Pull Request resolved: pytorch#156208 Approved by: https://github.com/zhxchen17
1 parent f402eed commit e4c9f6d

File tree

5 files changed

+428
-0
lines changed

5 files changed

+428
-0
lines changed

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,7 @@ libtorch_nativert_sources = [
606606
"torch/nativert/executor/memory/FunctionSchema.cpp",
607607
"torch/nativert/common/FileUtil.cpp",
608608
"torch/nativert/detail/ITree.cpp",
609+
"torch/nativert/kernels/C10Kernel.cpp",
609610
]
610611

611612
torch_mobile_tracer_sources = [

test/cpp/nativert/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ set(NATIVERT_TEST_SRCS
1717
${TORCH_ROOT}/torch/nativert/executor/ExecutionPlanner.cpp
1818
${TORCH_ROOT}/torch/nativert/detail/ITree.cpp
1919
${TORCH_ROOT}/torch/nativert/executor/ExecutionFrame.cpp
20+
${TORCH_ROOT}/torch/nativert/kernels/C10Kernel.cpp
2021
)
2122

2223
add_executable(test_nativert
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#include <ATen/core/op_registration/op_registration.h>
2+
#include <gtest/gtest.h>
3+
#include <torch/nativert/executor/ExecutionFrame.h>
4+
#include <torch/nativert/graph/Graph.h>
5+
#include <torch/nativert/kernels/C10Kernel.h>
6+
#include <torch/torch.h>
7+
8+
namespace torch::nativert {
9+
10+
at::Tensor foo_kernel(const at::Tensor& a, const at::Tensor& b) {
11+
return a + b;
12+
}
13+
14+
TEST(C10KernelTest, computeInternal) {
15+
auto registrar = c10::RegisterOperators().op(
16+
"test::foo(Tensor a, Tensor b) -> Tensor", &foo_kernel);
17+
18+
static constexpr std::string_view source =
19+
R"(graph(%a, %b):
20+
%x = test.foo.default(a=%a, b=%b)
21+
return (%x)
22+
)";
23+
24+
auto graph = stringToGraph(source);
25+
const auto& nodes = graph->nodes();
26+
auto it = nodes.begin();
27+
std::advance(it, 1);
28+
const Node& node = *it;
29+
30+
c10::Device device = torch::Device(torch::kCPU, 0);
31+
32+
auto a = at::randn({6, 6, 6});
33+
auto b = at::randn({6, 6, 6});
34+
35+
auto frame = ExecutionFrame(*graph);
36+
frame.setIValue(graph->getValue("a")->id(), a);
37+
frame.setIValue(graph->getValue("b")->id(), b);
38+
39+
auto kernel = C10Kernel(&node, device);
40+
41+
kernel.computeInternal(frame);
42+
43+
at::Tensor expected = a + b;
44+
EXPECT_TRUE(
45+
torch::equal(frame.getTensor(graph->getValue("x")->id()), expected));
46+
}
47+
48+
TEST(ScalarBinaryOpKernelTest, computeInternal) {
49+
static constexpr std::string_view source =
50+
R"(graph(%a, %b):
51+
%x = _operator.add(a=%a, b=%b)
52+
return (%x)
53+
)";
54+
55+
auto graph = stringToGraph(source);
56+
const auto& nodes = graph->nodes();
57+
auto it = nodes.begin();
58+
std::advance(it, 1);
59+
const Node& node = *it;
60+
61+
auto a = 1;
62+
auto b = 2;
63+
64+
auto frame = ExecutionFrame(*graph);
65+
frame.setIValue(graph->getValue("a")->id(), a);
66+
frame.setIValue(graph->getValue("b")->id(), b);
67+
68+
auto kernel = ScalarBinaryOpKernel(&node);
69+
70+
kernel.computeInternal(frame);
71+
72+
auto expected = a + b;
73+
EXPECT_EQ(frame.getIValue(graph->getValue("x")->id()).toInt(), expected);
74+
}
75+
76+
} // namespace torch::nativert
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
#include <torch/nativert/kernels/C10Kernel.h>
2+
3+
#include <fmt/ostream.h>
4+
5+
#include <c10/util/Enumerate.h>
6+
7+
#ifdef __SIGRID_USE_GPU__
8+
#include <ATen/cuda/CUDAContext.h>
9+
#include <ATen/cuda/Exceptions.h>
10+
#endif
11+
12+
namespace torch::nativert {
13+
14+
C10Kernel::C10Kernel(
15+
const Node* node,
16+
c10::Device device,
17+
OpKernelKind kind,
18+
AliasingSpec&& aliasingSpec)
19+
: OpKernel(node, device, kind),
20+
op_(getOperatorForTarget(node->target(), node)),
21+
schema_(op_.schema(), std::move(aliasingSpec), kind_),
22+
arguments_(prefillStackWithStaticArgs(node, op_.schema())) {}
23+
24+
void C10Kernel::computeInternal(ExecutionFrame& executionFrame) const {
25+
// Make a copy of the stack
26+
std::vector<c10::IValue> stack = arguments_.getStackWithStaticArgs();
27+
28+
fillDynamicInputs(executionFrame, arguments_, stack);
29+
30+
// Call the op with the prepared stack.
31+
try {
32+
op_.callBoxed(stack);
33+
} catch (const std::exception& ex) {
34+
auto stackTrace = node_->getMetadata("stack_trace");
35+
throw std::runtime_error(fmt::format(
36+
"Exception while executing node: {}\n"
37+
"with args:\n{}\n"
38+
"{}\n"
39+
"Original Python stacktrace:\n{}",
40+
fmt::streamed(*node_),
41+
readableArgs(op_.schema(), stack),
42+
ex.what(),
43+
stackTrace ? *stackTrace : "<no stack trace>"));
44+
}
45+
46+
// Write out results
47+
// TODO: we store intermediates in a single table (symint and tensor alike).
48+
// This can theoretically lead to name collisions, although based on how
49+
// these are named I don't think it will ever happen in practice. We need to
50+
// enforce it though.
51+
const auto& outputValues = node_->outputs();
52+
TORCH_CHECK_EQ(outputValues.size(), stack.size())
53+
<< "Output size mismatch for " << node_->toString();
54+
for (auto&& [i, actualOutput] : c10::enumerate(stack)) {
55+
executionFrame.setIValue(outputValues[i]->id(), std::move(actualOutput));
56+
}
57+
}
58+
59+
namespace {
60+
std::unordered_map<std::string, c10::IValue> getSymInputs(
61+
const ExecutionFrame& executionFrame,
62+
const Node& node) {
63+
std::unordered_map<std::string, c10::IValue> inputs;
64+
for (const auto& input : node.inputs()) {
65+
const auto& val = executionFrame.getIValue(input.value->id());
66+
if (val.isInt() || val.isDouble() || val.isBool()) {
67+
inputs[input.name] = val;
68+
} else {
69+
throw std::runtime_error("unsupported type for symbolic input");
70+
}
71+
}
72+
for (const auto& attribute : node.attributes()) {
73+
if (std::holds_alternative<int64_t>(attribute.value)) {
74+
inputs[attribute.name] = std::get<int64_t>(attribute.value);
75+
} else if (std::holds_alternative<double>(attribute.value)) {
76+
inputs[attribute.name] = std::get<double>(attribute.value);
77+
} else if (std::holds_alternative<bool>(attribute.value)) {
78+
inputs[attribute.name] = std::get<bool>(attribute.value);
79+
} else {
80+
throw std::runtime_error("unsupported type for symbolic input");
81+
}
82+
}
83+
return inputs;
84+
}
85+
86+
template <typename T>
87+
void computeScalarBinaryOp(
88+
ExecutionFrame& executionFrame,
89+
const Node& node,
90+
std::enable_if_t<true, T> a,
91+
std::enable_if_t<true, T> b) {
92+
std::string_view target = node.target();
93+
T out;
94+
95+
if (target == "_operator.add") {
96+
out = a + b;
97+
} else if (target == "_operator.sub") {
98+
out = a - b;
99+
} else if (target == "_operator.mul") {
100+
out = a * b;
101+
} else if (target == "_operator.pow") {
102+
out = std::pow(a, b);
103+
} else {
104+
throw std::runtime_error(
105+
fmt::format("unsupported operator for symbolic values: {}", target));
106+
}
107+
108+
executionFrame.setIValue(node.outputs()[0]->id(), out);
109+
VLOG(2) << fmt::format(
110+
"Completed executing node: {} with a={}, b={}, out={}",
111+
fmt::streamed(node),
112+
a,
113+
b,
114+
out);
115+
}
116+
117+
} // namespace
118+
119+
void ScalarBinaryOpKernel::computeInternal(
120+
ExecutionFrame& executionFrame) const {
121+
auto inputs = getSymInputs(executionFrame, *node_);
122+
123+
const auto& a = inputs.at("a");
124+
const auto& b = inputs.at("b");
125+
126+
auto coerceToDouble = [](const c10::IValue& x) -> double {
127+
if (x.isInt()) {
128+
return static_cast<double>(x.toInt());
129+
} else if (x.isDouble()) {
130+
return x.toDouble();
131+
} else {
132+
throw std::runtime_error("unsupported type for symbolic input");
133+
}
134+
};
135+
136+
if (a.isInt() && b.isInt()) {
137+
computeScalarBinaryOp<int64_t>(
138+
executionFrame, *node_, a.toInt(), b.toInt());
139+
} else {
140+
computeScalarBinaryOp<double>(
141+
executionFrame, *node_, coerceToDouble(a), coerceToDouble(b));
142+
}
143+
}
144+
145+
void SymIntOpKernel::computeInternal(ExecutionFrame& executionFrame) const {
146+
auto inputs = getSymInputs(executionFrame, *node_);
147+
148+
int64_t a = inputs.at("a").toInt();
149+
std::string_view target = node_->target();
150+
if (target == "torch.sym_float") {
151+
double out = static_cast<double>(a);
152+
executionFrame.setIValue(node_->outputs()[0]->id(), out);
153+
VLOG(2) << fmt::format(
154+
"Completed executing node: {} with a={}, out={}",
155+
fmt::streamed(*node_),
156+
a,
157+
out);
158+
return;
159+
}
160+
int64_t b = inputs.at("b").toInt();
161+
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
162+
int64_t out;
163+
164+
if (target == "_operator.floordiv") {
165+
out = a / b;
166+
} else if (target == "_operator.mod") {
167+
out = a % b;
168+
} else if (target == "torch.sym_max") {
169+
out = std::max(a, b);
170+
} else if (target == "torch.sym_min") {
171+
out = std::min(a, b);
172+
} else {
173+
throw std::runtime_error(
174+
fmt::format("unsupported operator for SymInt: {}", node_->target()));
175+
}
176+
177+
executionFrame.setIValue(node_->outputs()[0]->id(), out);
178+
VLOG(2) << fmt::format(
179+
"Completed executing node: {} with a={}, b={}, out={}",
180+
fmt::streamed(*node_),
181+
a,
182+
b,
183+
out);
184+
}
185+
186+
void SymBoolOpKernel::computeInternal(ExecutionFrame& executionFrame) const {
187+
auto inputs = getSymInputs(executionFrame, *node_);
188+
189+
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
190+
bool out;
191+
192+
const std::string_view target = node_->target();
193+
if (target == "torch.sym_not") {
194+
bool a = inputs.at("a").toBool();
195+
out = !a;
196+
} else if (target == "_operator.ge") {
197+
int64_t a = inputs.at("a").toInt();
198+
int64_t b = inputs.at("b").toInt();
199+
out = a >= b;
200+
} else if (target == "_operator.le") {
201+
int64_t a = inputs.at("a").toInt();
202+
int64_t b = inputs.at("b").toInt();
203+
out = a <= b;
204+
} else if (target == "_operator.eq") {
205+
int64_t a = inputs.at("a").toInt();
206+
int64_t b = inputs.at("b").toInt();
207+
out = a == b;
208+
} else if (target == "_operator.gt") {
209+
int64_t a = inputs.at("a").toInt();
210+
int64_t b = inputs.at("b").toInt();
211+
out = a > b;
212+
} else if (target == "_operator.lt") {
213+
int64_t a = inputs.at("a").toInt();
214+
int64_t b = inputs.at("b").toInt();
215+
out = a < b;
216+
} else if (target == "_operator.and_") {
217+
bool a = inputs.at("a").toBool();
218+
bool b = inputs.at("b").toBool();
219+
out = a && b;
220+
} else {
221+
throw std::runtime_error(
222+
fmt::format("unsupported operator for SymBool: {}", node_->target()));
223+
}
224+
225+
executionFrame.setIValue(node_->outputs()[0]->id(), out);
226+
}
227+
228+
void SymFloatOpKernel::computeInternal(ExecutionFrame& executionFrame) const {
229+
auto inputs = getSymInputs(executionFrame, *node_);
230+
231+
const std::string_view target = node_->target();
232+
if (target == "math.trunc") {
233+
double x = inputs.at("x").toDouble();
234+
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
235+
int64_t out = trunc(x);
236+
executionFrame.setIValue(node_->outputs()[0]->id(), out);
237+
} else if (target == "torch._sym_sqrt") {
238+
double a = inputs.at("a").toDouble();
239+
double out = std::sqrt(a);
240+
executionFrame.setIValue(node_->outputs()[0]->id(), out);
241+
} else if (target == "_operator.neg") {
242+
auto a = inputs.at("a");
243+
c10::IValue out;
244+
if (a.isInt()) {
245+
out = -a.toInt();
246+
} else if (a.isDouble()) {
247+
out = -a.toDouble();
248+
} else {
249+
throw std::runtime_error("unsupported type for symbolic input");
250+
}
251+
executionFrame.setIValue(node_->outputs()[0]->id(), out);
252+
} else if (target == "_operator.truediv") {
253+
auto ia = inputs.at("a");
254+
double a = ia.isInt() ? static_cast<double>(ia.toInt()) : ia.toDouble();
255+
auto ib = inputs.at("b");
256+
double b = ib.isInt() ? static_cast<double>(ib.toInt()) : ib.toDouble();
257+
double out = a / b;
258+
executionFrame.setIValue(node_->outputs()[0]->id(), out);
259+
} else {
260+
throw std::runtime_error(
261+
fmt::format("unsupported operator for SymFloat: {}", node_->target()));
262+
}
263+
}
264+
265+
} // namespace torch::nativert

0 commit comments

Comments
 (0)