From 78874b5a2ee74024b22809af8a5e6fde3b078962 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Fri, 26 Jul 2024 13:21:58 +0000 Subject: [PATCH 1/9] chore: telemetry for deployment configs --- src/sagemaker/jumpstart/exceptions.py | 2 + src/sagemaker/jumpstart/factory/estimator.py | 28 ++--- src/sagemaker/jumpstart/factory/model.py | 117 +++++++++++------- src/sagemaker/jumpstart/utils.py | 56 ++++++++- .../jumpstart/model/test_jumpstart_model.py | 17 +++ .../jumpstart/estimator/test_estimator.py | 2 +- tests/unit/sagemaker/jumpstart/test_utils.py | 14 ++- 7 files changed, 171 insertions(+), 65 deletions(-) diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py index 742a6b8d3f..5941cbad2e 100644 --- a/src/sagemaker/jumpstart/exceptions.py +++ b/src/sagemaker/jumpstart/exceptions.py @@ -150,6 +150,7 @@ def __init__( model. (Default: None). """ + version = version or "*" if message: self.message = message else: @@ -198,6 +199,7 @@ def __init__( version: Optional[str] = None, message: Optional[str] = None, ): + version = version or "*" if message: self.message = message else: diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 0d156c415f..4da190ac1e 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -69,6 +69,7 @@ add_jumpstart_model_info_tags, get_eula_message, get_default_jumpstart_session_with_user_agent_suffix, + get_top_ranked_config_name, update_dict_if_key_not_present, resolve_estimator_sagemaker_config_field, verify_model_region_and_return_specs, @@ -204,7 +205,7 @@ def get_init_kwargs( estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(estimator_init_kwargs) - estimator_init_kwargs = _add_sagemaker_session_to_kwargs(estimator_init_kwargs) + estimator_init_kwargs = _add_sagemaker_session_with_user_agent_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_region_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_instance_type_and_count_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_image_uri_to_kwargs(estimator_init_kwargs) @@ -438,12 +439,15 @@ def _add_region_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: return kwargs -def _add_sagemaker_session_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: +def _add_sagemaker_session_with_user_agent_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: """Sets session in kwargs based on default or override, returns full kwargs.""" kwargs.sagemaker_session = ( kwargs.sagemaker_session or get_default_jumpstart_session_with_user_agent_suffix( - kwargs.model_id, kwargs.model_version, kwargs.hub_arn + model_id=kwargs.model_id, + model_version=kwargs.model_version, + config_name=None, + is_hub_content=kwargs.hub_arn is not None, ) ) return kwargs @@ -903,20 +907,16 @@ def _add_config_name_to_kwargs( ) -> JumpStartEstimatorInitKwargs: """Sets tags in kwargs based on default or override, returns full kwargs.""" - specs = verify_model_region_and_return_specs( + kwargs.config_name = kwargs.config_name or get_top_ranked_config_name( + region=kwargs.region, model_id=kwargs.model_id, - version=kwargs.model_version, + model_version=kwargs.model_version, + sagemaker_session=kwargs.sagemaker_session, scope=JumpStartScriptScope.TRAINING, - region=kwargs.region, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + model_type=kwargs.model_type, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + hub_arn=kwargs.hub_arn, ) - if specs.training_configs and specs.training_configs.get_top_config_from_ranking(): - kwargs.config_name = ( - kwargs.config_name or specs.training_configs.get_top_config_from_ranking().config_name - ) - return kwargs diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index f4e13de6d7..27087ecba1 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -29,6 +29,7 @@ ) from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, INFERENCE_ENTRY_POINT_SCRIPT_NAME, JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, @@ -54,6 +55,7 @@ add_jumpstart_model_info_tags, get_default_jumpstart_session_with_user_agent_suffix, get_neo_content_bucket, + get_top_ranked_config_name, update_dict_if_key_not_present, resolve_model_sagemaker_config_field, verify_model_region_and_return_specs, @@ -155,7 +157,7 @@ def _add_region_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelIni return kwargs -def _add_sagemaker_session_to_kwargs( +def _add_sagemaker_session_with_custom_user_agent_to_kwargs( kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs] ) -> JumpStartModelInitKwargs: """Sets session in kwargs based on default or override, returns full kwargs.""" @@ -163,7 +165,7 @@ def _add_sagemaker_session_to_kwargs( kwargs.sagemaker_session = ( kwargs.sagemaker_session or get_default_jumpstart_session_with_user_agent_suffix( - kwargs.model_id, kwargs.model_version, kwargs.hub_arn + kwargs.model_id, kwargs.model_version, kwargs.config_name, kwargs.hub_arn ) ) @@ -662,6 +664,25 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta ValueError: If the instance_type is not supported with the current config. """ + # we need to create a default JS session (without custom user agent) + # in order to retrieve config name info + temp_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION + + kwargs.config_name = kwargs.config_name or get_top_ranked_config_name( + region=kwargs.region, + model_id=kwargs.model_id, + model_version=kwargs.model_version, + sagemaker_session=temp_session, + scope=JumpStartScriptScope.INFERENCE, + model_type=kwargs.model_type, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + hub_arn=kwargs.hub_arn, + ) + + if kwargs.config_name is None: + return kwargs + specs = verify_model_region_and_return_specs( model_id=kwargs.model_id, version=kwargs.model_version, @@ -669,26 +690,21 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta region=kwargs.region, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - sagemaker_session=kwargs.sagemaker_session, + sagemaker_session=temp_session, model_type=kwargs.model_type, config_name=kwargs.config_name, ) - if specs.inference_configs: - default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name - kwargs.config_name = kwargs.config_name or default_config_name - - if not kwargs.config_name: - return kwargs - - if kwargs.config_name not in set(specs.inference_configs.configs.keys()): - raise ValueError( - f"Config {kwargs.config_name} is not supported for model {kwargs.model_id}." - ) - resolved_config = specs.inference_configs.configs[kwargs.config_name].resolved_config - supported_instance_types = resolved_config.get("supported_inference_instance_types", []) - if kwargs.instance_type not in supported_instance_types: - JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type) + resolved_config = ( + specs.inference_configs.configs[kwargs.config_name].resolved_config + if specs.inference_configs + else None + ) + if resolved_config is None: + return kwargs + supported_instance_types = resolved_config.get("supported_inference_instance_types", []) + if kwargs.instance_type not in supported_instance_types: + JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type) return kwargs @@ -740,27 +756,41 @@ def _add_config_name_to_deploy_kwargs( ValueError: If the instance_type is not supported with the current config. """ - specs = verify_model_region_and_return_specs( - model_id=kwargs.model_id, - version=kwargs.model_version, - scope=JumpStartScriptScope.INFERENCE, - region=kwargs.region, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, - config_name=kwargs.config_name, - ) + # we need to create a default JS session (without custom user agent) + # in order to retrieve config name info + temp_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION if training_config_name: - kwargs.config_name = _select_inference_config_from_training_config( + + specs = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + scope=JumpStartScriptScope.INFERENCE, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=temp_session, + model_type=kwargs.model_type, + config_name=kwargs.config_name, + ) + default_config_name = _select_inference_config_from_training_config( specs=specs, training_config_name=training_config_name ) - return kwargs - if specs.inference_configs: - default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name - kwargs.config_name = kwargs.config_name or default_config_name + else: + default_config_name = get_top_ranked_config_name( + region=kwargs.region, + model_id=kwargs.model_id, + model_version=kwargs.model_version, + sagemaker_session=temp_session, + scope=JumpStartScriptScope.INFERENCE, + model_type=kwargs.model_type, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + hub_arn=kwargs.hub_arn, + ) + + kwargs.config_name = kwargs.config_name or default_config_name return kwargs @@ -839,16 +869,16 @@ def get_deploy_kwargs( routing_config=routing_config, ) - deploy_kwargs = _add_sagemaker_session_to_kwargs(kwargs=deploy_kwargs) + deploy_kwargs = _add_config_name_to_deploy_kwargs( + kwargs=deploy_kwargs, training_config_name=training_config_name + ) + + deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(kwargs=deploy_kwargs) deploy_kwargs = _add_model_version_to_kwargs(kwargs=deploy_kwargs) deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs) - deploy_kwargs = _add_config_name_to_deploy_kwargs( - kwargs=deploy_kwargs, training_config_name=training_config_name - ) - deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs) deploy_kwargs.initial_instance_count = initial_instance_count or 1 @@ -1030,11 +1060,14 @@ def get_init_kwargs( ) model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs) + model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs) + model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_sagemaker_session_to_kwargs(kwargs=model_init_kwargs) + model_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs( + kwargs=model_init_kwargs + ) model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_instance_type_to_kwargs( @@ -1062,8 +1095,6 @@ def get_init_kwargs( model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_additional_model_data_sources_to_kwargs(kwargs=model_init_kwargs) return model_init_kwargs diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index f521dbcc5a..3b31be0d2d 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1109,11 +1109,16 @@ def get_jumpstart_configs( def get_jumpstart_user_agent_extra_suffix( - model_id: Optional[str], model_version: Optional[str], is_hub_content: Optional[bool] + model_id: Optional[str], + model_version: Optional[str], + config_name: Optional[str], + is_hub_content: Optional[bool], ) -> str: """Returns the model-specific user agent string to be added to requests.""" sagemaker_python_sdk_headers = get_user_agent_extra_suffix() jumpstart_specific_suffix = f"md/js_model_id#{model_id} md/js_model_ver#{model_version}" + config_specific_suffix = f"md/js_config#{config_name}" + print(config_name) hub_specific_suffix = f"md/js_is_hub_content#{is_hub_content}" if os.getenv(constants.ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY, None): @@ -1128,19 +1133,66 @@ def get_jumpstart_user_agent_extra_suffix( else: headers = f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix}" + if config_name: + headers = f"{headers} {config_specific_suffix}" + return headers +def get_top_ranked_config_name( + region: str, + model_id: str, + model_version: str, + sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, + model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, + tolerate_deprecated_model: bool = False, + tolerate_vulnerable_model: bool = False, + hub_arn: Optional[str] = None, +) -> Optional[str]: + """Returns the top ranked config name for the given model ID and region. + + Raises: + ValueError: If the script scope is not supported by JumpStart. + """ + model_specs = verify_model_region_and_return_specs( + model_id=model_id, + version=model_version, + scope=scope, + region=region, + hub_arn=hub_arn, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, + model_type=model_type, + ) + + if scope == enums.JumpStartScriptScope.INFERENCE: + return ( + model_specs.inference_configs.get_top_config_from_ranking().config_name + if model_specs.inference_configs + else None + ) + if scope == enums.JumpStartScriptScope.TRAINING: + return ( + model_specs.training_configs.get_top_config_from_ranking().config_name + if model_specs.training_configs + else None + ) + raise ValueError(f"Unsupported script scope: {scope}.") + + def get_default_jumpstart_session_with_user_agent_suffix( model_id: Optional[str] = None, model_version: Optional[str] = None, + config_name: Optional[str] = None, is_hub_content: Optional[bool] = False, ) -> Session: """Returns default JumpStart SageMaker Session with model-specific user agent suffix.""" botocore_session = botocore.session.get_session() botocore_config = botocore.config.Config( user_agent_extra=get_jumpstart_user_agent_extra_suffix( - model_id, model_version, is_hub_content + model_id, model_version, config_name, is_hub_content ), ) botocore_session.set_default_client_config(botocore_config) diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 5ee0abd41f..032052ea0e 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -396,3 +396,20 @@ def test_jumpstart_model_with_deployment_configs(setup): response = predictor.predict(payload, custom_attributes="accept_eula=true") assert response is not None + + +def test_jumpstart_session_with_config_name(): + model = JumpStartModel(model_id="meta-textgeneration-llama-2-7b", model_version="*") + assert model.config_name != None + session = model.sagemaker_session + + with mock.patch("botocore.client.BaseClient._make_request") as mock_make_request: + try: + session.sagemaker_client.list_endpoints() + except Exception: + pass + + assert ( + "md/js_model_id#meta-textgeneration-llama-2-7b md/js_model_ver#* md/js_config#tgi" + in mock_make_request.call_args[0][1]["headers"]["User-Agent"] + ) diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 3678685db5..fbf76d1c98 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1637,7 +1637,7 @@ def test_training_passes_role_to_deploy( @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch( "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix", - sagemaker_session, + lambda *largs, **kwargs: sagemaker_session, ) @mock.patch( "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix", diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 07c49a308c..231efbbbcf 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1890,20 +1890,24 @@ class TestUserAgent: def test_get_jumpstart_user_agent_extra_suffix(self, mock_getenv): mock_getenv.return_value = False assert utils.get_jumpstart_user_agent_extra_suffix( - "some-id", "some-version", "False" + "some-id", "some-version", None, "False" ).endswith("md/js_model_id#some-id md/js_model_ver#some-version") mock_getenv.return_value = None assert utils.get_jumpstart_user_agent_extra_suffix( - "some-id", "some-version", "False" + "some-id", "some-version", None, "False" ).endswith("md/js_model_id#some-id md/js_model_ver#some-version") mock_getenv.return_value = "True" assert not utils.get_jumpstart_user_agent_extra_suffix( - "some-id", "some-version", "True" + "some-id", "some-version", None, "True" ).endswith("md/js_model_id#some-id md/js_model_ver#some-version md/js_is_hub_content#True") mock_getenv.return_value = True assert not utils.get_jumpstart_user_agent_extra_suffix( - "some-id", "some-version", "True" + "some-id", "some-version", None, "True" ).endswith("md/js_model_id#some-id md/js_model_ver#some-version md/js_is_hub_content#True") + mock_getenv.return_value = False + assert utils.get_jumpstart_user_agent_extra_suffix( + "some-id", "some-version", "some-config", "False" + ).endswith("md/js_model_id#some-id md/js_model_ver#some-version md/js_config#some-config") @patch("sagemaker.jumpstart.utils.botocore.session") @patch("sagemaker.jumpstart.utils.botocore.config.Config") @@ -1923,7 +1927,7 @@ def test_get_default_jumpstart_session_with_user_agent_suffix( utils.get_default_jumpstart_session_with_user_agent_suffix("model_id", "model_version") mock_boto3_session.get_session.assert_called_once_with() mock_get_jumpstart_user_agent_extra_suffix.assert_called_once_with( - "model_id", "model_version", False + "model_id", "model_version", None, False ) mock_botocore_config.assert_called_once_with( user_agent_extra=mock_get_jumpstart_user_agent_extra_suffix.return_value From 49a1c82fb3077e73523e120b83e1747e679d11bf Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Fri, 26 Jul 2024 13:29:34 +0000 Subject: [PATCH 2/9] chore: minor fixes --- src/sagemaker/jumpstart/factory/estimator.py | 8 ++++++-- src/sagemaker/jumpstart/utils.py | 1 - 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 4da190ac1e..d6c26b0429 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -205,7 +205,9 @@ def get_init_kwargs( estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(estimator_init_kwargs) - estimator_init_kwargs = _add_sagemaker_session_with_user_agent_to_kwargs(estimator_init_kwargs) + estimator_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs( + estimator_init_kwargs + ) estimator_init_kwargs = _add_region_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_instance_type_and_count_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_image_uri_to_kwargs(estimator_init_kwargs) @@ -439,7 +441,9 @@ def _add_region_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: return kwargs -def _add_sagemaker_session_with_user_agent_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: +def _add_sagemaker_session_with_custom_user_agent_to_kwargs( + kwargs: JumpStartKwargs, +) -> JumpStartKwargs: """Sets session in kwargs based on default or override, returns full kwargs.""" kwargs.sagemaker_session = ( kwargs.sagemaker_session diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 3b31be0d2d..e298609857 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1118,7 +1118,6 @@ def get_jumpstart_user_agent_extra_suffix( sagemaker_python_sdk_headers = get_user_agent_extra_suffix() jumpstart_specific_suffix = f"md/js_model_id#{model_id} md/js_model_ver#{model_version}" config_specific_suffix = f"md/js_config#{config_name}" - print(config_name) hub_specific_suffix = f"md/js_is_hub_content#{is_hub_content}" if os.getenv(constants.ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY, None): From 3fc60f28d3a422d195fb6e6982992e760395aee1 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Fri, 26 Jul 2024 13:39:50 +0000 Subject: [PATCH 3/9] chore: address minor issues --- src/sagemaker/jumpstart/factory/model.py | 54 ++++++++++--------- .../jumpstart/model/test_jumpstart_model.py | 2 +- 2 files changed, 30 insertions(+), 26 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 27087ecba1..08a4adbf92 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -246,6 +246,32 @@ def _add_instance_type_to_kwargs( kwargs.instance_type, ) + specs = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + scope=JumpStartScriptScope.INFERENCE, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + config_name=kwargs.config_name, + ) + + if specs.inference_configs and kwargs.config_name not in specs.inference_configs.configs: + return kwargs + + resolved_config = ( + specs.inference_configs.configs[kwargs.config_name].resolved_config + if specs.inference_configs + else None + ) + if resolved_config is None: + return kwargs + supported_instance_types = resolved_config.get("supported_inference_instance_types", []) + if kwargs.instance_type not in supported_instance_types: + JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type) + return kwargs @@ -683,28 +709,6 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta if kwargs.config_name is None: return kwargs - specs = verify_model_region_and_return_specs( - model_id=kwargs.model_id, - version=kwargs.model_version, - scope=JumpStartScriptScope.INFERENCE, - region=kwargs.region, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - sagemaker_session=temp_session, - model_type=kwargs.model_type, - config_name=kwargs.config_name, - ) - - resolved_config = ( - specs.inference_configs.configs[kwargs.config_name].resolved_config - if specs.inference_configs - else None - ) - if resolved_config is None: - return kwargs - supported_instance_types = resolved_config.get("supported_inference_instance_types", []) - if kwargs.instance_type not in supported_instance_types: - JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type) return kwargs @@ -873,10 +877,10 @@ def get_deploy_kwargs( kwargs=deploy_kwargs, training_config_name=training_config_name ) - deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(kwargs=deploy_kwargs) - deploy_kwargs = _add_model_version_to_kwargs(kwargs=deploy_kwargs) + deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(kwargs=deploy_kwargs) + deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs) deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs) @@ -1060,8 +1064,8 @@ def get_init_kwargs( ) model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) + model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs( kwargs=model_init_kwargs diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 032052ea0e..610ac80e3b 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -399,7 +399,7 @@ def test_jumpstart_model_with_deployment_configs(setup): def test_jumpstart_session_with_config_name(): - model = JumpStartModel(model_id="meta-textgeneration-llama-2-7b", model_version="*") + model = JumpStartModel(model_id="meta-textgeneration-llama-2-7b") assert model.config_name != None session = model.sagemaker_session From df711e3d0970a7ea798d2d5fad2f18dcbde6012f Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Fri, 26 Jul 2024 14:10:12 +0000 Subject: [PATCH 4/9] fix: flake8 --- tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 610ac80e3b..a7693709dd 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -400,7 +400,7 @@ def test_jumpstart_model_with_deployment_configs(setup): def test_jumpstart_session_with_config_name(): model = JumpStartModel(model_id="meta-textgeneration-llama-2-7b") - assert model.config_name != None + assert model.config_name is not None session = model.sagemaker_session with mock.patch("botocore.client.BaseClient._make_request") as mock_make_request: From 8f886fe6e3f415d7dd1807a0800c61e1760c7962 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Fri, 26 Jul 2024 15:28:14 +0000 Subject: [PATCH 5/9] fix: model type for estimator --- src/sagemaker/jumpstart/factory/estimator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index d6c26b0429..31f62aefb1 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -917,7 +917,7 @@ def _add_config_name_to_kwargs( model_version=kwargs.model_version, sagemaker_session=kwargs.sagemaker_session, scope=JumpStartScriptScope.TRAINING, - model_type=kwargs.model_type, + model_type=kwargs.model_type[0], tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, hub_arn=kwargs.hub_arn, From e4fa7b0f1c2c3b5bf6d1bc667a82f1b757372aa7 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 30 Jul 2024 14:14:37 +0000 Subject: [PATCH 6/9] chore: add ranking name argument to get_top_ranked_config_name --- src/sagemaker/jumpstart/exceptions.py | 2 -- src/sagemaker/jumpstart/utils.py | 9 +++++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py index 5941cbad2e..742a6b8d3f 100644 --- a/src/sagemaker/jumpstart/exceptions.py +++ b/src/sagemaker/jumpstart/exceptions.py @@ -150,7 +150,6 @@ def __init__( model. (Default: None). """ - version = version or "*" if message: self.message = message else: @@ -199,7 +198,6 @@ def __init__( version: Optional[str] = None, message: Optional[str] = None, ): - version = version or "*" if message: self.message = message else: diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index e298609857..bc67649c87 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1148,6 +1148,7 @@ def get_top_ranked_config_name( tolerate_deprecated_model: bool = False, tolerate_vulnerable_model: bool = False, hub_arn: Optional[str] = None, + ranking_name: enums.JumpStartConfigRankingName = enums.JumpStartConfigRankingName.DEFAULT, ) -> Optional[str]: """Returns the top ranked config name for the given model ID and region. @@ -1168,13 +1169,17 @@ def get_top_ranked_config_name( if scope == enums.JumpStartScriptScope.INFERENCE: return ( - model_specs.inference_configs.get_top_config_from_ranking().config_name + model_specs.inference_configs.get_top_config_from_ranking( + ranking_name=ranking_name + ).config_name if model_specs.inference_configs else None ) if scope == enums.JumpStartScriptScope.TRAINING: return ( - model_specs.training_configs.get_top_config_from_ranking().config_name + model_specs.training_configs.get_top_config_from_ranking( + ranking_name=ranking_name + ).config_name if model_specs.training_configs else None ) From 12f8f95acfff54e18334a22931bfaf770a062c68 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Wed, 31 Jul 2024 16:26:30 +0000 Subject: [PATCH 7/9] chore: use named args --- src/sagemaker/jumpstart/factory/estimator.py | 2 +- src/sagemaker/jumpstart/factory/model.py | 5 ++++- src/sagemaker/jumpstart/utils.py | 5 ++++- tests/unit/sagemaker/jumpstart/test_utils.py | 5 ++++- 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 31f62aefb1..d6c26b0429 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -917,7 +917,7 @@ def _add_config_name_to_kwargs( model_version=kwargs.model_version, sagemaker_session=kwargs.sagemaker_session, scope=JumpStartScriptScope.TRAINING, - model_type=kwargs.model_type[0], + model_type=kwargs.model_type, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, hub_arn=kwargs.hub_arn, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 08a4adbf92..117b9e1854 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -165,7 +165,10 @@ def _add_sagemaker_session_with_custom_user_agent_to_kwargs( kwargs.sagemaker_session = ( kwargs.sagemaker_session or get_default_jumpstart_session_with_user_agent_suffix( - kwargs.model_id, kwargs.model_version, kwargs.config_name, kwargs.hub_arn + model_id=kwargs.model_id, + model_version=kwargs.model_version, + config_name=kwargs.config_name, + is_hub_content=kwargs.hub_arn is not None, ) ) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index bc67649c87..3d36aabb3a 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1196,7 +1196,10 @@ def get_default_jumpstart_session_with_user_agent_suffix( botocore_session = botocore.session.get_session() botocore_config = botocore.config.Config( user_agent_extra=get_jumpstart_user_agent_extra_suffix( - model_id, model_version, config_name, is_hub_content + model_id=model_id, + model_version=model_version, + config_name=config_name, + is_hub_content=is_hub_content, ), ) botocore_session.set_default_client_config(botocore_config) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 231efbbbcf..cbf918dee8 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1927,7 +1927,10 @@ def test_get_default_jumpstart_session_with_user_agent_suffix( utils.get_default_jumpstart_session_with_user_agent_suffix("model_id", "model_version") mock_boto3_session.get_session.assert_called_once_with() mock_get_jumpstart_user_agent_extra_suffix.assert_called_once_with( - "model_id", "model_version", None, False + model_id="model_id", + model_version="model_version", + config_name=None, + is_hub_content=False, ) mock_botocore_config.assert_called_once_with( user_agent_extra=mock_get_jumpstart_user_agent_extra_suffix.return_value From 3534b7956d95dd87ff91c41670426ca8c971693f Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 5 Aug 2024 13:27:51 +0000 Subject: [PATCH 8/9] fix: remove tuple from model type --- src/sagemaker/jumpstart/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index ae54bc72b8..68d9c282d7 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -2417,7 +2417,7 @@ def __init__( self.model_id = model_id self.model_version = model_version self.hub_arn = hub_arn - self.model_type = (model_type,) + self.model_type = model_type self.instance_type = instance_type self.instance_count = instance_count self.region = region From fdaa7fd55acd23c47f57dbe88056df1f3dbde929 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 5 Aug 2024 18:11:20 +0000 Subject: [PATCH 9/9] chore: add comment explaining test --- tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index a7693709dd..7733041579 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -403,6 +403,9 @@ def test_jumpstart_session_with_config_name(): assert model.config_name is not None session = model.sagemaker_session + # we're mocking the http request, so it's expected to raise an Exception. + # we're interested that the low-level request attaches the correct + # jumpstart-related tags. with mock.patch("botocore.client.BaseClient._make_request") as mock_make_request: try: session.sagemaker_client.list_endpoints()