Skip to content

Commit 4e0235f

Browse files
yushangdifacebook-github-bot
authored andcommitted
[nativert] Move auto_functionalize_kernel
Summary: Torch Native Runtime RFC: pytorch/rfcs#72 As part of the effort to open source TorchNativeRuntime (or what we call Sigmoid), we are moving the Pytree implementation to torch/: fbcode/sigmoid/kernels -> fbcode/caffe2/torch/nativert/kernels Copied from original auto_functionalize Diff Summary D53776805: This is a non-functional kernel implementation for auto_functionalize In AutoFunctionalizeKernel, I directly call the underlying target without making a clone of mutating inputs. This would mutates the input tensors inplace, which is unsafe in general. However, Sigmoid is not doing any graph optimization, or node reordering at the moment, so it's ok do take this short cut. In the proper functional implementation, it will make a clone of the mutating input tensor return these new instance of tensors as AutoFunctionalizeKernel output. If the original exported program has some "bufferMutation" or "userInputMutation" fields, it will also need to honor such mutations in Sigmoid. Test Plan: See internal for test plan Differential Revision: D76926383
1 parent c60d818 commit 4e0235f

File tree

3 files changed

+94
-0
lines changed

3 files changed

+94
-0
lines changed

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,7 @@ libtorch_nativert_sources = [
608608
"torch/nativert/common/FileUtil.cpp",
609609
"torch/nativert/detail/ITree.cpp",
610610
"torch/nativert/kernels/C10Kernel.cpp",
611+
"torch/nativert/kernels/AutoFunctionalizeKernel.cpp",
611612
]
612613

613614
torch_mobile_tracer_sources = [
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#include <torch/nativert/kernels/AutoFunctionalizeKernel.h>
2+
3+
#include <fmt/format.h>
4+
5+
#include <c10/util/Enumerate.h>
6+
7+
namespace torch::nativert {
8+
9+
UnsafeAutoFunctionalizeKernel::UnsafeAutoFunctionalizeKernel(const Node* node)
10+
: OpKernel(node),
11+
op_(getOperatorForTarget(
12+
std::get<std::string>(node->attributes()[0].value))),
13+
schema_(op_.schema()),
14+
arguments_(prefillStackWithStaticArgs(node, schema_)) {
15+
for (const auto& [idx, schemaArg] : c10::enumerate(schema_.arguments())) {
16+
if (schemaArg.alias_info() != nullptr &&
17+
schemaArg.alias_info()->isWrite()) {
18+
mutatingInputArgs_.push_back(node->getInput(schemaArg.name()).value);
19+
}
20+
}
21+
22+
numOutputs_ = schema_.returns().size();
23+
}
24+
25+
void UnsafeAutoFunctionalizeKernel::computeInternal(
26+
ExecutionFrame& executionFrame) const {
27+
// Make a copy of the stack
28+
std::vector<c10::IValue> stack = arguments_.getStackWithStaticArgs();
29+
30+
fillDynamicInputs(executionFrame, arguments_, stack);
31+
32+
// Call the op with the prepared stack.
33+
try {
34+
op_.callBoxed(stack);
35+
} catch (const std::exception& ex) {
36+
// TODO: this eats the original exception type. ATen returns different
37+
// exception types that correspond to different Python errors (e.g.
38+
// IndexError, ValueError). If retaining this information is important
39+
// to us, we'll have to change this up a little.
40+
auto stackTrace = node_->getMetadata("stack_trace");
41+
throw std::runtime_error(fmt::format(
42+
"Original Python stacktrace:\n{}\n{}",
43+
stackTrace ? *stackTrace : "<no stack trace>",
44+
ex.what()));
45+
}
46+
47+
const auto& outputValues = node_->outputs();
48+
49+
for (int i = 0; i < numOutputs_; ++i) {
50+
executionFrame.setIValue(outputValues[i]->id(), std::move(stack.at(i)));
51+
}
52+
53+
// Copy over mutating inputs to outputs
54+
int mutatingArgStartIndex = (numOutputs_ == 0) ? 1 : numOutputs_;
55+
for (size_t i = mutatingArgStartIndex; i < outputValues.size(); ++i) {
56+
executionFrame.setIValue(
57+
outputValues[i]->id(),
58+
executionFrame.getIValue(
59+
mutatingInputArgs_.at(i - mutatingArgStartIndex)->id(),
60+
true /* allowNone */));
61+
}
62+
}
63+
64+
} // namespace torch::nativert
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include <ATen/core/dispatch/Dispatcher.h>
4+
#include <ATen/core/function_schema.h>
5+
#include <c10/core/Device.h>
6+
7+
#include <torch/nativert/executor/ExecutionFrame.h> // @manual
8+
#include <torch/nativert/executor/OpKernel.h> // @manual
9+
10+
namespace torch::nativert {
11+
12+
class UnsafeAutoFunctionalizeKernel : public OpKernel {
13+
public:
14+
UnsafeAutoFunctionalizeKernel() = delete; // deleted default constructor
15+
UnsafeAutoFunctionalizeKernel(const Node* node);
16+
17+
void computeInternal(ExecutionFrame& executionFrame) const override final;
18+
19+
private:
20+
c10::OperatorHandle op_;
21+
c10::FunctionSchema schema_;
22+
23+
Arguments arguments_;
24+
25+
std::vector<Value*> mutatingInputArgs_;
26+
int numOutputs_;
27+
};
28+
29+
} // namespace torch::nativert

0 commit comments

Comments
 (0)