-
Notifications
You must be signed in to change notification settings - Fork 233
TF32 POC in Conv3d on MI30x platform #2763
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
a2506cc
to
1187441
Compare
...de/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
Outdated
Show resolved
Hide resolved
.gitignore
Outdated
@@ -70,4 +70,3 @@ build*/ | |||
__pycache__/ | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missclick?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. It seems a space is auto deleted by VSCode. Will try to recover it.
example/01_gemm/common.hpp
Outdated
@@ -310,10 +310,14 @@ bool parse_cmd_args<ProblemSizeSplitK>(int argc, | |||
return true; | |||
} | |||
|
|||
template <typename DataType> | |||
template <typename DataType, typename GemmType = DataType> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you change it to ComputeType to keep naming convention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Use ComputeDataType
to align with device_gemm_xdl_cshuffle_lds_direct_load.hpp#L61
example/01_gemm/run_gemm_example.inc
Outdated
@@ -4,6 +4,11 @@ | |||
#pragma once | |||
#include "ck/library/utility/validation_common.hpp" | |||
|
|||
// use macro to minimize code change | |||
#ifndef EXAMPLE_WITH_GEMM_DATATYPE | |||
using GemmDataType = AccDataType; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ComputeType
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -68,10 +72,14 @@ inline __host__ __device__ constexpr double get_rtol() | |||
} | |||
} | |||
|
|||
template <typename DataType> | |||
template <typename DataType, typename GemmType = DataType> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Compute Type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
#ifndef EXAMPLE_WITH_GEMM_DATATYPE | ||
using GemmDataType = AccDataType; | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ComputeDataType
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -111,8 +111,9 @@ template <typename ALayout, | |||
typename AElementwiseOperation, | |||
typename BElementwiseOperation, | |||
typename CElementwiseOperation, | |||
typename ComputeTypeA = CDataType, | |||
typename ComputeTypeB = ComputeTypeA> | |||
typename ComputeTypeA = CDataType, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adatatype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok.
typename DsLayout, | ||
typename ELayout, | ||
ConvolutionForwardSpecialization ConvSpec> | ||
using device_grouped_conv_fwd_xdl_dynamic_op_f32_tf32_instances = std::tuple< |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably dont need dynamic op instances since it has not been integrated with MIOpen
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
@@ -553,6 +565,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe | |||
add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances( | |||
op_ptrs); | |||
} | |||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need something like CK_ENABLE_TF32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CK API use different template params ComputeDataTypeA/B
to distinguish tf32 or fp32 compute. No incorrect usage will occur.
While MIOpen use MIOPEN_TF32_OVERRIDE
(vs NVIDIA_TF32_OVERRIDE
) to disable TF32 mode which means MIOpen will select different CK kernel. That should be enough.
namespace ck { | ||
namespace tensor_operation { | ||
namespace device { | ||
namespace instance { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Plese dont extend gndhwc layout since it is not used widely
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
...uped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me.
Proposed changes
Demonstrate TF32(XF32 in CDNA3 ISA) kernel in conv3d. Also add lots of instances for miopen.
Checklist
Please put an
x
into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-format
on all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered