diff --git a/pyproject.toml b/pyproject.toml index 7a170943..9517a6ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ classifiers = [ dependencies = [ "accelerate==0.34.2", "sentencepiece", "tokenizers>=0.12.1", "torch==2.3.0", "torchvision==0.18.0", - "transformers==4.46.0", + "transformers==4.48.3", "lm_eval==0.3.0", "texttable", "toml", "attributedict", "protobuf", diff --git a/tinychat/models/llama.py b/tinychat/models/llama.py index e64a4f49..5303d20a 100644 --- a/tinychat/models/llama.py +++ b/tinychat/models/llama.py @@ -10,6 +10,7 @@ import torch.nn.functional as F import awq_inference_engine from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding +from transformers import LlamaConfig # from flash_attn.flash_attn_interface import flash_attn_unpadded_func @@ -156,9 +157,8 @@ def __init__(self, args): .half() ) # added to half # dummy - self.rotary_emb = LlamaRotaryEmbedding( - self.head_dim, max_position_embeddings=2048, device="cuda:0" - ) + cfg = LlamaConfig(head_dim=self.head_dim, max_position_embeddings=2048) + self.rotary_emb = LlamaRotaryEmbedding(cfg, device="cuda:0") def forward( self, diff --git a/tinychat/models/mpt.py b/tinychat/models/mpt.py index 8e24164b..1a28f749 100644 --- a/tinychat/models/mpt.py +++ b/tinychat/models/mpt.py @@ -9,7 +9,6 @@ from torch import nn import torch.nn.functional as F import awq_inference_engine -from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding # from flash_attn.flash_attn_interface import flash_attn_unpadded_func diff --git a/tinychat/modules/fused_attn.py b/tinychat/modules/fused_attn.py index 9856feb8..9e979a3a 100644 --- a/tinychat/modules/fused_attn.py +++ b/tinychat/modules/fused_attn.py @@ -4,7 +4,6 @@ from torch.nn import functional as F from transformers.models.llama.modeling_llama import ( LlamaAttention, - LlamaRotaryEmbedding, apply_rotary_pos_emb, ) from typing import Optional