Intel® Extension for PyTorch* v1.13.0+cpu Release Notes
          ·
          
            37 commits
          
          to release/1.13
          since this release
        
        
        
We are pleased to announce the release of Intel® Extension for PyTorch* 1.13.0-cpu which accompanies PyTorch 1.13. This release is highlighted with quite a few usability features which help users to get good performance and accuracy on CPU with less effort. We also added a couple of performance features as always. Check out the feature summary below.
- Usability Features
- Automatic channels last format conversion: Channels last conversion is now applied automatically to PyTorch modules with ipex.optimizeby default. Users don't have to explicitly convert input and weight for CV models.
- Code-free optimization (experimental): ipex.optimizeis automatically applied to PyTorch modules without the need of code changes when the PyTorch program is started with the IPEX launcher via the new--auto-ipexoption.
- Graph capture mode of ipex.optimize(experimental): A new boolean flaggraph_mode(default off) was added toipex.optimize, when turned on, converting the eager-mode PyTorch module into graph(s) to get the best of graph optimization.
- INT8 quantization accuracy autotune (experimental): A new quantization API ipex.quantization.autotunewas added to refine the default IPEX quantization recipe via autotuning algorithms for better accuracy.
- Hypertune (experimental) is a new tool added on top of IPEX launcher to automatically identify the good configurations for best throughput via hyper-parameter tuning.
- ipexrun: The counterpart of torchrun, is a shortcut added for invoking IPEX launcher.
- Performance Features
- Packed MKL SGEMM landed as the default kernel option for FP32 Linear, bringing up-to 20% geomean speedup for real-time NLP tasks.
- DL compiler is now turned on by default with oneDNN fusion and gives additional performance boost for INT8 models.
Highlights
- Automatic channels last format conversion: Channels last conversion is now applied to PyTorch modules automatically with ipex.optimizeby default for both training and inference scenarios. Users don't have to explicitly convert input and weight for CV models.
import intel_extension_for_pytorch as ipex
# No need to do explicitly format conversion
# m = m.to(format=torch.channels_last)
# x = x.to(format=torch.channels_last)
# for inference
m = ipex.optimize(m)
m(x)
# for training
m, optimizer = ipex.optimize(m, optimizer)
m(x)- Code-free optimization (experimental): ipex.optimizeis automatically applied to PyTorch modules without the need of code changes when the PyTorch program is started with the IPEX launcher via the new--auto-ipexoption.
Example: QA case in HuggingFace
# original command
ipexrun --use_default_allocator --ninstance 2 --ncore_per_instance 28 run_qa.py \
  --model_name_or_path bert-base-uncased --dataset_name squad --do_eval \
  --per_device_train_batch_size 12 --learning_rate 3e-5 --num_train_epochs 2 \
  --max_seq_length 384 --doc_stride 128 --output_dir /tmp/debug_squad/
# automatically apply bfloat16 optimization (--auto-ipex --dtype bfloat16)
ipexrun --use_default_allocator --ninstance 2 --ncore_per_instance 28 --auto_ipex --dtype bfloat16 run_qa.py \
  --model_name_or_path bert-base-uncased --dataset_name squad --do_eval \
  --per_device_train_batch_size 12 --learning_rate 3e-5 --num_train_epochs 2 \
  --max_seq_length 384 --doc_stride 128 --output_dir /tmp/debug_squad/- Graph capture mode of ipex.optimize(experimental): A new boolean flaggraph_mode(default off) was added toipex.optimize, when turned on, converting the eager-mode PyTorch module into graph(s) to get the best of graph optimization. Under the hood, it combines the goodness of both TorchScript tracing and TorchDynamo to get as max graph scope as possible. Currently, it only supports FP32 and BF16 inference. INT8 inference and training support are under way.
import intel_extension_for_pytorch as ipex
model = ...
model.load_state_dict(torch.load(PATH))
model.eval()
optimized_model = ipex.optimize(model, graph_mode=True)- INT8 quantization accuracy autotune (experimental): A new quantization API ipex.quantization.autotunewas added to refine the default IPEX quantization recipe via autotuning algorithms for better accuracy. This is an optional API to invoke (afterprepareand beforeconvert) for scenarios when the accuracy of default quantization recipe of IPEX cannot meet the requirement. The current implementation is powered by Intel Neural Compressor (INC).
import intel_extension_for_pytorch as ipex
# Calibrate the model
qconfig = ipex.quantization.default_static_qconfig
calibrated_model = ipex.quantization.prepare(model_to_be_calibrated, qconfig, example_inputs=example_inputs)
for data in calibration_data_set:
    calibrated_model(data)
# Autotune the model
calib_dataloader = torch.utils.data.DataLoader(...)
def eval_func(model):
    # Return accuracy value
    ...
    return accuracy
tuned_model = ipex.quantization.autotune(
                 calibrated_model, calib_dataloader, eval_func,
                 sampling_sizes=[100], accuracy_criterion={'relative': 0.01}, tuning_time=0
              )
# Convert the model to jit model
quantized_model = ipex.quantization.convert(tuned_model)
with torch.no_grad():
    traced_model = torch.jit.trace(quantized_model, example_input)
    traced_model = torch.jit.freeze(traced_model)
# Do inference
y = traced_model(x)- Hypertune (experimental) is a new tool added on top of IPEX launcher to automatically identify the good configurations for best throughput via hyper-parameter tuning.
python -m intel_extension_for_pytorch.cpu.launch.hypertune --conf_file <your_conf_file> <your_python_script> [args]Known Issues
Please check at Known Issues webpage.