[CK_BUILDER] Generalize convolution factory to build arbitrary device operations. #3116
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Proposed changes
The goal of this PR is to generalize the current convolution factory in CK Builder to be able to build instances of any relevant convolution device operation. The main changes are
FwdGroupConvDeviceOperation,BwdDataGroupConvDeviceOperation, andBwdWeightGroupConvDeviceOperationthat contain the device operations for which the builder should be able to build instances.GroupConvDeviceOpthat can represent a single value of the fwd, bwd weight, or bwd data device operations. This would be more naturally represented bystd::variantobject, but we cannot usestd::variantin NTTPs because it is not a structural object.device_operationin theConvSignatureDescriptorconcept that assumesGroupConvDeviceOpvalue.ConvFactoryspecialization for the different device operation. When we add support for a new device operation, we'll just create a newConvFactoryspecialization with appropriate predicates.GroupConvLayout1D,GroupConvLayout2D,GroupConvLayout3D) to use the union based handling, i.e., there's now aGroupConvLayoutunion struct that can hold a single value of the 1D, 2D, or 3D layouts. This simplifies the handling of the different layouts as we get rid of templatized convolution signature.These code changes allow developers to work more easily in parallel when adding new device operations.