Skip to content

Commit 7e00a22

Browse files
committed
Added GODEL support
1 parent 9fd7983 commit 7e00a22

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

src/pipelines/conversation.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
// limitations under the License.
1313

1414
//! # Multi-turn dialogue
15-
//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT).
15+
//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT) or
16+
//! [GODEL](https://github.com/microsoft/GODEL).
1617
//! This pipeline allows the generation of single or multi-turn conversations between a human and a model.
1718
//! The DialoGPT's page states that
1819
//! > The human evaluation results indicate that the response generated from DialoGPT is comparable to human response quality
@@ -55,6 +56,7 @@
5556
//! from the 3rd party utilization of the pretrained system.
5657
use crate::common::error::RustBertError;
5758
use crate::gpt2::GPT2Generator;
59+
use crate::t5::T5Generator;
5860
use crate::pipelines::common::{ModelType, TokenizerOption};
5961
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
6062
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
@@ -695,12 +697,14 @@ impl Default for ConversationManager {
695697
pub enum ConversationOption {
696698
/// Conversation based on GPT2 model
697699
GPT2(GPT2Generator),
700+
T5(T5Generator),
698701
}
699702

700703
impl ConversationOption {
701704
pub fn new(config: ConversationConfig) -> Result<Self, RustBertError> {
702705
match config.model_type {
703706
ModelType::GPT2 => Ok(ConversationOption::GPT2(GPT2Generator::new(config.into())?)),
707+
ModelType::T5 => Ok(ConversationOption::T5(T5Generator::new(config.into())?)),
704708
_ => Err(RustBertError::InvalidConfigurationError(
705709
"GPT2 is currently the only supported model for conversation generation"
706710
.to_string(),
@@ -717,6 +721,10 @@ impl ConversationOption {
717721
config.into(),
718722
tokenizer,
719723
)?)),
724+
ModelType::T5 => Ok(ConversationOption::T5(T5Generator::new_with_tokenizer(
725+
config.into(),
726+
tokenizer,
727+
)?)),
720728
_ => Err(RustBertError::InvalidConfigurationError(
721729
"GPT2 is currently the only supported model for conversation generation"
722730
.to_string(),
@@ -729,27 +737,33 @@ impl ConversationOption {
729737
Self::GPT2(model_ref) => {
730738
Ok(*model_ref.get_eos_ids().as_ref().unwrap().first().unwrap())
731739
}
740+
Self::T5(model_ref) => {
741+
Ok(*model_ref.get_eos_ids().as_ref().unwrap().first().unwrap())
742+
}
732743
}
733744
}
734745

735746
/// Get a reference to the model tokenizer.
736747
pub fn get_tokenizer(&self) -> &TokenizerOption {
737748
match self {
738749
Self::GPT2(model_ref) => model_ref._get_tokenizer(),
750+
Self::T5(model_ref) => model_ref._get_tokenizer(),
739751
}
740752
}
741753

742754
/// Get a mutable reference to the model tokenizer.
743755
pub fn get_tokenizer_mut(&mut self) -> &TokenizerOption {
744756
match self {
745757
Self::GPT2(model_ref) => model_ref._get_tokenizer_mut(),
758+
Self::T5(model_ref) => model_ref._get_tokenizer_mut(),
746759
}
747760
}
748761

749762
/// Returns the `ModelType` for this ConversationOption
750763
pub fn model_type(&self) -> ModelType {
751764
match *self {
752765
Self::GPT2(_) => ModelType::GPT2,
766+
Self::T5(_) => ModelType::T5,
753767
}
754768
}
755769

@@ -765,6 +779,11 @@ impl ConversationOption {
765779
.into_iter()
766780
.map(|output| output.indices)
767781
.collect(),
782+
Self::T5(ref model) => model
783+
.generate_from_ids_and_past(input_ids, attention_mask, None)
784+
.into_iter()
785+
.map(|output| output.indices)
786+
.collect(),
768787
}
769788
}
770789
}

src/t5/t5_model.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ impl T5ModelResources {
6161
"sentence-t5-base/model",
6262
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/rust_model.ot",
6363
);
64+
/// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq>. Modified with conversion to C-array format.
65+
pub const GODEL_V1_1_BASE: (&'static str, &'static str) = (
66+
"godel-v1-1-base/model",
67+
"https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq/resolve/main/rust_model.ot",
68+
);
6469
}
6570

6671
impl T5ConfigResources {
@@ -79,6 +84,11 @@ impl T5ConfigResources {
7984
"sentence-t5-base/config",
8085
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/config.json",
8186
);
87+
/// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq>. Modified with conversion to C-array format.
88+
pub const GODEL_V1_1_BASE: (&'static str, &'static str) = (
89+
"godel-v1-1-base/config",
90+
"https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq/resolve/main/config.json",
91+
);
8292
}
8393

8494
impl T5VocabResources {
@@ -97,6 +107,11 @@ impl T5VocabResources {
97107
"sentence-t5-base/spiece",
98108
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/spiece.model",
99109
);
110+
/// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq>. Modified with conversion to C-array format.
111+
pub const GODEL_V1_1_BASE: (&'static str, &'static str) = (
112+
"godel-v1-1-base/spiece",
113+
"https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq/resolve/main/spiece.model",
114+
);
100115
}
101116

102117
const T5LANGUAGES: [Language; 3] = [Language::English, Language::French, Language::German];

0 commit comments

Comments
 (0)