1212// limitations under the License.
1313
1414//! # Multi-turn dialogue
15- //! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT).
15+ //! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT) or
16+ //! [GODEL](https://github.com/microsoft/GODEL).
1617//! This pipeline allows the generation of single or multi-turn conversations between a human and a model.
1718//! The DialoGPT's page states that
1819//! > The human evaluation results indicate that the response generated from DialoGPT is comparable to human response quality
5556//! from the 3rd party utilization of the pretrained system.
5657use crate :: common:: error:: RustBertError ;
5758use crate :: gpt2:: GPT2Generator ;
59+ use crate :: t5:: T5Generator ;
5860use crate :: pipelines:: common:: { ModelType , TokenizerOption } ;
5961use crate :: pipelines:: generation_utils:: private_generation_utils:: PrivateLanguageGenerator ;
6062use crate :: pipelines:: generation_utils:: { GenerateConfig , LanguageGenerator } ;
@@ -695,12 +697,14 @@ impl Default for ConversationManager {
695697pub enum ConversationOption {
696698 /// Conversation based on GPT2 model
697699 GPT2 ( GPT2Generator ) ,
700+ T5 ( T5Generator ) ,
698701}
699702
700703impl ConversationOption {
701704 pub fn new ( config : ConversationConfig ) -> Result < Self , RustBertError > {
702705 match config. model_type {
703706 ModelType :: GPT2 => Ok ( ConversationOption :: GPT2 ( GPT2Generator :: new ( config. into ( ) ) ?) ) ,
707+ ModelType :: T5 => Ok ( ConversationOption :: T5 ( T5Generator :: new ( config. into ( ) ) ?) ) ,
704708 _ => Err ( RustBertError :: InvalidConfigurationError (
705709 "GPT2 is currently the only supported model for conversation generation"
706710 . to_string ( ) ,
@@ -717,6 +721,10 @@ impl ConversationOption {
717721 config. into ( ) ,
718722 tokenizer,
719723 ) ?) ) ,
724+ ModelType :: T5 => Ok ( ConversationOption :: T5 ( T5Generator :: new_with_tokenizer (
725+ config. into ( ) ,
726+ tokenizer,
727+ ) ?) ) ,
720728 _ => Err ( RustBertError :: InvalidConfigurationError (
721729 "GPT2 is currently the only supported model for conversation generation"
722730 . to_string ( ) ,
@@ -729,27 +737,33 @@ impl ConversationOption {
729737 Self :: GPT2 ( model_ref) => {
730738 Ok ( * model_ref. get_eos_ids ( ) . as_ref ( ) . unwrap ( ) . first ( ) . unwrap ( ) )
731739 }
740+ Self :: T5 ( model_ref) => {
741+ Ok ( * model_ref. get_eos_ids ( ) . as_ref ( ) . unwrap ( ) . first ( ) . unwrap ( ) )
742+ }
732743 }
733744 }
734745
735746 /// Get a reference to the model tokenizer.
736747 pub fn get_tokenizer ( & self ) -> & TokenizerOption {
737748 match self {
738749 Self :: GPT2 ( model_ref) => model_ref. _get_tokenizer ( ) ,
750+ Self :: T5 ( model_ref) => model_ref. _get_tokenizer ( ) ,
739751 }
740752 }
741753
742754 /// Get a mutable reference to the model tokenizer.
743755 pub fn get_tokenizer_mut ( & mut self ) -> & TokenizerOption {
744756 match self {
745757 Self :: GPT2 ( model_ref) => model_ref. _get_tokenizer_mut ( ) ,
758+ Self :: T5 ( model_ref) => model_ref. _get_tokenizer_mut ( ) ,
746759 }
747760 }
748761
749762 /// Returns the `ModelType` for this ConversationOption
750763 pub fn model_type ( & self ) -> ModelType {
751764 match * self {
752765 Self :: GPT2 ( _) => ModelType :: GPT2 ,
766+ Self :: T5 ( _) => ModelType :: T5 ,
753767 }
754768 }
755769
@@ -765,6 +779,11 @@ impl ConversationOption {
765779 . into_iter ( )
766780 . map ( |output| output. indices )
767781 . collect ( ) ,
782+ Self :: T5 ( ref model) => model
783+ . generate_from_ids_and_past ( input_ids, attention_mask, None )
784+ . into_iter ( )
785+ . map ( |output| output. indices )
786+ . collect ( ) ,
768787 }
769788 }
770789}
0 commit comments