Skip to content

torch fake_tensor load hf model failed #39217

@SandyWang85

Description

@SandyWang85

The error type is "TypeError: FakeTensor.new() got an unexpected keyword argument 'fake_device'"

The code:
import torch
from torch import nn
from torch._subclasses import FakeTensorMode
from typing import Optional, Callable, Dict, Any
from transformers import AutoModelForCausalLM, AutoTokenizer

class BaseModelLoader:
def init(
self,
model_provider: Optional[Callable],
model_kwargs: Dict[str, Any] = {},
fake_device_str: str = "cpu"
) -> None:
self.model_provider = model_provider
self.model_kwargs = model_kwargs
self.device_str = fake_device_str

def load(self) -> Optional[nn.Module]:
    raise NotImplementedError("")

class ShadowModelLoader(BaseModelLoader):

def __init__(
    self, 
    model_provider: Optional[Callable],
    model_kwargs: Dict[str, Any] = {},
    fake_device_str: str = "cpu"
) -> None:
    super().__init__(model_provider, model_kwargs, fake_device_str)
    torch.__future__.set_swap_module_params_on_conversion(False)

def load(self) -> Optional[nn.Module]:
   
    with FakeTensorMode() as fake_mode:
        model = self.model_provider(**self.model_kwargs) if self.model_provider else None
        if model is None:
            return None

        self._configure_fake_tensors(model, self.device_str)            
        return model

def _configure_fake_tensors(self, model: nn.Module, device_str: str):
    device = torch.device(device_str)
    for name, param in model.named_parameters():
        if hasattr(param, 'fake_device'):
            param.fake_device = device  

def get_huggingface_model(model_name: str) -> nn.Module:
return AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)

if name == "main":
model_config = {
"model_name": "gpt2"
}

loader = ShadowModelLoader(
    model_provider=get_huggingface_model,
    model_kwargs=model_config,
    fake_device_str="cpu"
)

model = loader.load()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions