Skip to content

Commit 1697443

Browse files
sfc-gh-anavalosSnowflake Authors
andauthored
Project import generated by Copybara. (#178)
GitOrigin-RevId: f10ddb99e98b5c15b4b34ebcd31a289218c43a3c Co-authored-by: Snowflake Authors <[email protected]>
1 parent d03c59c commit 1697443

34 files changed

+834
-224
lines changed

.bazelrc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ build:py3.12 --repo_env=BAZEL_CONDA_PYTHON_VERSION=3.12
3131
build:build --config=_build
3232

3333
# Config to sync files
34-
run:pre_build --config=_build --config=py3.9
34+
run:pre_build --config=_build --config=py3.10
3535

3636
# Config to run type check
37-
build:typecheck --aspects @rules_mypy//:mypy.bzl%mypy_aspect --output_groups=mypy --config=_all --config=py3.9
37+
build:typecheck --aspects @rules_mypy//:mypy.bzl%mypy_aspect --output_groups=mypy --config=_all --config=py3.10
3838

3939
# Config to build the doc
40-
build:docs --config=_all --config=py3.9
40+
build:docs --config=_all --config=py3.10
4141

4242
# Public the extended setting
4343

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
# Release History
22

3+
## 1.14.0
4+
5+
### Bug Fixes
6+
7+
### Behavior Changes
8+
9+
### New Features
10+
11+
* ML Job: The `additional_payloads` argument is now **deprecated** in favor of `imports`.
12+
313
## 1.13.0
414

515
### Bug Fixes

bazel/environments/fetch_conda_env_config.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ load("//bazel/platforms:optional_dependency_groups.bzl", "OPTIONAL_DEPENDENCY_GR
33
def _fetch_conda_env_config_impl(rctx):
44
# read the particular environment variable we are interested in
55
env_name = rctx.os.environ.get("BAZEL_CONDA_ENV_NAME", "core").lower()
6-
python_ver = rctx.os.environ.get("BAZEL_CONDA_PYTHON_VERSION", "3.9").lower()
6+
python_ver = rctx.os.environ.get("BAZEL_CONDA_PYTHON_VERSION", "3.10").lower()
77

88
# necessary to create empty BUILD file for this rule
99
# which will be located somewhere in the Bazel build files

bazel/requirements/templates/bazelrc.tpl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ build:py3.12 --repo_env=BAZEL_CONDA_PYTHON_VERSION=3.12
2828
build:build --config=_build
2929

3030
# Config to sync files
31-
run:pre_build --config=_build --config=py3.9
31+
run:pre_build --config=_build --config=py3.10
3232

3333
# Config to run type check
34-
build:typecheck --aspects @rules_mypy//:mypy.bzl%mypy_aspect --output_groups=mypy --config=_all --config=py3.9
34+
build:typecheck --aspects @rules_mypy//:mypy.bzl%mypy_aspect --output_groups=mypy --config=_all --config=py3.10
3535

3636
# Config to build the doc
37-
build:docs --config=_all --config=py3.9
37+
build:docs --config=_all --config=py3.10
3838

3939
# Public the extended setting
4040

ci/build_and_run_tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ WITH_SNOWPARK=false
4242
WITH_SPCS_IMAGE=false
4343
RUN_GRYPE=false
4444
MODE="continuous_run"
45-
PYTHON_VERSION=3.9
45+
PYTHON_VERSION=3.10
4646
PYTHON_ENABLE_SCRIPT="bin/activate"
4747
SNOWML_DIR="snowml"
4848
SNOWPARK_DIR="snowpark-python"

ci/conda_recipe/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ build:
1717
noarch: python
1818
package:
1919
name: snowflake-ml-python
20-
version: 1.13.0
20+
version: 1.14.0
2121
requirements:
2222
build:
2323
- python

snowflake/ml/jobs/_utils/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
2626
DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
2727
DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
28-
DEFAULT_IMAGE_TAG = "1.6.2"
28+
DEFAULT_IMAGE_TAG = "1.8.0"
2929
DEFAULT_ENTRYPOINT_PATH = "func.py"
3030

3131
# Percent of container memory to allocate for /dev/shm volume

snowflake/ml/jobs/_utils/scripts/mljob_launcher.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -234,12 +234,6 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
234234
if payload_dir and payload_dir not in sys.path:
235235
sys.path.insert(0, payload_dir)
236236

237-
# Create a Snowpark session before running the script
238-
# Session can be retrieved from using snowflake.snowpark.context.get_active_session()
239-
config = SnowflakeLoginOptions()
240-
config["client_session_keep_alive"] = "True"
241-
session = Session.builder.configs(config).create() # noqa: F841
242-
243237
try:
244238

245239
if main_func:
@@ -266,7 +260,6 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
266260
finally:
267261
# Restore original sys.argv
268262
sys.argv = original_argv
269-
session.close()
270263

271264

272265
def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = None) -> ExecutionResult:
@@ -297,6 +290,12 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
297290
except ModuleNotFoundError:
298291
warnings.warn("Ray is not installed, skipping Ray initialization", ImportWarning, stacklevel=1)
299292

293+
# Create a Snowpark session before starting
294+
# Session can be retrieved from using snowflake.snowpark.context.get_active_session()
295+
config = SnowflakeLoginOptions()
296+
config["client_session_keep_alive"] = "True"
297+
session = Session.builder.configs(config).create() # noqa: F841
298+
300299
try:
301300
# Wait for minimum required instances if specified
302301
min_instances_str = os.environ.get(MIN_INSTANCES_ENV_VAR) or "1"
@@ -352,6 +351,9 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
352351
f"Failed to serialize JSON result to {result_json_path}: {json_exc}", RuntimeWarning, stacklevel=1
353352
)
354353

354+
# Close the session after serializing the result
355+
session.close()
356+
355357

356358
if __name__ == "__main__":
357359
# Parse command line arguments

snowflake/ml/jobs/job.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def _service_spec(self) -> dict[str, Any]:
8383
def _container_spec(self) -> dict[str, Any]:
8484
"""Get the job's main container spec."""
8585
containers = self._service_spec["spec"]["containers"]
86+
if len(containers) == 1:
87+
return cast(dict[str, Any], containers[0])
8688
try:
8789
container_spec = next(c for c in containers if c["name"] == constants.DEFAULT_CONTAINER_NAME)
8890
except StopIteration:
@@ -163,7 +165,7 @@ def get_logs(
163165
Returns:
164166
The job's execution logs.
165167
"""
166-
logs = _get_logs(self._session, self.id, limit, instance_id, verbose)
168+
logs = _get_logs(self._session, self.id, limit, instance_id, self._container_spec["name"], verbose)
167169
assert isinstance(logs, str) # mypy
168170
if as_list:
169171
return logs.splitlines()
@@ -281,7 +283,12 @@ def _get_service_spec(session: snowpark.Session, job_id: str) -> dict[str, Any]:
281283

282284
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"])
283285
def _get_logs(
284-
session: snowpark.Session, job_id: str, limit: int = -1, instance_id: Optional[int] = None, verbose: bool = True
286+
session: snowpark.Session,
287+
job_id: str,
288+
limit: int = -1,
289+
instance_id: Optional[int] = None,
290+
container_name: str = constants.DEFAULT_CONTAINER_NAME,
291+
verbose: bool = True,
285292
) -> str:
286293
"""
287294
Retrieve the job's execution logs.
@@ -291,6 +298,7 @@ def _get_logs(
291298
limit: The maximum number of lines to return. Negative values are treated as no limit.
292299
session: The Snowpark session to use. If none specified, uses active session.
293300
instance_id: Optional instance ID to get logs from a specific instance.
301+
container_name: The container name to get logs from a specific container.
294302
verbose: Whether to return the full log or just the portion between START and END messages.
295303
296304
Returns:
@@ -311,7 +319,7 @@ def _get_logs(
311319
params: list[Any] = [
312320
job_id,
313321
0 if instance_id is None else instance_id,
314-
constants.DEFAULT_CONTAINER_NAME,
322+
container_name,
315323
]
316324
if limit > 0:
317325
params.append(limit)
@@ -337,7 +345,7 @@ def _get_logs(
337345
job_id,
338346
limit=limit,
339347
instance_id=instance_id if instance_id else 0,
340-
container_name=constants.DEFAULT_CONTAINER_NAME,
348+
container_name=container_name,
341349
)
342350
full_log = os.linesep.join(row[0] for row in logs)
343351

snowflake/ml/jobs/jobs_test.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@
1010
from snowflake.snowpark import exceptions as sp_exceptions
1111
from snowflake.snowpark.row import Row
1212

13+
SERVICE_SPEC = """
14+
spec:
15+
containers:
16+
- name: main
17+
image: test-image
18+
"""
19+
1320

1421
class JobTest(parameterized.TestCase):
1522
@parameterized.named_parameters( # type: ignore[misc]
@@ -83,7 +90,7 @@ def test_get_logs_negative(self) -> None:
8390

8491
def sql_side_effect(session: snowpark.Session, query_str: str, *args: Any, **kwargs: Any) -> Any:
8592
if query_str.startswith("DESCRIBE SERVICE IDENTIFIER"):
86-
return [Row(target_instances=2)]
93+
return [Row(target_instances=2, spec=SERVICE_SPEC)]
8794
else:
8895
raise sp_exceptions.SnowparkSQLException("Waiting to start, Container Status: PENDING")
8996

@@ -97,7 +104,7 @@ def test_get_logs_from_event_table(self) -> None:
97104
def sql_side_effect(session: snowpark.Session, query_str: str, *args: Any, **kwargs: Any) -> Any:
98105
if query_str.startswith("DESCRIBE SERVICE IDENTIFIER"):
99106
return [
100-
Row(target_instances=2),
107+
Row(target_instances=2, spec=SERVICE_SPEC),
101108
]
102109
elif query_str.startswith("SELECT VALUE FROM "):
103110
return [

0 commit comments

Comments
 (0)