@@ -40,10 +40,12 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
40
40
"torchao is not installed. Please install it to use MXFP8 linear layers."
41
41
)
42
42
torchao_version = version ("torchao" )
43
- mxfp8_min_version = "0.11.0"
44
- if torchao_version < mxfp8_min_version :
43
+
44
+ # Last torchao release was 0.12.0, so nightly build starts with 0.13.0+git...
45
+ is_nightly_build = torchao_version .startswith ("0.13.0" )
46
+ if not is_nightly_build :
45
47
raise ImportError (
46
- f"torchao version { torchao_version } is too old, please install torchao { mxfp8_min_version } or later and try again"
48
+ f"torchao version { torchao_version } is too old, please install torchao nightly build and try again"
47
49
)
48
50
49
51
# Can be removed if we enable the emulated versions
@@ -56,12 +58,17 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
56
58
self .filter_fqns = mx_job_config .filter_fqns
57
59
58
60
# Configure MXFP8
59
- from torchao .prototype .mx_formats .config import MXLinearConfig
61
+ from torchao .prototype .mx_formats .config import (
62
+ MXFP8Dim1CastKernelChoice ,
63
+ MXLinearConfig ,
64
+ )
60
65
61
66
config = MXLinearConfig .from_recipe_name (NAME_MAP [mx_job_config .recipe_name ])
62
- config .use_fp8_dim1_cast_triton_kernel = (
63
- mx_job_config .use_fp8_dim1_cast_triton_kernel
64
- )
67
+
68
+ # String to enum
69
+ config .mxfp8_dim1_cast_kernel_choice = MXFP8Dim1CastKernelChoice [
70
+ mx_job_config .mxfp8_dim1_cast_kernel_choice .upper ()
71
+ ]
65
72
self .config = config
66
73
67
74
logger .info (f"Float8 training active with recipe { mx_job_config .recipe_name } " )
0 commit comments