Skip to content

Commit 57285ae

Browse files
committed
add model and update model type to worker
1 parent 071fcfd commit 57285ae

File tree

2 files changed

+741
-319
lines changed

2 files changed

+741
-319
lines changed

xinference/core/supervisor.py

Lines changed: 34 additions & 319 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,7 +1190,7 @@ async def register_model(
11901190
@log_async(logger=logger)
11911191
async def add_model(self, model_type: str, model_json: Dict[str, Any]):
11921192
"""
1193-
Add a new model by parsing the provided JSON and registering it.
1193+
Add a new model by forwarding the request to all workers.
11941194
11951195
Args:
11961196
model_type: Type of model (LLM, embedding, image, etc.)
@@ -1199,204 +1199,30 @@ async def add_model(self, model_type: str, model_json: Dict[str, Any]):
11991199
logger.info(
12001200
f"[DEBUG SUPERVISOR] add_model called with model_type: {model_type}"
12011201
)
1202-
logger.info(f"[DEBUG SUPERVISOR] model_json type: {type(model_json)}")
1203-
logger.info(
1204-
f"[DEBUG SUPERVISOR] model_json keys: {list(model_json.keys()) if isinstance(model_json, dict) else 'Not a dict'}"
1205-
)
1206-
if isinstance(model_json, dict):
1207-
logger.info(f"[DEBUG SUPERVISOR] model_json content: {model_json}")
1208-
1209-
# Validate model type (with case normalization)
1210-
supported_types = list(self._custom_register_type_to_cls.keys())
1211-
logger.info(f"[DEBUG SUPERVISOR] Supported model types: {supported_types}")
1212-
logger.info(f"[DEBUG SUPERVISOR] Received model_type: '{model_type}'")
1213-
1214-
normalized_model_type = model_type
1215-
1216-
if model_type.lower() == "llm" and "LLM" in supported_types:
1217-
normalized_model_type = "LLM"
1218-
elif model_type.lower() == "llm" and "llm" in supported_types:
1219-
normalized_model_type = "llm"
1220-
1221-
logger.info(
1222-
f"[DEBUG SUPERVISOR] Normalized model_type: '{normalized_model_type}'"
1223-
)
1224-
1225-
if normalized_model_type not in self._custom_register_type_to_cls:
1226-
logger.error(
1227-
f"[DEBUG SUPERVISOR] Unsupported model type: {normalized_model_type} (original: {model_type})"
1228-
)
1229-
raise ValueError(
1230-
f"Unsupported model type '{model_type}'. "
1231-
f"Supported types are: {', '.join(supported_types)}"
1232-
)
1233-
1234-
# Use normalized model type for the rest of the function
1235-
model_type = normalized_model_type
1236-
logger.info(
1237-
f"[DEBUG SUPERVISOR] Using model_type: '{model_type}' for registration"
1238-
)
1239-
1240-
# Get the appropriate model class and register function
1241-
(
1242-
model_spec_cls,
1243-
register_fn,
1244-
unregister_fn,
1245-
generate_fn,
1246-
) = self._custom_register_type_to_cls[model_type]
1247-
logger.info(f"[DEBUG SUPERVISOR] Model spec class: {model_spec_cls}")
1248-
logger.info(f"[DEBUG SUPERVISOR] Register function: {register_fn}")
1249-
logger.info(f"[DEBUG SUPERVISOR] Unregister function: {unregister_fn}")
1250-
logger.info(f"[DEBUG SUPERVISOR] Generate function: {generate_fn}")
1251-
1252-
# Validate required fields (only model_name is required)
1253-
required_fields = ["model_name"]
1254-
logger.info(f"[DEBUG SUPERVISOR] Checking required fields: {required_fields}")
1255-
for field in required_fields:
1256-
if field not in model_json:
1257-
logger.error(f"[DEBUG SUPERVISOR] Missing required field: {field}")
1258-
raise ValueError(f"Missing required field: {field}")
1259-
1260-
# Validate model name format
1261-
from ..model.utils import is_valid_model_name
1262-
1263-
model_name = model_json["model_name"]
1264-
logger.info(f"[DEBUG SUPERVISOR] Extracted model_name: {model_name}")
1265-
1266-
if not is_valid_model_name(model_name):
1267-
logger.error(f"[DEBUG SUPERVISOR] Invalid model name format: {model_name}")
1268-
raise ValueError(f"Invalid model name format: {model_name}")
1269-
1270-
logger.info(f"[DEBUG SUPERVISOR] Model name validation passed")
1271-
1272-
# Convert model hub JSON format to Xinference expected format
1273-
logger.info(f"[DEBUG SUPERVISOR] Converting model JSON format...")
1274-
try:
1275-
converted_model_json = self._convert_model_json_format(model_json)
1276-
logger.info(
1277-
f"[DEBUG SUPERVISOR] Converted model JSON: {converted_model_json}"
1278-
)
1279-
except Exception as e:
1280-
logger.error(
1281-
f"[DEBUG SUPERVISOR] Format conversion failed: {str(e)}", exc_info=True
1282-
)
1283-
raise ValueError(f"Failed to convert model JSON format: {str(e)}")
1202+
logger.info(f"[DEBUG SUPERVISOR] Forwarding add_model request to all workers")
12841203

1285-
# Parse the JSON into the appropriate model spec
1286-
logger.info(f"[DEBUG SUPERVISOR] Parsing model spec...")
12871204
try:
1288-
model_spec = model_spec_cls.parse_obj(converted_model_json)
1289-
logger.info(f"[DEBUG SUPERVISOR] Parsed model spec: {model_spec}")
1290-
except Exception as e:
1291-
logger.error(
1292-
f"[DEBUG SUPERVISOR] Model spec parsing failed: {str(e)}", exc_info=True
1293-
)
1294-
raise ValueError(f"Invalid model JSON format: {str(e)}")
1295-
1296-
# Check if model already exists
1297-
logger.info(f"[DEBUG SUPERVISOR] Checking if model already exists...")
1298-
try:
1299-
existing_model = await self.get_model_registration(
1300-
model_type, model_spec.model_name
1301-
)
1302-
logger.info(
1303-
f"[DEBUG SUPERVISOR] Existing model check result: {existing_model}"
1304-
)
1305-
1306-
if existing_model is not None:
1307-
logger.error(
1308-
f"[DEBUG SUPERVISOR] Model already exists: {model_spec.model_name}"
1309-
)
1310-
raise ValueError(
1311-
f"Model '{model_spec.model_name}' already exists for type '{model_type}'. "
1312-
f"Please choose a different model name or remove the existing model first."
1313-
)
1314-
1315-
except ValueError as e:
1316-
if "not found" in str(e):
1317-
# Model doesn't exist, we can proceed
1318-
logger.info(
1319-
f"[DEBUG SUPERVISOR] Model doesn't exist yet, proceeding with registration"
1320-
)
1321-
pass
1205+
# Forward the add_model request to all workers
1206+
tasks = []
1207+
for worker_address, worker_ref in self._worker_address_to_worker.items():
1208+
logger.info(f"[DEBUG SUPERVISOR] Forwarding add_model to worker: {worker_address}")
1209+
tasks.append(worker_ref.add_model(model_type, model_json))
1210+
1211+
# Wait for all workers to complete the operation
1212+
if tasks:
1213+
await asyncio.gather(*tasks, return_exceptions=True)
1214+
logger.info(f"[DEBUG SUPERVISOR] All workers completed add_model operation")
13221215
else:
1323-
# Re-raise validation errors
1324-
logger.error(
1325-
f"[DEBUG SUPERVISOR] Validation error during model check: {str(e)}"
1326-
)
1327-
raise e
1328-
except Exception as ex:
1329-
logger.error(
1330-
f"[DEBUG SUPERVISOR] Unexpected error during model check: {str(ex)}",
1331-
exc_info=True,
1332-
)
1333-
raise ValueError(f"Failed to validate model registration: {str(ex)}")
1334-
1335-
logger.info(f"[DEBUG SUPERVISOR] Storing single model as built-in...")
1336-
try:
1337-
# Create CacheManager and store as built-in model
1338-
from ..model.cache_manager import CacheManager
1216+
logger.warning(f"[DEBUG SUPERVISOR] No workers available to forward add_model request")
13391217

1340-
cache_manager = CacheManager(model_spec)
1341-
cache_manager.register_builtin_model(model_type.lower())
1342-
logger.info(f"[DEBUG SUPERVISOR] Built-in model stored successfully")
1218+
logger.info(f"[DEBUG SUPERVISOR] add_model completed successfully")
13431219

1344-
# Register in the model registry without persisting to avoid duplicate storage
1345-
register_fn(model_spec, persist=False)
1346-
logger.info(
1347-
f"[DEBUG SUPERVISOR] Model registry registration completed successfully"
1348-
)
1349-
1350-
# Record model version
1351-
logger.info(f"[DEBUG SUPERVISOR] Generating version info...")
1352-
version_info = generate_fn(model_spec)
1353-
logger.info(f"[DEBUG SUPERVISOR] Generated version_info: {version_info}")
1354-
1355-
logger.info(
1356-
f"[DEBUG SUPERVISOR] Recording model version in cache tracker..."
1357-
)
1358-
await self._cache_tracker_ref.record_model_version(
1359-
version_info, self.address
1360-
)
1361-
logger.info(f"[DEBUG SUPERVISOR] Cache tracker recording completed")
1362-
1363-
# Sync to workers if not local deployment
1364-
is_local = self.is_local_deployment()
1365-
logger.info(f"[DEBUG SUPERVISOR] Is local deployment: {is_local}")
1366-
if not is_local:
1367-
# Convert back to JSON string for sync compatibility
1368-
model_json_str = json.dumps(converted_model_json)
1369-
logger.info(f"[DEBUG SUPERVISOR] Syncing model to workers...")
1370-
await self._sync_register_model(
1371-
model_type, model_json_str, True, model_spec.model_name
1372-
)
1373-
logger.info(f"[DEBUG SUPERVISOR] Model sync to workers completed")
1374-
1375-
logger.info(
1376-
f"Successfully added model '{model_spec.model_name}' (type: {model_type})"
1377-
)
1378-
1379-
except ValueError as e:
1380-
# Validation errors - don't need cleanup as model wasn't registered
1381-
logger.error(f"[DEBUG SUPERVISOR] ValueError during registration: {str(e)}")
1382-
raise e
13831220
except Exception as e:
1384-
# Unexpected errors - attempt cleanup
13851221
logger.error(
1386-
f"[DEBUG SUPERVISOR] Unexpected error during registration: {str(e)}",
1222+
f"[DEBUG SUPERVISOR] Error during add_model forwarding: {str(e)}",
13871223
exc_info=True,
13881224
)
1389-
try:
1390-
logger.info(f"[DEBUG SUPERVISOR] Attempting cleanup...")
1391-
unregister_fn(model_spec.model_name, raise_error=False)
1392-
logger.info(f"[DEBUG SUPERVISOR] Cleanup completed successfully")
1393-
except Exception as cleanup_error:
1394-
logger.warning(f"[DEBUG SUPERVISOR] Cleanup failed: {cleanup_error}")
1395-
raise ValueError(
1396-
f"Failed to register model '{model_spec.model_name}': {str(e)}"
1397-
)
1398-
1399-
logger.info(f"[DEBUG SUPERVISOR] add_model completed successfully")
1225+
raise ValueError(f"Failed to add model: {str(e)}")
14001226

14011227
def _convert_model_json_format(self, model_json: Dict[str, Any]) -> Dict[str, Any]:
14021228
"""
@@ -1622,150 +1448,39 @@ async def _sync_register_model(
16221448
@log_async(logger=logger)
16231449
async def update_model_type(self, model_type: str):
16241450
"""
1625-
Update model configurations for a specific model type by downloading
1626-
the latest JSON from the remote API and storing it locally.
1451+
Update model configurations for a specific model type by forwarding
1452+
the request to all workers.
16271453
16281454
Args:
16291455
model_type: Type of model (LLM, embedding, image, etc.)
16301456
"""
1631-
import json
1632-
1633-
import requests
1634-
16351457
logger.info(
16361458
f"[DEBUG SUPERVISOR] update_model_type called with model_type: {model_type}"
16371459
)
1638-
1639-
supported_types = list(self._custom_register_type_to_cls.keys())
1640-
1641-
normalized_for_validation = model_type
1642-
if model_type.lower() == "llm" and "LLM" in supported_types:
1643-
normalized_for_validation = "LLM"
1644-
elif model_type.lower() == "llm" and "llm" in supported_types:
1645-
normalized_for_validation = "llm"
1646-
1647-
if normalized_for_validation not in supported_types:
1648-
logger.error(
1649-
f"[DEBUG SUPERVISOR] Unsupported model type: {normalized_for_validation}"
1650-
)
1651-
raise ValueError(
1652-
f"Unsupported model type '{model_type}'. "
1653-
f"Supported types are: {', '.join(supported_types)}"
1654-
)
1655-
1656-
model_type_for_operations = normalized_for_validation
1657-
logger.info(
1658-
f"[DEBUG SUPERVISOR] Using model_type: '{model_type_for_operations}' for operations"
1659-
)
1660-
1661-
# Construct the URL to download JSON
1662-
url = f"https://model.xinference.io/api/models/download?model_type={model_type.lower()}"
1663-
logger.info(f"[DEBUG SUPERVISOR] Downloading model configurations from: {url}")
1460+
logger.info(f"[DEBUG SUPERVISOR] Forwarding update_model_type request to all workers")
16641461

16651462
try:
1666-
# Download JSON from remote API
1667-
response = requests.get(url, timeout=30)
1668-
response.raise_for_status()
1669-
1670-
# Parse JSON response
1671-
model_data = response.json()
1672-
logger.info(
1673-
f"[DEBUG SUPERVISOR] Successfully downloaded JSON for model type: {model_type}"
1674-
)
1675-
logger.info(f"[DEBUG SUPERVISOR] JSON data type: {type(model_data)}")
1676-
1677-
if isinstance(model_data, dict):
1678-
logger.info(
1679-
f"[DEBUG SUPERVISOR] JSON data keys: {list(model_data.keys())}"
1680-
)
1681-
elif isinstance(model_data, list):
1682-
logger.info(
1683-
f"[DEBUG SUPERVISOR] JSON data contains {len(model_data)} items"
1684-
)
1685-
if model_data:
1686-
logger.info(
1687-
f"[DEBUG SUPERVISOR] First item keys: {list(model_data[0].keys()) if isinstance(model_data[0], dict) else 'Not a dict'}"
1688-
)
1689-
1690-
# Store the JSON data using CacheManager as built-in models
1691-
logger.info(
1692-
f"[DEBUG SUPERVISOR] Storing model configurations as built-in models..."
1693-
)
1694-
await self._store_model_configurations(model_type, model_data)
1695-
logger.info(
1696-
f"[DEBUG SUPERVISOR] Built-in model configurations stored successfully"
1697-
)
1698-
1699-
# Dynamically reload built-in models to make them immediately available
1700-
logger.info(
1701-
f"[DEBUG SUPERVISOR] Reloading built-in models for immediate availability..."
1702-
)
1703-
try:
1704-
if model_type.lower() == "llm":
1705-
from ..model.llm import register_builtin_model
1706-
1707-
register_builtin_model()
1708-
logger.info(f"[DEBUG SUPERVISOR] LLM models reloaded successfully")
1709-
elif model_type.lower() == "embedding":
1710-
from ..model.embedding import register_builtin_model
1711-
1712-
register_builtin_model()
1713-
logger.info(
1714-
f"[DEBUG SUPERVISOR] Embedding models reloaded successfully"
1715-
)
1716-
elif model_type.lower() == "audio":
1717-
from ..model.audio import register_builtin_model
1718-
1719-
register_builtin_model()
1720-
logger.info(
1721-
f"[DEBUG SUPERVISOR] Audio models reloaded successfully"
1722-
)
1723-
elif model_type.lower() == "image":
1724-
from ..model.image import register_builtin_model
1725-
1726-
register_builtin_model()
1727-
logger.info(
1728-
f"[DEBUG SUPERVISOR] Image models reloaded successfully"
1729-
)
1730-
elif model_type.lower() == "rerank":
1731-
from ..model.rerank import register_builtin_model
1732-
1733-
register_builtin_model()
1734-
logger.info(
1735-
f"[DEBUG SUPERVISOR] Rerank models reloaded successfully"
1736-
)
1737-
elif model_type.lower() == "video":
1738-
from ..model.video import register_builtin_model
1463+
# Forward the update_model_type request to all workers
1464+
tasks = []
1465+
for worker_address, worker_ref in self._worker_address_to_worker.items():
1466+
logger.info(f"[DEBUG SUPERVISOR] Forwarding update_model_type to worker: {worker_address}")
1467+
tasks.append(worker_ref.update_model_type(model_type))
1468+
1469+
# Wait for all workers to complete the operation
1470+
if tasks:
1471+
await asyncio.gather(*tasks, return_exceptions=True)
1472+
logger.info(f"[DEBUG SUPERVISOR] All workers completed update_model_type operation")
1473+
else:
1474+
logger.warning(f"[DEBUG SUPERVISOR] No workers available to forward update_model_type request")
17391475

1740-
register_builtin_model()
1741-
logger.info(
1742-
f"[DEBUG SUPERVISOR] Video models reloaded successfully"
1743-
)
1744-
else:
1745-
logger.warning(
1746-
f"[DEBUG SUPERVISOR] No dynamic loading available for model type: {model_type}"
1747-
)
1748-
except Exception as reload_error:
1749-
logger.error(
1750-
f"[DEBUG SUPERVISOR] Error reloading built-in models: {reload_error}",
1751-
exc_info=True,
1752-
)
1753-
# Don't fail the update if reload fails, just log the error
1476+
logger.info(f"[DEBUG SUPERVISOR] update_model_type completed successfully")
17541477

1755-
except requests.exceptions.RequestException as e:
1756-
logger.error(
1757-
f"[DEBUG SUPERVISOR] Network error downloading model configurations: {e}"
1758-
)
1759-
raise ValueError(f"Failed to download model configurations: {str(e)}")
1760-
except json.JSONDecodeError as e:
1761-
logger.error(f"[DEBUG SUPERVISOR] JSON decode error: {e}")
1762-
raise ValueError(f"Invalid JSON response from remote API: {str(e)}")
17631478
except Exception as e:
17641479
logger.error(
1765-
f"[DEBUG SUPERVISOR] Unexpected error during model update: {e}",
1480+
f"[DEBUG SUPERVISOR] Error during update_model_type forwarding: {str(e)}",
17661481
exc_info=True,
17671482
)
1768-
raise ValueError(f"Failed to update model configurations: {str(e)}")
1483+
raise ValueError(f"Failed to update model type: {str(e)}")
17691484

17701485
async def _store_model_configurations(self, model_type: str, model_data):
17711486
"""

0 commit comments

Comments
 (0)