|
4 | 4 | # This source code is licensed under the BSD-style license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
| 7 | +import json |
| 8 | +import logging |
| 9 | +import os |
| 10 | +import re |
7 | 11 | from abc import ABC, abstractmethod
|
8 | 12 | from typing import Any
|
9 | 13 |
|
| 14 | +logger = logging.getLogger() |
| 15 | + |
10 | 16 | from .model import BaseModelArgs
|
11 | 17 |
|
12 | 18 |
|
13 |
| -class StateDictAdapter(ABC): |
| 19 | +class BaseStateDictAdapter(ABC): |
14 | 20 | """Abstract base class for state dict transformations.
|
15 | 21 |
|
16 | 22 | This class defines the interface for converting between native model
|
@@ -47,3 +53,28 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
|
47 | 53 | The converted native model state dict
|
48 | 54 | """
|
49 | 55 | pass
|
| 56 | + |
| 57 | + |
| 58 | +class StateDictAdapter(BaseStateDictAdapter): |
| 59 | + """State dict adapter base class which provides convenient default behavior to build fqn_to_index_mapping""" |
| 60 | + |
| 61 | + def __init__(self, model_args: BaseModelArgs, hf_assets_path: str | None): |
| 62 | + if hf_assets_path: |
| 63 | + mapping_path = os.path.join(hf_assets_path, "model.safetensors.index.json") |
| 64 | + try: |
| 65 | + with open(mapping_path, "r") as f: |
| 66 | + hf_safetensors_indx = json.load(f) |
| 67 | + except FileNotFoundError: |
| 68 | + logger.warning( |
| 69 | + "model.safetensors.index.json not found at hf_assets_path: {mapping_path}. \ |
| 70 | + Defaulting to saving a single safetensors file if checkpoint is saved in HF format.", |
| 71 | + ) |
| 72 | + hf_safetensors_indx = None |
| 73 | + |
| 74 | + if hf_safetensors_indx: |
| 75 | + self.fqn_to_index_mapping = {} |
| 76 | + for hf_key, raw_indx in hf_safetensors_indx["weight_map"].items(): |
| 77 | + indx = re.search(r"\d+", raw_indx).group(0) |
| 78 | + self.fqn_to_index_mapping[hf_key] = indx |
| 79 | + else: |
| 80 | + self.fqn_to_index_mapping = None |
0 commit comments