6767 MULTI_MODAL_TEXT_GENERATION_MODELS ,
6868 OV_XML_FILE_NAME ,
6969 _get_input_info ,
70- _get_dynamic_shapes_info ,
71- _normalize_dummy_inputs ,
7270 _get_open_clip_submodels_fn_and_export_configs ,
73- get_model_dtype ,
7471 allow_skip_tracing_check ,
7572 clear_class_registry ,
7673 remove_none_from_dummy_inputs ,
@@ -428,7 +425,6 @@ def export_pytorch(
428425 patched_forward = patcher .patched_forward
429426 dummy_input_keys = list (dummy_inputs .keys ())
430427
431- < << << << HEAD
432428 @functools .wraps (patched_forward )
433429 def ts_patched_forward (* args , ** kwargs ):
434430 ordered_example_inputs = [
@@ -446,158 +442,14 @@ def ts_patched_forward(*args, **kwargs):
446442 kwargs [input_name ] = input_dict
447443 outputs = patched_forward (** kwargs )
448444 return tuple ([value if not isinstance (value , list ) else tuple (value ) for value in outputs .values ()])
449- == == == =
450- try :
451- # TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching,
452- # while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output
453- # To handle it, additional wrapper on patcher forward applied.
454- # model.config.torchscript = True can not be used for patching, because it overrides return_dict to False
455- patcher = config .patch_model_for_export (model , model_kwargs = model_kwargs )
456- #patched_forward = patcher.orig_forward
457- import inspect
458- from optimum .exporters .onnx .model_patcher import override_arguments
459-
460- if is_transformers_version (">=" , "4.48" ):
461- from transformers .cache_utils import DynamicCache , EncoderDecoderCache
462-
463- @functools .wraps (patcher .orig_forward )
464- def patched_forward (* args , ** kwargs ):
465- signature = inspect .signature (patcher .orig_forward )
466- args , kwargs = override_arguments (args , kwargs , signature , model_kwargs = patcher .model_kwargs )
467-
468- if is_transformers_version (">=" , "4.48" ):
469- if "past_key_values" in signature .parameters :
470- pkv_index = list (signature .parameters .keys ()).index ("past_key_values" )
471-
472- if (
473- pkv_index < len (args ) # pkv is in args
474- and isinstance (args [pkv_index ], (list , tuple ))
475- and isinstance (args [pkv_index ][0 ], (list , tuple ))
476- ):
477- if len (args [pkv_index ][0 ]) == 2 :
478- args [pkv_index ] = DynamicCache .from_legacy_cache (args [pkv_index ])
479- elif len (args [pkv_index ][0 ]) == 4 :
480- args [pkv_index ] = EncoderDecoderCache .from_legacy_cache (args [pkv_index ])
481- else :
482- raise ValueError (
483- f"past_key_values should have either 2 or 4 elements, but it has { len (args [pkv_index ][0 ])} elements"
484- )
485- elif (
486- "past_key_values" in kwargs # pkv is in kwargs
487- and isinstance (kwargs ["past_key_values" ], (list , tuple ))
488- and isinstance (kwargs ["past_key_values" ][0 ], (list , tuple ))
489- ):
490- if len (kwargs ["past_key_values" ][0 ]) == 2 :
491- kwargs ["past_key_values" ] = DynamicCache .from_legacy_cache (kwargs ["past_key_values" ])
492- elif len (kwargs ["past_key_values" ][0 ]) == 4 :
493- kwargs ["past_key_values" ] = EncoderDecoderCache .from_legacy_cache (
494- kwargs ["past_key_values" ]
495- )
496- else :
497- raise ValueError (
498- f"past_key_values should have either 2 or 4 elements, but it has { len (kwargs ['past_key_values' ][0 ])} elements"
499- )
500-
501- outputs = patcher .orig_forward (* args , ** kwargs )
502-
503- # This code block handles different cases of the filterd_outputs input to align it with the expected
504- # format of outputs. It is common for the output type of a model to vary, such as tensor, list,
505- # tuple, etc. For Transformers models, the output is encapsulated in a ModelOutput object that
506- # contains the output names of the model. In the case of Timm classification models, the output
507- # is of type tensor. By default, it is assumed that the output names mentioned in the ONNX config
508- # match the outputs in order.
509- filterd_outputs = {}
510- if isinstance (outputs , dict ):
511- for name , value in outputs .items ():
512- filterd_outputs [name ] = value
513- elif isinstance (outputs , (list , tuple )):
514- outputs_list = list (config .outputs .keys ())
515- filterd_outputs = dict (zip (outputs_list , outputs ))
516- else :
517- if len (config .outputs ) > 1 :
518- num_outputs = len (config .outputs )
519- outputs_str = ", " .join (config .outputs .keys ())
520- raise ValueError (
521- f"config.outputs should have only one outputs, but it has { num_outputs } keys: { outputs_str } "
522- )
523- else :
524- name = list (config .outputs .keys ())[0 ]
525- filterd_outputs [name ] = outputs
526- name = list (config .outputs .keys ())[0 ]
527- filterd_outputs [name ] = outputs
528-
529- if is_transformers_version (">=" , "4.48" ):
530- if isinstance (filterd_outputs .get ("past_key_values" ), (DynamicCache , EncoderDecoderCache )):
531- filterd_outputs ["past_key_values" ] = outputs ["past_key_values" ].to_legacy_cache ()
532-
533- return filterd_outputs
534- >> >> >> > cfde44f ([POC ] Use torch .export for converting )
535445
536446 patcher .patched_forward = ts_patched_forward
537447
538- < << << << HEAD
539448 ts_decoder_kwargs = {}
540449 model_config = getattr (model , "config" , {})
541450 model_type = getattr (model_config , "model_type" , "" ).replace ("_" , "-" )
542451 if allow_skip_tracing_check (library_name , model_type ):
543452 ts_decoder_kwargs ["trace_kwargs" ] = {"check_trace" : False }
544- == == == =
545- patcher .patched_forward = ts_patched_forward
546-
547- ts_decoder_kwargs = {}
548- model_config = getattr (model , "config" , {})
549- model_type = getattr (model_config , "model_type" , "" ).replace ("_" , "-" )
550- if allow_skip_tracing_check (library_name , model_type ):
551- ts_decoder_kwargs ["trace_kwargs" ] = {"check_trace" : False }
552-
553- with patcher :
554- use_export = True
555- check_dummy_inputs_are_allowed (model , dummy_inputs )
556- input_info = _get_input_info (model , config , dummy_inputs )
557- if use_export :
558- if hasattr (torch .ops , "_prepare_4d_causal_attention_mask_for_sdpa" ):
559- # patch_everywhere breaks torch.ops namespace
560- del torch .ops ._prepare_4d_causal_attention_mask_for_sdpa
561- dynamic_shapes = _get_dynamic_shapes_info (model , config , dummy_inputs )
562- _export_kwargs = {"args" : tuple (), "kwargs" : _normalize_dummy_inputs (dummy_inputs , get_model_dtype (model ))}
563- _export_kwargs ["dynamic_shapes" ] = dynamic_shapes
564-
565- try :
566- from nncf .torch .dynamic_graph .patch_pytorch import disable_patching
567- # nncf patching breaks export
568- with disable_patching ():
569- ep = torch .export .export_for_training (model , ** _export_kwargs )
570- except ImportError :
571- ep = torch .export .export_for_training (model , ** _export_kwargs )
572-
573- ov_model = convert_model (ep )
574- else :
575- if patch_16bit_model :
576- from openvino .frontend .pytorch .patch_model import __make_16bit_traceable
577-
578- __make_16bit_traceable (model )
579-
580- ts_decoder = TorchScriptPythonDecoder (model , example_input = dummy_inputs , ** ts_decoder_kwargs )
581- ov_model = convert_model (
582- ts_decoder ,
583- example_input = dummy_inputs ,
584- input = [(item .shape , item .type ) for item in input_info ],
585- )
586-
587- except Exception as ex :
588- logger .warning (f"Export model to OpenVINO directly failed with: \n " , exc_info = ex )
589- raise ex
590- logger .warning ("\n Model will be exported to ONNX" )
591-
592- if stateful :
593- # cannot raise because stateful is enabled by default and it would break backward compatibility for models that couldn't convert to OV directly
594- # TODO: Implement stateful for ONNX path as well, not doing it right now because of lack of validation
595- logger .warning (
596- "[ WARNING ] Making stateful models is not supported when exporting to ONNX as an intermediate step. "
597- "A stateless model will be exported instead. It may result in sub-optimal inference performance."
598- "Provide a model that can be converted to OpenVINO without fallback to ONNX conversion path."
599- )
600- >> >> >> > cfde44f ([POC ] Use torch .export for converting )
601453
602454 with patcher :
603455 if patch_16bit_model :
0 commit comments