Skip to content

Commit b4115a4

Browse files
authored
[Ernie 4.5] Add ernie text models (#39228)
* init * copied from remote * add proper structure and llama like structure * fixup * revert to state that works * get closer to llama * slow and steady * some removal * masks work * it is indeed the rope implementation, how dafuq does it mesh with the cache now hmm * nice * getting closer * closer to transformers style * let's simplify this, batching works now * simplified * working version with modular * it is indeed the rotation per weights, make it complete llama style * cleanup conversion, next to look at -> tokenizer * remove llama artefacts * fix modeling tests (common ones) * style * integration test + first look into tokenization (will need more work, focussing on modeling other models first) * style * working moe version, based on remote * lets keep it simple and go step by step - transformers annotations for modular and transformers style rope (complex view) * more cleanup * refactor namings and remove addition forXXX classes * our moe won't cut it it seems, correction bias seems to be missing in remote code version * tokenization change (remote) * our moe version works when adding normalization :D * cleanup moe * nits * cleanup modeling -> let's get to modular next * style * modular v1 * minor things + attempt at conversion (which doesn't work) * no conversion follow glm, fixup modular and other nits * modular cleanup * fixes * tests, tests, tests + some moe dtype forcing * simplify modular, fix fatal fa2 bug, remaining tests * fix import issue? * some initial docs, fix bnb faulty behavior --> needs to fix some tests because of gate needing to be float * fix sdpa test, load on init dtype only * fixup post merge * style * fix doc links * tokenization cleanup beginnings * simplify tokenizer by a lot as its basically llama * tokenizer is full llama with different defaults + extra special tokens * sync og special tokens of ernie * fix decoding with numbers (also in remote done what a timing), begin of tok tests * align with remote and preserve special tokens, adjust tests to ernie legacy behavior, warning for questionable behavior (also in llama) * nits * docs * my daily post merge it is * check * tokenization update with explanations and conversion script * review on modular (til), revert some tokenizer things i did prior, remove mtp comment (low prio) * post merge fixes * fixup tokenization, llama fast is the way to go * more fixups * check * import fixes * correction bias following the paddle code * fix * fix TP plan, fix correction bias sharding during forward * style * whoops * fix tied weights * docs and last nit * license * flasky tests * move repo id, update when merged on the hub
1 parent 69b1582 commit b4115a4

23 files changed

+2956
-2
lines changed

docs/source/en/_toctree.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,10 @@
441441
title: Encoder Decoder Models
442442
- local: model_doc/ernie
443443
title: ERNIE
444+
- local: model_doc/ernie4_5
445+
title: Ernie4_5
446+
- local: model_doc/ernie4_5_moe
447+
title: Ernie4_5_MoE
444448
- local: model_doc/ernie_m
445449
title: ErnieM
446450
- local: model_doc/esm

docs/source/en/model_doc/ernie4_5.md

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
<div style="float: right;">
18+
<div class="flex flex-wrap space-x-1">
19+
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
20+
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
21+
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
22+
<img alt="Tensor parallelism" src="https://img.shields.io/badge/Tensor%20parallelism-06b6d4?style=flat&logoColor=white">
23+
</div>
24+
</div>
25+
26+
# Ernie 4.5
27+
28+
## Overview
29+
30+
The Ernie 4.5 model was released in the [Ernie 4.5 Model Family](https://ernie.baidu.com/blog/posts/ernie4.5/) release by baidu.
31+
This family of models contains multiple different architectures and model sizes. This model in specific targets the base text
32+
model without mixture of experts (moe) with 0.3B parameters in total. It uses the standard [Llama](./llama.md) at its core.
33+
34+
Other models from the family can be found at [Ernie 4.5 MoE](./ernie4_5_moe.md).
35+
36+
<div class="flex justify-center">
37+
<img src="https://ernie.baidu.com/blog/posts/ernie4.5/overview.png"/>
38+
</div>
39+
40+
41+
## Usage Tips
42+
43+
### Generate text
44+
45+
```python
46+
import torch
47+
from transformers import AutoModelForCausalLM, AutoTokenizer
48+
49+
model_name = "baidu/ERNIE-4.5-0.3B-PT"
50+
51+
# load the tokenizer and the model
52+
tokenizer = AutoTokenizer.from_pretrained(model_name)
53+
model = AutoModelForCausalLM.from_pretrained(
54+
model_name,
55+
device_map="auto",
56+
torch_dtype=torch.bfloat16,
57+
)
58+
59+
# prepare the model input
60+
inputs = tokenizer("Hey, are you conscious? Can you talk to me?", return_tensors="pt")
61+
prompt = "Hey, are you conscious? Can you talk to me?"
62+
messages = [
63+
{"role": "user", "content": prompt}
64+
]
65+
text = tokenizer.apply_chat_template(
66+
messages,
67+
tokenize=False,
68+
add_generation_prompt=True
69+
)
70+
model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device)
71+
72+
# conduct text completion
73+
generated_ids = model.generate(
74+
**model_inputs,
75+
max_new_tokens=32,
76+
)
77+
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
78+
79+
# decode the generated ids
80+
generate_text = tokenizer.decode(output_ids, skip_special_tokens=True)
81+
```
82+
83+
This model was contributed by [Anton Vlasjuk](https://huggingface.co/AntonV).
84+
The original code can be found [here](https://github.com/PaddlePaddle/ERNIE).
85+
86+
87+
## Ernie4_5Config
88+
89+
[[autodoc]] Ernie4_5Config
90+
91+
## Ernie4_5Model
92+
93+
[[autodoc]] Ernie4_5Model
94+
- forward
95+
96+
## Ernie4_5ForCausalLM
97+
98+
[[autodoc]] Ernie4_5ForCausalLM
99+
- forward
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
<div style="float: right;">
18+
<div class="flex flex-wrap space-x-1">
19+
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
20+
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
21+
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
22+
<img alt="Tensor parallelism" src="https://img.shields.io/badge/Tensor%20parallelism-06b6d4?style=flat&logoColor=white">
23+
</div>
24+
</div>
25+
26+
# Ernie 4.5 MoE
27+
28+
## Overview
29+
30+
The Ernie 4.5 MoE model was released in the [Ernie 4.5 Model Family](https://ernie.baidu.com/blog/posts/ernie4.5/) release by baidu.
31+
This family of models contains multiple different architectures and model sizes. This model in specific targets the base text
32+
model with mixture of experts (moe) - one with 21B total, 3B active parameters and another one with 300B total, 47B active parameters.
33+
It uses the standard [Llama](./llama.md) at its core combined with a specialized MoE based on [Mixtral](./mixtral.md) with additional shared
34+
experts.
35+
36+
Other models from the family can be found at [Ernie 4.5](./ernie4_5.md).
37+
38+
<div class="flex justify-center">
39+
<img src="https://ernie.baidu.com/blog/posts/ernie4.5/overview.png"/>
40+
</div>
41+
42+
43+
## Usage Tips
44+
45+
### Generate text
46+
47+
```python
48+
import torch
49+
from transformers import AutoModelForCausalLM, AutoTokenizer
50+
51+
model_name = "baidu/ERNIE-4.5-21B-A3B-PT"
52+
53+
# load the tokenizer and the model
54+
tokenizer = AutoTokenizer.from_pretrained(model_name)
55+
model = AutoModelForCausalLM.from_pretrained(
56+
model_name,
57+
device_map="auto",
58+
torch_dtype=torch.bfloat16,
59+
)
60+
61+
# prepare the model input
62+
inputs = tokenizer("Hey, are you conscious? Can you talk to me?", return_tensors="pt")
63+
prompt = "Hey, are you conscious? Can you talk to me?"
64+
messages = [
65+
{"role": "user", "content": prompt}
66+
]
67+
text = tokenizer.apply_chat_template(
68+
messages,
69+
tokenize=False,
70+
add_generation_prompt=True
71+
)
72+
model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device)
73+
74+
# conduct text completion
75+
generated_ids = model.generate(
76+
**model_inputs,
77+
max_new_tokens=32,
78+
)
79+
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
80+
81+
# decode the generated ids
82+
generate_text = tokenizer.decode(output_ids, skip_special_tokens=True)
83+
```
84+
85+
### Distributed Generation with Tensor Parallelism
86+
87+
```python
88+
import torch
89+
from transformers import AutoModelForCausalLM, AutoTokenizer
90+
91+
model_name = "baidu/ERNIE-4.5-21B-A3B-PT"
92+
93+
# load the tokenizer and the model
94+
tokenizer = AutoTokenizer.from_pretrained(model_name)
95+
model = AutoModelForCausalLM.from_pretrained(
96+
model_name,
97+
device_map="auto",
98+
torch_dtype=torch.bfloat16,
99+
tp_plan="auto",
100+
)
101+
102+
# prepare the model input
103+
inputs = tokenizer("Hey, are you conscious? Can you talk to me?", return_tensors="pt")
104+
prompt = "Hey, are you conscious? Can you talk to me?"
105+
messages = [
106+
{"role": "user", "content": prompt}
107+
]
108+
text = tokenizer.apply_chat_template(
109+
messages,
110+
tokenize=False,
111+
add_generation_prompt=True
112+
)
113+
model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device)
114+
115+
# conduct text completion
116+
generated_ids = model.generate(
117+
**model_inputs,
118+
max_new_tokens=32,
119+
)
120+
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
121+
122+
# decode the generated ids
123+
generate_text = tokenizer.decode(output_ids, skip_special_tokens=True)
124+
```
125+
126+
### Quantization with Bitsandbytes
127+
128+
```python
129+
import torch
130+
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer
131+
132+
model_name = "baidu/ERNIE-4.5-21B-A3B-PT"
133+
134+
# load the tokenizer and the model
135+
tokenizer = AutoTokenizer.from_pretrained(model_name)
136+
model = AutoModelForCausalLM.from_pretrained(
137+
model_name,
138+
device_map="auto",
139+
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
140+
)
141+
142+
# prepare the model input
143+
inputs = tokenizer("Hey, are you conscious? Can you talk to me?", return_tensors="pt")
144+
prompt = "Hey, are you conscious? Can you talk to me?"
145+
messages = [
146+
{"role": "user", "content": prompt}
147+
]
148+
text = tokenizer.apply_chat_template(
149+
messages,
150+
tokenize=False,
151+
add_generation_prompt=True
152+
)
153+
model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device)
154+
155+
# conduct text completion
156+
generated_ids = model.generate(
157+
**model_inputs,
158+
max_new_tokens=32,
159+
)
160+
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
161+
162+
# decode the generated ids
163+
generate_text = tokenizer.decode(output_ids, skip_special_tokens=True)
164+
```
165+
166+
This model was contributed by [Anton Vlasjuk](https://huggingface.co/AntonV).
167+
The original code can be found [here](https://github.com/PaddlePaddle/ERNIE).
168+
169+
170+
## Ernie4_5_MoEConfig
171+
172+
[[autodoc]] Ernie4_5_MoEConfig
173+
174+
## Ernie4_5_MoEModel
175+
176+
[[autodoc]] Ernie4_5_MoEModel
177+
- forward
178+
179+
## Ernie4_5_MoEForCausalLM
180+
181+
[[autodoc]] Ernie4_5_MoEForCausalLM
182+
- forward
183+
- generate

src/transformers/modeling_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3129,6 +3129,17 @@ def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
31293129
else:
31303130
output_embeddings.weight = input_embeddings.weight
31313131

3132+
# Passing hooks over to the embeddings if needed
3133+
# (currently limited to tensor parallel hooks and flags only)
3134+
if hasattr(input_embeddings, "_is_hooked") and getattr(input_embeddings, "_hf_tp_plan", None):
3135+
output_embeddings._is_hooked = input_embeddings._is_hooked
3136+
output_embeddings._hf_tp_plan = input_embeddings._hf_tp_plan
3137+
output_embeddings._forward_hooks = input_embeddings._forward_hooks
3138+
output_embeddings._forward_pre_hooks = input_embeddings._forward_pre_hooks
3139+
output_embeddings.__repr__ = (
3140+
lambda: f"{output_embeddings.__repr__()}\nTP Plan: {output_embeddings._hf_tp_plan}"
3141+
)
3142+
31323143
if getattr(output_embeddings, "bias", None) is not None:
31333144
output_embeddings.bias.data = nn.functional.pad(
31343145
output_embeddings.bias.data,

src/transformers/models/auto/configuration_auto.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@
128128
("encoder-decoder", "EncoderDecoderConfig"),
129129
("eomt", "EomtConfig"),
130130
("ernie", "ErnieConfig"),
131+
("ernie4_5", "Ernie4_5Config"),
132+
("ernie4_5_moe", "Ernie4_5_MoEConfig"),
131133
("ernie_m", "ErnieMConfig"),
132134
("esm", "EsmConfig"),
133135
("falcon", "FalconConfig"),
@@ -520,6 +522,8 @@
520522
("encoder-decoder", "Encoder decoder"),
521523
("eomt", "EoMT"),
522524
("ernie", "ERNIE"),
525+
("ernie4_5", "Ernie4_5"),
526+
("ernie4_5_moe", "Ernie4_5_MoE"),
523527
("ernie_m", "ErnieM"),
524528
("esm", "ESM"),
525529
("falcon", "Falcon"),

src/transformers/models/auto/modeling_auto.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@
119119
("emu3", "Emu3Model"),
120120
("encodec", "EncodecModel"),
121121
("ernie", "ErnieModel"),
122+
("ernie4_5", "Ernie4_5Model"),
123+
("ernie4_5_moe", "Ernie4_5_MoEModel"),
122124
("ernie_m", "ErnieMModel"),
123125
("esm", "EsmModel"),
124126
("falcon", "FalconModel"),
@@ -594,6 +596,8 @@
594596
("electra", "ElectraForCausalLM"),
595597
("emu3", "Emu3ForCausalLM"),
596598
("ernie", "ErnieForCausalLM"),
599+
("ernie4_5", "Ernie4_5ForCausalLM"),
600+
("ernie4_5_moe", "Ernie4_5_MoEForCausalLM"),
597601
("falcon", "FalconForCausalLM"),
598602
("falcon_h1", "FalconH1ForCausalLM"),
599603
("falcon_mamba", "FalconMambaForCausalLM"),

src/transformers/models/auto/tokenization_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@
212212
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
213213
("emu3", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
214214
("ernie", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
215+
("ernie4_5", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
216+
("ernie4_5_moe", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
215217
("ernie_m", ("ErnieMTokenizer" if is_sentencepiece_available() else None, None)),
216218
("esm", ("EsmTokenizer", None)),
217219
("falcon", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import TYPE_CHECKING
15+
16+
from ...utils import _LazyModule
17+
from ...utils.import_utils import define_import_structure
18+
19+
20+
if TYPE_CHECKING:
21+
from .configuration_ernie4_5 import *
22+
from .modeling_ernie4_5 import *
23+
else:
24+
import sys
25+
26+
_file = globals()["__file__"]
27+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

0 commit comments

Comments
 (0)