Skip to content

Commit d1d18b9

Browse files
committed
revert prelu impl
1 parent cdc31d7 commit d1d18b9

File tree

1 file changed

+1
-8
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+1
-8
lines changed

py/torch_tensorrt/dynamo/conversion/impl/prelu.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from torch.fx.node import Target
44
from torch_tensorrt.dynamo._SourceIR import SourceIR
5-
from torch_tensorrt.dynamo.conversion import impl
65
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
76
from torch_tensorrt.dynamo.conversion.converter_utils import set_layer_name
87
from torch_tensorrt.dynamo.types import TRTTensor
@@ -16,12 +15,6 @@ def prelu(
1615
input: TRTTensor,
1716
weight: TRTTensor,
1817
) -> TRTTensor:
19-
# TRT requires that the slopes tensor must be unidirectional broadcastable to the input tensor:
20-
# the rank of the two tensors must be the same, and all dimensions of the slopes tensor must
21-
# either equal the input tensor or be 1. The output tensor has the same shape as the input tensor.
22-
input, weight = impl.elementwise.broadcast(
23-
ctx, input, weight, f"{name}_broadcast_input", f"{name}_broadcast_weight"
24-
)
25-
layer = ctx.net.add_parametric_relu(input, slopes=weight)
18+
layer = ctx.net.add_parametric_relu(input, weight)
2619
set_layer_name(layer, target, name, source_ir)
2720
return layer.get_output(0)

0 commit comments

Comments
 (0)