-
Notifications
You must be signed in to change notification settings - Fork 30.3k
Description
The error type is "TypeError: FakeTensor.new() got an unexpected keyword argument 'fake_device'"
The code:
import torch
from torch import nn
from torch._subclasses import FakeTensorMode
from typing import Optional, Callable, Dict, Any
from transformers import AutoModelForCausalLM, AutoTokenizer
class BaseModelLoader:
def init(
self,
model_provider: Optional[Callable],
model_kwargs: Dict[str, Any] = {},
fake_device_str: str = "cpu"
) -> None:
self.model_provider = model_provider
self.model_kwargs = model_kwargs
self.device_str = fake_device_str
def load(self) -> Optional[nn.Module]:
raise NotImplementedError("")
class ShadowModelLoader(BaseModelLoader):
def __init__(
self,
model_provider: Optional[Callable],
model_kwargs: Dict[str, Any] = {},
fake_device_str: str = "cpu"
) -> None:
super().__init__(model_provider, model_kwargs, fake_device_str)
torch.__future__.set_swap_module_params_on_conversion(False)
def load(self) -> Optional[nn.Module]:
with FakeTensorMode() as fake_mode:
model = self.model_provider(**self.model_kwargs) if self.model_provider else None
if model is None:
return None
self._configure_fake_tensors(model, self.device_str)
return model
def _configure_fake_tensors(self, model: nn.Module, device_str: str):
device = torch.device(device_str)
for name, param in model.named_parameters():
if hasattr(param, 'fake_device'):
param.fake_device = device
def get_huggingface_model(model_name: str) -> nn.Module:
return AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
if name == "main":
model_config = {
"model_name": "gpt2"
}
loader = ShadowModelLoader(
model_provider=get_huggingface_model,
model_kwargs=model_config,
fake_device_str="cpu"
)
model = loader.load()