@@ -1190,7 +1190,7 @@ async def register_model(
11901190 @log_async (logger = logger )
11911191 async def add_model (self , model_type : str , model_json : Dict [str , Any ]):
11921192 """
1193- Add a new model by parsing the provided JSON and registering it .
1193+ Add a new model by forwarding the request to all workers .
11941194
11951195 Args:
11961196 model_type: Type of model (LLM, embedding, image, etc.)
@@ -1199,204 +1199,30 @@ async def add_model(self, model_type: str, model_json: Dict[str, Any]):
11991199 logger .info (
12001200 f"[DEBUG SUPERVISOR] add_model called with model_type: { model_type } "
12011201 )
1202- logger .info (f"[DEBUG SUPERVISOR] model_json type: { type (model_json )} " )
1203- logger .info (
1204- f"[DEBUG SUPERVISOR] model_json keys: { list (model_json .keys ()) if isinstance (model_json , dict ) else 'Not a dict' } "
1205- )
1206- if isinstance (model_json , dict ):
1207- logger .info (f"[DEBUG SUPERVISOR] model_json content: { model_json } " )
1208-
1209- # Validate model type (with case normalization)
1210- supported_types = list (self ._custom_register_type_to_cls .keys ())
1211- logger .info (f"[DEBUG SUPERVISOR] Supported model types: { supported_types } " )
1212- logger .info (f"[DEBUG SUPERVISOR] Received model_type: '{ model_type } '" )
1213-
1214- normalized_model_type = model_type
1215-
1216- if model_type .lower () == "llm" and "LLM" in supported_types :
1217- normalized_model_type = "LLM"
1218- elif model_type .lower () == "llm" and "llm" in supported_types :
1219- normalized_model_type = "llm"
1220-
1221- logger .info (
1222- f"[DEBUG SUPERVISOR] Normalized model_type: '{ normalized_model_type } '"
1223- )
1224-
1225- if normalized_model_type not in self ._custom_register_type_to_cls :
1226- logger .error (
1227- f"[DEBUG SUPERVISOR] Unsupported model type: { normalized_model_type } (original: { model_type } )"
1228- )
1229- raise ValueError (
1230- f"Unsupported model type '{ model_type } '. "
1231- f"Supported types are: { ', ' .join (supported_types )} "
1232- )
1233-
1234- # Use normalized model type for the rest of the function
1235- model_type = normalized_model_type
1236- logger .info (
1237- f"[DEBUG SUPERVISOR] Using model_type: '{ model_type } ' for registration"
1238- )
1239-
1240- # Get the appropriate model class and register function
1241- (
1242- model_spec_cls ,
1243- register_fn ,
1244- unregister_fn ,
1245- generate_fn ,
1246- ) = self ._custom_register_type_to_cls [model_type ]
1247- logger .info (f"[DEBUG SUPERVISOR] Model spec class: { model_spec_cls } " )
1248- logger .info (f"[DEBUG SUPERVISOR] Register function: { register_fn } " )
1249- logger .info (f"[DEBUG SUPERVISOR] Unregister function: { unregister_fn } " )
1250- logger .info (f"[DEBUG SUPERVISOR] Generate function: { generate_fn } " )
1251-
1252- # Validate required fields (only model_name is required)
1253- required_fields = ["model_name" ]
1254- logger .info (f"[DEBUG SUPERVISOR] Checking required fields: { required_fields } " )
1255- for field in required_fields :
1256- if field not in model_json :
1257- logger .error (f"[DEBUG SUPERVISOR] Missing required field: { field } " )
1258- raise ValueError (f"Missing required field: { field } " )
1259-
1260- # Validate model name format
1261- from ..model .utils import is_valid_model_name
1262-
1263- model_name = model_json ["model_name" ]
1264- logger .info (f"[DEBUG SUPERVISOR] Extracted model_name: { model_name } " )
1265-
1266- if not is_valid_model_name (model_name ):
1267- logger .error (f"[DEBUG SUPERVISOR] Invalid model name format: { model_name } " )
1268- raise ValueError (f"Invalid model name format: { model_name } " )
1269-
1270- logger .info (f"[DEBUG SUPERVISOR] Model name validation passed" )
1271-
1272- # Convert model hub JSON format to Xinference expected format
1273- logger .info (f"[DEBUG SUPERVISOR] Converting model JSON format..." )
1274- try :
1275- converted_model_json = self ._convert_model_json_format (model_json )
1276- logger .info (
1277- f"[DEBUG SUPERVISOR] Converted model JSON: { converted_model_json } "
1278- )
1279- except Exception as e :
1280- logger .error (
1281- f"[DEBUG SUPERVISOR] Format conversion failed: { str (e )} " , exc_info = True
1282- )
1283- raise ValueError (f"Failed to convert model JSON format: { str (e )} " )
1202+ logger .info (f"[DEBUG SUPERVISOR] Forwarding add_model request to all workers" )
12841203
1285- # Parse the JSON into the appropriate model spec
1286- logger .info (f"[DEBUG SUPERVISOR] Parsing model spec..." )
12871204 try :
1288- model_spec = model_spec_cls .parse_obj (converted_model_json )
1289- logger .info (f"[DEBUG SUPERVISOR] Parsed model spec: { model_spec } " )
1290- except Exception as e :
1291- logger .error (
1292- f"[DEBUG SUPERVISOR] Model spec parsing failed: { str (e )} " , exc_info = True
1293- )
1294- raise ValueError (f"Invalid model JSON format: { str (e )} " )
1295-
1296- # Check if model already exists
1297- logger .info (f"[DEBUG SUPERVISOR] Checking if model already exists..." )
1298- try :
1299- existing_model = await self .get_model_registration (
1300- model_type , model_spec .model_name
1301- )
1302- logger .info (
1303- f"[DEBUG SUPERVISOR] Existing model check result: { existing_model } "
1304- )
1305-
1306- if existing_model is not None :
1307- logger .error (
1308- f"[DEBUG SUPERVISOR] Model already exists: { model_spec .model_name } "
1309- )
1310- raise ValueError (
1311- f"Model '{ model_spec .model_name } ' already exists for type '{ model_type } '. "
1312- f"Please choose a different model name or remove the existing model first."
1313- )
1314-
1315- except ValueError as e :
1316- if "not found" in str (e ):
1317- # Model doesn't exist, we can proceed
1318- logger .info (
1319- f"[DEBUG SUPERVISOR] Model doesn't exist yet, proceeding with registration"
1320- )
1321- pass
1205+ # Forward the add_model request to all workers
1206+ tasks = []
1207+ for worker_address , worker_ref in self ._worker_address_to_worker .items ():
1208+ logger .info (f"[DEBUG SUPERVISOR] Forwarding add_model to worker: { worker_address } " )
1209+ tasks .append (worker_ref .add_model (model_type , model_json ))
1210+
1211+ # Wait for all workers to complete the operation
1212+ if tasks :
1213+ await asyncio .gather (* tasks , return_exceptions = True )
1214+ logger .info (f"[DEBUG SUPERVISOR] All workers completed add_model operation" )
13221215 else :
1323- # Re-raise validation errors
1324- logger .error (
1325- f"[DEBUG SUPERVISOR] Validation error during model check: { str (e )} "
1326- )
1327- raise e
1328- except Exception as ex :
1329- logger .error (
1330- f"[DEBUG SUPERVISOR] Unexpected error during model check: { str (ex )} " ,
1331- exc_info = True ,
1332- )
1333- raise ValueError (f"Failed to validate model registration: { str (ex )} " )
1334-
1335- logger .info (f"[DEBUG SUPERVISOR] Storing single model as built-in..." )
1336- try :
1337- # Create CacheManager and store as built-in model
1338- from ..model .cache_manager import CacheManager
1216+ logger .warning (f"[DEBUG SUPERVISOR] No workers available to forward add_model request" )
13391217
1340- cache_manager = CacheManager (model_spec )
1341- cache_manager .register_builtin_model (model_type .lower ())
1342- logger .info (f"[DEBUG SUPERVISOR] Built-in model stored successfully" )
1218+ logger .info (f"[DEBUG SUPERVISOR] add_model completed successfully" )
13431219
1344- # Register in the model registry without persisting to avoid duplicate storage
1345- register_fn (model_spec , persist = False )
1346- logger .info (
1347- f"[DEBUG SUPERVISOR] Model registry registration completed successfully"
1348- )
1349-
1350- # Record model version
1351- logger .info (f"[DEBUG SUPERVISOR] Generating version info..." )
1352- version_info = generate_fn (model_spec )
1353- logger .info (f"[DEBUG SUPERVISOR] Generated version_info: { version_info } " )
1354-
1355- logger .info (
1356- f"[DEBUG SUPERVISOR] Recording model version in cache tracker..."
1357- )
1358- await self ._cache_tracker_ref .record_model_version (
1359- version_info , self .address
1360- )
1361- logger .info (f"[DEBUG SUPERVISOR] Cache tracker recording completed" )
1362-
1363- # Sync to workers if not local deployment
1364- is_local = self .is_local_deployment ()
1365- logger .info (f"[DEBUG SUPERVISOR] Is local deployment: { is_local } " )
1366- if not is_local :
1367- # Convert back to JSON string for sync compatibility
1368- model_json_str = json .dumps (converted_model_json )
1369- logger .info (f"[DEBUG SUPERVISOR] Syncing model to workers..." )
1370- await self ._sync_register_model (
1371- model_type , model_json_str , True , model_spec .model_name
1372- )
1373- logger .info (f"[DEBUG SUPERVISOR] Model sync to workers completed" )
1374-
1375- logger .info (
1376- f"Successfully added model '{ model_spec .model_name } ' (type: { model_type } )"
1377- )
1378-
1379- except ValueError as e :
1380- # Validation errors - don't need cleanup as model wasn't registered
1381- logger .error (f"[DEBUG SUPERVISOR] ValueError during registration: { str (e )} " )
1382- raise e
13831220 except Exception as e :
1384- # Unexpected errors - attempt cleanup
13851221 logger .error (
1386- f"[DEBUG SUPERVISOR] Unexpected error during registration : { str (e )} " ,
1222+ f"[DEBUG SUPERVISOR] Error during add_model forwarding : { str (e )} " ,
13871223 exc_info = True ,
13881224 )
1389- try :
1390- logger .info (f"[DEBUG SUPERVISOR] Attempting cleanup..." )
1391- unregister_fn (model_spec .model_name , raise_error = False )
1392- logger .info (f"[DEBUG SUPERVISOR] Cleanup completed successfully" )
1393- except Exception as cleanup_error :
1394- logger .warning (f"[DEBUG SUPERVISOR] Cleanup failed: { cleanup_error } " )
1395- raise ValueError (
1396- f"Failed to register model '{ model_spec .model_name } ': { str (e )} "
1397- )
1398-
1399- logger .info (f"[DEBUG SUPERVISOR] add_model completed successfully" )
1225+ raise ValueError (f"Failed to add model: { str (e )} " )
14001226
14011227 def _convert_model_json_format (self , model_json : Dict [str , Any ]) -> Dict [str , Any ]:
14021228 """
@@ -1622,150 +1448,39 @@ async def _sync_register_model(
16221448 @log_async (logger = logger )
16231449 async def update_model_type (self , model_type : str ):
16241450 """
1625- Update model configurations for a specific model type by downloading
1626- the latest JSON from the remote API and storing it locally .
1451+ Update model configurations for a specific model type by forwarding
1452+ the request to all workers .
16271453
16281454 Args:
16291455 model_type: Type of model (LLM, embedding, image, etc.)
16301456 """
1631- import json
1632-
1633- import requests
1634-
16351457 logger .info (
16361458 f"[DEBUG SUPERVISOR] update_model_type called with model_type: { model_type } "
16371459 )
1638-
1639- supported_types = list (self ._custom_register_type_to_cls .keys ())
1640-
1641- normalized_for_validation = model_type
1642- if model_type .lower () == "llm" and "LLM" in supported_types :
1643- normalized_for_validation = "LLM"
1644- elif model_type .lower () == "llm" and "llm" in supported_types :
1645- normalized_for_validation = "llm"
1646-
1647- if normalized_for_validation not in supported_types :
1648- logger .error (
1649- f"[DEBUG SUPERVISOR] Unsupported model type: { normalized_for_validation } "
1650- )
1651- raise ValueError (
1652- f"Unsupported model type '{ model_type } '. "
1653- f"Supported types are: { ', ' .join (supported_types )} "
1654- )
1655-
1656- model_type_for_operations = normalized_for_validation
1657- logger .info (
1658- f"[DEBUG SUPERVISOR] Using model_type: '{ model_type_for_operations } ' for operations"
1659- )
1660-
1661- # Construct the URL to download JSON
1662- url = f"https://model.xinference.io/api/models/download?model_type={ model_type .lower ()} "
1663- logger .info (f"[DEBUG SUPERVISOR] Downloading model configurations from: { url } " )
1460+ logger .info (f"[DEBUG SUPERVISOR] Forwarding update_model_type request to all workers" )
16641461
16651462 try :
1666- # Download JSON from remote API
1667- response = requests .get (url , timeout = 30 )
1668- response .raise_for_status ()
1669-
1670- # Parse JSON response
1671- model_data = response .json ()
1672- logger .info (
1673- f"[DEBUG SUPERVISOR] Successfully downloaded JSON for model type: { model_type } "
1674- )
1675- logger .info (f"[DEBUG SUPERVISOR] JSON data type: { type (model_data )} " )
1676-
1677- if isinstance (model_data , dict ):
1678- logger .info (
1679- f"[DEBUG SUPERVISOR] JSON data keys: { list (model_data .keys ())} "
1680- )
1681- elif isinstance (model_data , list ):
1682- logger .info (
1683- f"[DEBUG SUPERVISOR] JSON data contains { len (model_data )} items"
1684- )
1685- if model_data :
1686- logger .info (
1687- f"[DEBUG SUPERVISOR] First item keys: { list (model_data [0 ].keys ()) if isinstance (model_data [0 ], dict ) else 'Not a dict' } "
1688- )
1689-
1690- # Store the JSON data using CacheManager as built-in models
1691- logger .info (
1692- f"[DEBUG SUPERVISOR] Storing model configurations as built-in models..."
1693- )
1694- await self ._store_model_configurations (model_type , model_data )
1695- logger .info (
1696- f"[DEBUG SUPERVISOR] Built-in model configurations stored successfully"
1697- )
1698-
1699- # Dynamically reload built-in models to make them immediately available
1700- logger .info (
1701- f"[DEBUG SUPERVISOR] Reloading built-in models for immediate availability..."
1702- )
1703- try :
1704- if model_type .lower () == "llm" :
1705- from ..model .llm import register_builtin_model
1706-
1707- register_builtin_model ()
1708- logger .info (f"[DEBUG SUPERVISOR] LLM models reloaded successfully" )
1709- elif model_type .lower () == "embedding" :
1710- from ..model .embedding import register_builtin_model
1711-
1712- register_builtin_model ()
1713- logger .info (
1714- f"[DEBUG SUPERVISOR] Embedding models reloaded successfully"
1715- )
1716- elif model_type .lower () == "audio" :
1717- from ..model .audio import register_builtin_model
1718-
1719- register_builtin_model ()
1720- logger .info (
1721- f"[DEBUG SUPERVISOR] Audio models reloaded successfully"
1722- )
1723- elif model_type .lower () == "image" :
1724- from ..model .image import register_builtin_model
1725-
1726- register_builtin_model ()
1727- logger .info (
1728- f"[DEBUG SUPERVISOR] Image models reloaded successfully"
1729- )
1730- elif model_type .lower () == "rerank" :
1731- from ..model .rerank import register_builtin_model
1732-
1733- register_builtin_model ()
1734- logger .info (
1735- f"[DEBUG SUPERVISOR] Rerank models reloaded successfully"
1736- )
1737- elif model_type .lower () == "video" :
1738- from ..model .video import register_builtin_model
1463+ # Forward the update_model_type request to all workers
1464+ tasks = []
1465+ for worker_address , worker_ref in self ._worker_address_to_worker .items ():
1466+ logger .info (f"[DEBUG SUPERVISOR] Forwarding update_model_type to worker: { worker_address } " )
1467+ tasks .append (worker_ref .update_model_type (model_type ))
1468+
1469+ # Wait for all workers to complete the operation
1470+ if tasks :
1471+ await asyncio .gather (* tasks , return_exceptions = True )
1472+ logger .info (f"[DEBUG SUPERVISOR] All workers completed update_model_type operation" )
1473+ else :
1474+ logger .warning (f"[DEBUG SUPERVISOR] No workers available to forward update_model_type request" )
17391475
1740- register_builtin_model ()
1741- logger .info (
1742- f"[DEBUG SUPERVISOR] Video models reloaded successfully"
1743- )
1744- else :
1745- logger .warning (
1746- f"[DEBUG SUPERVISOR] No dynamic loading available for model type: { model_type } "
1747- )
1748- except Exception as reload_error :
1749- logger .error (
1750- f"[DEBUG SUPERVISOR] Error reloading built-in models: { reload_error } " ,
1751- exc_info = True ,
1752- )
1753- # Don't fail the update if reload fails, just log the error
1476+ logger .info (f"[DEBUG SUPERVISOR] update_model_type completed successfully" )
17541477
1755- except requests .exceptions .RequestException as e :
1756- logger .error (
1757- f"[DEBUG SUPERVISOR] Network error downloading model configurations: { e } "
1758- )
1759- raise ValueError (f"Failed to download model configurations: { str (e )} " )
1760- except json .JSONDecodeError as e :
1761- logger .error (f"[DEBUG SUPERVISOR] JSON decode error: { e } " )
1762- raise ValueError (f"Invalid JSON response from remote API: { str (e )} " )
17631478 except Exception as e :
17641479 logger .error (
1765- f"[DEBUG SUPERVISOR] Unexpected error during model update : { e } " ,
1480+ f"[DEBUG SUPERVISOR] Error during update_model_type forwarding : { str ( e ) } " ,
17661481 exc_info = True ,
17671482 )
1768- raise ValueError (f"Failed to update model configurations : { str (e )} " )
1483+ raise ValueError (f"Failed to update model type : { str (e )} " )
17691484
17701485 async def _store_model_configurations (self , model_type : str , model_data ):
17711486 """
0 commit comments