diff --git a/onnxruntime/contrib_ops/cpu/layer_norm.cc b/onnxruntime/contrib_ops/cpu/layer_norm.cc index c949fcddad093..b4c4f475461b7 100644 --- a/onnxruntime/contrib_ops/cpu/layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/layer_norm.cc @@ -17,6 +17,12 @@ namespace contrib { .TypeConstraint("V", DataTypeImpl::GetTensorType()), \ LayerNorm); \ ONNX_OPERATOR_TYPED_KERNEL_EX(SimplifiedLayerNormalization, kOnnxDomain, 1, T, kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("U", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("V", DataTypeImpl::GetTensorType()), \ + LayerNorm); \ + ONNX_OPERATOR_TYPED_KERNEL_EX(SimplifiedLayerNormalization, kMSDomain, 1, T, kCpuExecutionProvider, \ KernelDefBuilder() \ .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ .TypeConstraint("U", DataTypeImpl::GetTensorType()) \ diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 36d6fc378d45e..027825871295e 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -148,6 +148,13 @@ class CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, MLFloat16_float_MLFloat16, class CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, float_float_MLFloat16, LayerNormalization); class CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, MLFloat16_float_float, LayerNormalization); class CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, BFloat16_float_BFloat16, LayerNormalization); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_float_float, SimplifiedLayerNormalization); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, double_double_double, SimplifiedLayerNormalization); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_float_MLFloat16, SimplifiedLayerNormalization); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_float_float, SimplifiedLayerNormalization); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16_float_BFloat16, SimplifiedLayerNormalization); +// For backward compatibility: SimplifiedLayerNormalization was previously incorrectly registered to the ONNX domain class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float_float_float, SimplifiedLayerNormalization); class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double_double_double, SimplifiedLayerNormalization); class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization); @@ -384,6 +391,13 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // For backward compatibility: SimplifiedLayerNormalization was previously incorrectly registered to the ONNX domain BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/layer_norm.cc b/onnxruntime/contrib_ops/cuda/layer_norm.cc index fad9ce7a6bf7d..06fb8d4fb846f 100644 --- a/onnxruntime/contrib_ops/cuda/layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/layer_norm.cc @@ -19,6 +19,12 @@ namespace cuda { .TypeConstraint("V", DataTypeImpl::GetTensorType()), \ onnxruntime::cuda::LayerNorm); \ ONNX_OPERATOR_TYPED_KERNEL_EX(SimplifiedLayerNormalization, kOnnxDomain, 1, T##_##U##_##V, kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("U", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("V", DataTypeImpl::GetTensorType()), \ + onnxruntime::cuda::LayerNorm); \ + ONNX_OPERATOR_TYPED_KERNEL_EX(SimplifiedLayerNormalization, kMSDomain, 1, T##_##U##_##V, kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ .TypeConstraint("U", DataTypeImpl::GetTensorType()) \ diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 36a6f9bd87013..e3b4aa0377d02 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -21,6 +21,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiH class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, QuickGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, RotaryEmbedding); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipSimplifiedLayerNormalization); @@ -53,6 +54,8 @@ Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo); +ONNX_OPERATOR_KERNEL_EX( + SimplifiedLayerNormalization, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", onnxruntime::js::JsepSupportedFloatTypes()) + .TypeConstraint("U", onnxruntime::js::JsepSupportedFloatTypes()), + onnxruntime::js::LayerNorm); + } // namespace js } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 7dbb24463961e..a122f5e55084e 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -112,6 +112,11 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, float_float_MLFloat16, LayerNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, MLFloat16_float_float, LayerNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, BFloat16_float_BFloat16, LayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_float, SimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double_double_double, SimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float_float_float, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double_double_double, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization); @@ -272,6 +277,11 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc b/onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc index 8997e8698d96d..791b7ac84b670 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc @@ -31,6 +31,14 @@ ONNX_OPERATOR_KERNEL_EX( (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), onnxruntime::webgpu::LayerNorm); +ONNX_OPERATOR_KERNEL_EX( + SimplifiedLayerNormalization, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + onnxruntime::webgpu::LayerNorm); + } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 25cc13b3ea1df..e8216866a85d3 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -23,9 +23,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Ma class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipLayerNormalization); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipSimplifiedLayerNormalization); template <> @@ -52,6 +52,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo};