Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,12 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
"If False, auto-detect and use column+row (all_reduce) sharding when possible.",
)

use_sharding_from_factory: bool = Field(
default=False,
description="If True, use sharding from the model config (if present). "
"If False, run heuristics to detect sharding.",
)

compile_backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = (
Field(
default="torch-compile",
Expand Down
21 changes: 21 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/models/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import copy
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Callable, Dict, Optional, Type

import torch
Expand All @@ -12,6 +13,13 @@
from ..utils.logger import ad_logger


class ShardingConfigSource(Enum):
"""Enum for factory source."""

HUGGINGFACE = "huggingface"
UNKNOWN = "unknown"


class ModelFactory(ABC):
"""An interface to return and correctly initialize a model from a desired source.

Expand All @@ -38,6 +46,7 @@ def __init__(
self.max_seq_len = max_seq_len
self._prefetched_model_path: Optional[str] = None
self._prefetched_tokenizer_path: Optional[str] = None
self._sharding_config: Dict[str, Any] = {}

@property
def model(self) -> Optional[str]:
Expand Down Expand Up @@ -96,6 +105,10 @@ def get_quant_config(self) -> Dict:
"""Returns the quantization config for this model or None if not quantized."""
return {}

def get_sharding_config(self) -> Dict:
"""Returns the sharding config for this model."""
return self._sharding_config

def get_cache_config(self) -> CacheConfig:
"""Return the cache configuration for the model.

Expand All @@ -104,6 +117,14 @@ def get_cache_config(self) -> CacheConfig:
"""
return CacheConfig()

def get_sharding_config_source(self) -> ShardingConfigSource:
"""Return the source of the model factory.

Returns:
The source identifier for this model factory.
"""
return ShardingConfigSource.UNKNOWN

def init_tokenizer(self) -> Optional[Any]:
"""Initialize the tokenizer for the model.

Expand Down
36 changes: 35 additions & 1 deletion tensorrt_llm/_torch/auto_deploy/models/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ..custom_ops.attention_interface import CacheConfig
from ..utils._config import deep_merge_dicts
from ..utils.logger import ad_logger
from .factory import ModelFactory, ModelFactoryRegistry
from .factory import ModelFactory, ModelFactoryRegistry, ShardingConfigSource


@contextmanager
Expand Down Expand Up @@ -174,12 +174,25 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module:
if hasattr(model, "post_init"):
model.post_init()

# if present, initialize sharding config. We need head_dim for colwise sharding.
self._set_sharding_config(model.config)

# patch forward method
model.forward = types.MethodType(self._simple_forward, model)

model.eval()
return model

def _set_sharding_config(self, model_config: PretrainedConfig):
"""Set the sharding config for the model."""
self._sharding_config["head_dim"] = 1
if hasattr(model_config, "base_model_tp_plan"):
self._sharding_config["tp_plan"] = model_config.base_model_tp_plan
if hasattr(model_config, "head_dim"):
self._sharding_config["head_dim"] = model_config.head_dim
if hasattr(model_config, "num_hidden_layers"):
self._sharding_config["num_hidden_layers"] = model_config.num_hidden_layers

def get_quant_config(self) -> Dict:
return self._quant_config or {}

Expand All @@ -196,6 +209,14 @@ def get_cache_config(self):
kv_cache_dtype = None
return CacheConfig(dtype=kv_cache_dtype)

def get_sharding_config_source(self) -> ShardingConfigSource:
"""Return the source of the model factory.

Returns:
The source identifier for this model factory.
"""
return ShardingConfigSource.HUGGINGFACE

def init_tokenizer(self) -> Optional[Any]:
"""Initialize the tokenizer—either a custom name or the model's default."""
if self.tokenizer is None:
Expand Down Expand Up @@ -363,6 +384,19 @@ def _get_max_position_embeddings_config(self) -> Dict[str, Any]:
},
}

def _set_sharding_config(self, model_config: PretrainedConfig):
"""Override the sharding config for the model with text_config."""
super()._set_sharding_config(model_config)

if hasattr(model_config, "text_config"):
text_config = model_config.text_config
if hasattr(text_config, "base_model_tp_plan"):
self._sharding_config["tp_plan"] = text_config.base_model_tp_plan
if hasattr(text_config, "head_dim"):
self._sharding_config["head_dim"] = text_config.head_dim
if hasattr(text_config, "num_hidden_layers"):
self._sharding_config["num_hidden_layers"] = text_config.num_hidden_layers

@property
def automodel_from_config(self):
return AutoModelForImageTextToText.from_config
Loading