Skip to content

Commit b2e3bd9

Browse files
authored
Merge branch 'main' into io_params
2 parents 62db20b + fc435fa commit b2e3bd9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+990
-721
lines changed

.ci/docker/ci_commit_pins/pytorch.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
9b498d3bb28b8e3411ce464dd2755c5b96d92c8f
1+
7cda4017ddda554752e89069ae205be5e8388f59

.ci/scripts/check_c10_sync.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ pushd pytorch
1212
git checkout "$pytorch_pin"
1313
popd
1414
"$(dirname "${BASH_SOURCE[0]}")"/compare_dirs.sh runtime/core/portable_type/c10/c10 pytorch/c10
15-
"$(dirname "${BASH_SOURCE[0]}")"/compare_dirs.sh runtime/core/portable_type/c10/torch/standalone pytorch/torch/standalone
15+
"$(dirname "${BASH_SOURCE[0]}")"/compare_dirs.sh runtime/core/portable_type/c10/torch/headeronly pytorch/torch/headeronly

.github/workflows/trunk.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,11 @@ jobs:
240240
241241
cxx_flags="-fno-exceptions -fno-rtti -Wall -Werror -Wno-int-in-bool-context -DET_HAVE_PREAD=0"
242242
setup_script_args=""
243-
if [[ ${{ matrix.os}} == "bare_metal" ]]; then
243+
if [[ ${{ matrix.os}} == "bare_metal" ]]; then
244244
toolchain_prefix=arm-none-eabi-
245-
threshold="103268" # ~100KiB
245+
threshold="104000" # should be ~103.7KB, set threshold to 104KB.
246246
toolchain_cmake=examples/arm/ethos-u-setup/arm-none-eabi-gcc.cmake
247-
elif [[ ${{ matrix.os}} == "zephyr-preset" ]]; then
247+
elif [[ ${{ matrix.os}} == "zephyr-preset" ]]; then
248248
setup_script_args="--target-toolchain zephyr"
249249
toolchain_prefix=arm-zephyr-eabi-
250250
threshold="133120" # should be ~125KB, set threshold to 130KB

CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ install(
490490
INCLUDES
491491
DESTINATION ${_common_include_directories}
492492
)
493-
install(FILES tools/cmake/executorch-config.cmake
493+
install(FILES tools/cmake/Utils.cmake tools/cmake/executorch-config.cmake
494494
DESTINATION lib/cmake/ExecuTorch
495495
)
496496

@@ -732,4 +732,8 @@ if(EXECUTORCH_BUILD_VULKAN)
732732
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/vulkan)
733733
endif()
734734

735+
if(EXECUTORCH_BUILD_ANDROID_JNI)
736+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/android)
737+
endif()
738+
735739
include(Test.cmake)

backends/cadence/aot/compiler.py

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import logging
1010
from pathlib import Path
11-
from typing import Callable, cast, Optional
11+
from typing import Optional
1212

1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
@@ -32,7 +32,6 @@
3232
ExecutorchBackendConfig,
3333
ExecutorchProgramManager,
3434
)
35-
from executorch.exir.pass_base import PassResult
3635
from executorch.exir.passes import ToOutVarPass
3736
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
3837
from executorch.exir.program._program import to_edge_with_preserved_ops
@@ -41,7 +40,7 @@
4140
from torch.export.exported_program import ExportedProgram
4241
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
4342

44-
from .passes import get_cadence_passes
43+
from .passes import apply_exir_ops_passes, apply_torch_ops_passes
4544

4645
from .utils import print_ops_info
4746

@@ -262,14 +261,20 @@ def export_to_edge(
262261
inputs: tuple[object, ...],
263262
dump_graphs: bool = False,
264263
constant_methods: Optional[dict[str, object]] = None,
264+
core_aten_exceptions: Optional[list[torch._ops.OpOverload]] = None,
265265
) -> EdgeProgramManager:
266266
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
267267

268268
# Export the model into an ExportedProgram.
269269
expo_program = trace(model, inputs)
270270

271+
# Apply passes which transform the ExportedProgram before it gets lowered to edge.
272+
expo_program = apply_torch_ops_passes(expo_program)
273+
271274
# Lower the model to edge IR.
272-
edge_prog_manager = _lower_ep_to_edge(expo_program, dump_graphs, constant_methods)
275+
edge_prog_manager = _lower_ep_to_edge(
276+
expo_program, dump_graphs, constant_methods, core_aten_exceptions
277+
)
273278

274279
return edge_prog_manager
275280

@@ -311,14 +316,7 @@ def _lower_ep_to_cadence(
311316
Lower an existing ExportedProgram to edge IR and apply frontend optimization passes.
312317
"""
313318
edge_prog_manager = _lower_ep_to_edge(program, dump_graphs=dump_graphs)
314-
cadence_passes = get_cadence_passes(opt_level)
315-
316-
# Run a couple required passes for quant/dequant ops
317-
cadence_prog_manager = edge_prog_manager.transform(
318-
cast(
319-
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
320-
)
321-
)
319+
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
322320
return cadence_prog_manager
323321

324322

@@ -329,14 +327,7 @@ def export_to_cadence(
329327
opt_level: int = 1,
330328
) -> EdgeProgramManager:
331329
edge_prog_manager = export_to_edge(model, inputs, dump_graphs=dump_graphs)
332-
cadence_passes = get_cadence_passes(opt_level)
333-
334-
# Run a couple required passes for quant/dequant ops
335-
cadence_prog_manager = edge_prog_manager.transform(
336-
cast(
337-
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
338-
)
339-
)
330+
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
340331
return cadence_prog_manager
341332

342333

@@ -373,15 +364,8 @@ def export_to_executorch_gen_etrecord(
373364
memory_config: Optional[MemoryConfig] = None,
374365
dump_graphs: bool = False,
375366
) -> ExecutorchProgramManager:
376-
cadence_passes = get_cadence_passes(opt_level)
377367
edge_prog_manager = export_to_edge(model, inputs, dump_graphs)
378-
379-
# Run a couple required passes for quant/dequant ops
380-
cadence_prog_manager = edge_prog_manager.transform(
381-
cast(
382-
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
383-
)
384-
)
368+
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
385369

386370
# Print some information to terminal
387371
print_ops_info(

backends/cadence/aot/fuse_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,6 +1127,7 @@ class CadenceFuseOpsInGraph:
11271127
FuseCascadedTransposeOrPermuteOps,
11281128
FuseCascadedViewOps,
11291129
FuseQuantDequantToRequantizePass,
1130+
FuseMulTensorIntoQuantPass,
11301131
FuseMulTensorIntoDequantPass,
11311132
FuseMulScalarIntoDequantPass,
11321133
FuseFullThenReshapePass,

backends/cadence/aot/pass_utils.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -174,30 +174,53 @@ def nodes_not_adjacent_in_gm(
174174

175175
def get_arg(
176176
node: torch.fx.Node,
177-
arg_index: int,
178177
kwarg_name: str,
179-
*,
180-
default: torch.fx.node.Argument = None,
181178
) -> torch.fx.node.Argument:
182179
"""
183-
Get the arg at arg_index or kwarg with arg_name of the node. If neither is found
184-
return default.
180+
Get the arg with arg_name of the node, returns default value if not set.
185181
"""
186-
if arg_index < len(node.args):
187-
return node.args[arg_index]
188-
elif kwarg_name in node.kwargs:
182+
# Try to get the arg from kwargs first since this is faster
183+
if kwarg_name in node.kwargs:
189184
return node.kwargs[kwarg_name]
190-
else:
191-
return default
185+
186+
# If it's not found in kwargs, try to normalize the args
187+
normalized_args = node.normalized_arguments(
188+
node.graph.owning_module, normalize_to_only_use_kwargs=True
189+
)
190+
if not normalized_args:
191+
raise RuntimeError(
192+
f"get_arg: Node {node} does not support normalization of arguments"
193+
)
194+
195+
return normalized_args.kwargs[kwarg_name]
192196

193197

194198
def set_arg(
195-
node: torch.fx.Node, arg_index: int, kwarg_name: str, value: torch.fx.node.Argument
199+
node: torch.fx.Node, kwarg_name: str, value: torch.fx.node.Argument
196200
) -> None:
197201
"""
198-
Set the arg at arg_index if it exists, otherwise set the kwarg.
202+
Set the node's arg with its name to the given value.
199203
"""
200-
if arg_index < len(node.args):
201-
node.update_arg(arg_index, value)
204+
# Try to set the arg if it is present in kwargs first since this is faster
205+
if kwarg_name in node.kwargs:
206+
node.update_kwarg(kwarg_name, value)
207+
return
208+
209+
# If it's not found in kwargs, try to normalize the args and set the arg
210+
normalized_args = node.normalized_arguments(
211+
node.graph.owning_module, normalize_to_only_use_kwargs=True
212+
)
213+
if not normalized_args:
214+
raise RuntimeError(
215+
f"set_arg: Node {node} does not support normalization of arguments"
216+
)
217+
218+
kwargs = normalized_args.kwargs
219+
if kwarg_name not in kwargs:
220+
raise ValueError(f"set_arg: invalid arg name {kwarg_name} for node {node} used")
221+
222+
idx = list(kwargs.keys()).index(kwarg_name)
223+
if idx < len(node.args):
224+
node.update_arg(idx, value)
202225
else:
203226
node.update_kwarg(kwarg_name, value)

backends/cadence/aot/passes.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88

9-
from typing import Any, List, Optional
9+
from typing import Any, Callable, cast, List, Optional
1010

1111
import torch
1212
import torch.fx
@@ -28,13 +28,18 @@
2828
RemoveRedundantOps,
2929
)
3030
from executorch.backends.cadence.aot.reorder_ops import CadenceReorderOpsInGraph
31-
from executorch.backends.cadence.aot.replace_ops import CadenceReplaceOpsInGraph
31+
from executorch.backends.cadence.aot.replace_ops import (
32+
CadenceReplaceOpsInGraph,
33+
ReplaceMulTensorWithMulAndFullOpsPass,
34+
)
3235
from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph
36+
from executorch.exir import EdgeProgramManager
3337
from executorch.exir.pass_base import ExportPass, PassResult
3438
from executorch.exir.pass_manager import PassManager, PassType
3539
from executorch.exir.passes import dead_code_elimination_pass
3640
from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
3741
from executorch.exir.passes.spec_prop_pass import SpecPropPass
42+
from torch.export.exported_program import ExportedProgram
3843

3944

4045
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -89,14 +94,37 @@ def get_passes_in_default_order() -> List[ExportPass]:
8994
return pytree.tree_flatten(passes)[0]
9095

9196

92-
def get_cadence_passes(
97+
def apply_exir_ops_passes(
9398
opt_level: int,
94-
) -> List[Optional[PassResult]]:
99+
edge_prog_manager: EdgeProgramManager,
100+
) -> EdgeProgramManager:
95101
passes = get_passes_in_default_order()
96102
pass_filter = create_cadence_pass_filter(opt_level)
97-
filtered_passes = [
98-
# pyre-ignore[20]: Expect argument graph_module
99-
filtered_pass()
103+
cadence_passes = [
104+
(
105+
lambda graph_module, filtered_pass=filtered_pass: filtered_pass()(
106+
graph_module
107+
)
108+
)
100109
for filtered_pass in list(filter(pass_filter, passes))
101110
]
102-
return filtered_passes
111+
cadence_prog_manager = edge_prog_manager.transform(
112+
cast(
113+
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
114+
)
115+
)
116+
return cadence_prog_manager
117+
118+
119+
def apply_torch_ops_passes(expo_program: ExportedProgram) -> ExportedProgram:
120+
"""
121+
Applies compiler passes on torch.ops IR, including torch.ops.aten, torch.ops.cadence, etc.
122+
expo_program is expected to be the output of the torch.export.export().
123+
"""
124+
125+
aten_passes: List[Callable[[torch.fx.GraphModule], Optional[PassResult]]] = [
126+
ReplaceMulTensorWithMulAndFullOpsPass()
127+
]
128+
# TODO(T230417247): Use PassResult which is currently ignored.
129+
PassManager(aten_passes)(expo_program.graph_module)
130+
return expo_program

backends/cadence/aot/remove_ops.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -779,17 +779,17 @@ def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
779779
for slice_copy_node in graph_module.graph.find_nodes(
780780
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
781781
):
782-
cat_node = cast(Node, get_arg(slice_copy_node, 0, "input"))
783-
slice_dim = cast(int, get_arg(slice_copy_node, 1, "dim", default=0))
784-
start_idx = cast(int, get_arg(slice_copy_node, 2, "start", default=None))
785-
end_idx = cast(int, get_arg(slice_copy_node, 3, "end", default=None))
786-
step = cast(int, get_arg(slice_copy_node, 4, "step", default=1))
782+
cat_node = cast(Node, get_arg(slice_copy_node, "input"))
783+
slice_dim = cast(int, get_arg(slice_copy_node, "dim"))
784+
start_idx = cast(int, get_arg(slice_copy_node, "start"))
785+
end_idx = cast(int, get_arg(slice_copy_node, "end"))
786+
step = cast(int, get_arg(slice_copy_node, "step"))
787787

788788
if cat_node.target != exir_ops.edge.aten.cat.default or step != 1:
789789
continue
790790

791791
# Make sure cat and slice happens on the same dimension.
792-
cat_dim = cast(Node, get_arg(cat_node, 1, "dim", default=0))
792+
cat_dim = cast(Node, get_arg(cat_node, "dim"))
793793
if cat_dim != slice_dim:
794794
continue
795795

@@ -805,14 +805,14 @@ def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
805805
end_idx += cat_output_shape[cat_dim]
806806

807807
offset = 0
808-
for cat_input_node in cast(List[Node], get_arg(cat_node, 0, "tensors")):
808+
for cat_input_node in cast(List[Node], get_arg(cat_node, "tensors")):
809809
cat_input_shape = cat_input_node.meta["val"].shape
810810

811811
# Check if the slice range overlaps with the cat input range.
812812
if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]:
813813
slice_copy_node.replace_input_with(cat_node, cat_input_node)
814-
set_arg(slice_copy_node, 2, "start", start_idx - offset)
815-
set_arg(slice_copy_node, 3, "end", end_idx - offset)
814+
set_arg(slice_copy_node, "start", start_idx - offset)
815+
set_arg(slice_copy_node, "end", end_idx - offset)
816816
break
817817

818818
offset += cat_input_shape[cat_dim]

backends/nxp/runtime/TARGETS

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
load("targets.bzl", "define_common_targets")
2+
3+
oncall("executorch")
4+
5+
define_common_targets()

0 commit comments

Comments
 (0)