Skip to content

Commit d8091a9

Browse files
snowflake-provisionerSnowflake Authors
andauthored
Project import generated by Copybara. (#20)
GitOrigin-RevId: 7cece61f8ed84deeaabc0bf1cd91fc803117f627 Co-authored-by: Snowflake Authors <[email protected]>
1 parent 8d57915 commit d8091a9

File tree

15 files changed

+554
-243
lines changed

15 files changed

+554
-243
lines changed

bazel/requirements/BUILD.bazel

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@ _GENERATE_TOOL = ":parse_and_generate_requirements"
2121

2222
_GENERATE_COMMAND = "$(location " + _GENERATE_TOOL + ") $(location " + _SRC_REQUIREMENT_FILE + ") --schema $(location " + _SCHEMA_FILE + ") {options} > $@"
2323

24-
_TEMPLATE_FOLDER_PATH = "//bazel/requirements/templates"
25-
2624
_AUTOGEN_HEADERS = """# DO NOT EDIT!
2725
# Generate by running 'bazel run //bazel/requirements:sync_requirements'
2826
"""
2927

28+
# "---" is a document start marker, which is legit but optional (https://yaml.org/spec/1.1/#c-document-start). This
29+
# is needed for conda meta.yaml to bypass some bug from conda side.
30+
_YAML_START_DOCUMENT_MARKER = "---"
31+
3032
_GENERATED_REQUIREMENTS_FILES = {
3133
"requirements_txt": {
3234
"cmd": "--mode dev_version --format text",
@@ -77,7 +79,7 @@ _GENERATED_REQUIREMENTS_FILES = {
7779
"{generated}.body".format(generated = value["generated"]),
7880
],
7981
outs = [value["generated"]],
80-
cmd = "(echo -e \""+ _AUTOGEN_HEADERS +"\" ; cat $(location :{generated}.body) ) > $@".format(
82+
cmd = "(echo -e \"" + _AUTOGEN_HEADERS + "\" ; cat $(location :{generated}.body) ) > $@".format(
8183
generated = value["generated"],
8284
),
8385
tools = [_GENERATE_TOOL],
@@ -99,15 +101,24 @@ genrule(
99101
)
100102

101103
yq(
102-
name = "gen_conda_meta",
104+
name = "gen_conda_meta_body_format",
103105
srcs = [
104106
":meta.body.yaml",
105-
"{template_folder}:meta.tpl.yaml".format(template_folder = _TEMPLATE_FOLDER_PATH),
107+
"//bazel/requirements/templates:meta.tpl.yaml",
106108
],
107-
outs = ["meta.yaml"],
109+
outs = ["meta.body.formatted.yaml"],
108110
expression = ". as $item ireduce ({}; . * $item ) | sort_keys(..)",
109111
)
110112

113+
genrule(
114+
name = "gen_conda_meta",
115+
srcs = [
116+
":meta.body.formatted.yaml",
117+
],
118+
outs = ["meta.yaml"],
119+
cmd = "(echo -e \"" + _AUTOGEN_HEADERS + "\" ; echo \"" + _YAML_START_DOCUMENT_MARKER + "\"; cat $(location :meta.body.formatted.yaml) ) > $@",
120+
)
121+
111122
# Create a test target for each file that Bazel should
112123
# write to the source tree.
113124
[

bazel/requirements/templates/meta.tpl.yaml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,3 @@
1-
# DO NOT EDIT!
2-
# Generated by //bazel/requirements:gen_conda_meta
3-
# To update, run:
4-
# bazel run //bazel/requirements:sync_requirements
5-
#
6-
71
package:
82
name: snowflake-ml-python
93

ci/conda_recipe/meta.yaml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# DO NOT EDIT!
2-
# Generated by //bazel/requirements:gen_conda_meta
3-
# To update, run:
4-
# bazel run //bazel/requirements:sync_requirements
5-
#
2+
# Generate by running 'bazel run //bazel/requirements:sync_requirements'
3+
4+
---
65
about:
76
description: |
87
Snowflake ML client Library is used for interacting with Snowflake to build machine learning solutions.

ci/get_excluded_tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# The missing dependency cuold happen when a new operator is being developed, but not yet released.
1313

1414
set -o pipefail
15-
set -eu
15+
set -u
1616

1717
echo "Running "$0
1818

snowflake/ml/modeling/impute/simple_imputer.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from snowflake.snowpark import functions as F, types as T
1616
from snowflake.snowpark._internal import utils as snowpark_utils
1717

18+
_SUBPROJECT = "Impute"
19+
1820
STRATEGY_TO_STATE_DICT = {
1921
"constant": None,
2022
"mean": _utils.NumericStatistics.MEAN,
@@ -194,10 +196,7 @@ def check_type_consistency(col_types: Dict[str, T.DataType]) -> None:
194196

195197
return input_col_datatypes
196198

197-
@telemetry.send_api_usage_telemetry(
198-
project=base.PROJECT,
199-
subproject=base.SUBPROJECT,
200-
)
199+
@telemetry.send_api_usage_telemetry(project=base.PROJECT, subproject=_SUBPROJECT)
201200
def fit(self, dataset: snowpark.DataFrame) -> "SimpleImputer":
202201
"""
203202
Compute values to impute for the dataset according to the strategy.
@@ -214,7 +213,7 @@ def fit(self, dataset: snowpark.DataFrame) -> "SimpleImputer":
214213
input_col_datatypes = self._get_dataset_input_col_datatypes(dataset)
215214

216215
self.statistics_: Dict[str, Any] = {}
217-
statement_params = telemetry.get_statement_params(base.PROJECT, base.SUBPROJECT, self.__class__.__name__)
216+
statement_params = telemetry.get_statement_params(base.PROJECT, _SUBPROJECT, self.__class__.__name__)
218217

219218
if self.strategy == "constant":
220219
if self.fill_value is None:
@@ -274,14 +273,8 @@ def fit(self, dataset: snowpark.DataFrame) -> "SimpleImputer":
274273
self._is_fitted = True
275274
return self
276275

277-
@telemetry.send_api_usage_telemetry(
278-
project=base.PROJECT,
279-
subproject=base.SUBPROJECT,
280-
)
281-
@telemetry.add_stmt_params_to_df(
282-
project=base.PROJECT,
283-
subproject=base.SUBPROJECT,
284-
)
276+
@telemetry.send_api_usage_telemetry(project=base.PROJECT, subproject=_SUBPROJECT)
277+
@telemetry.add_stmt_params_to_df(project=base.PROJECT, subproject=_SUBPROJECT)
285278
def transform(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> Union[snowpark.DataFrame, pd.DataFrame]:
286279
"""
287280
Transform the input dataset by imputing the computed statistics in the input columns.

snowflake/ml/modeling/metrics/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ py_library(
1313
"precision_recall_fscore_support.py",
1414
"precision_score.py",
1515
"regression.py",
16+
"roc_curve.py",
1617
],
1718
deps = [
1819
":init",

snowflake/ml/modeling/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .covariance import covariance
55
from .precision_recall_fscore_support import precision_recall_fscore_support
66
from .precision_score import precision_score
7+
from .roc_curve import roc_curve
78

89
__all__ = [
910
"accuracy_score",
@@ -12,4 +13,5 @@
1213
"covariance",
1314
"precision_recall_fscore_support",
1415
"precision_score",
16+
"roc_curve",
1517
]

snowflake/ml/modeling/metrics/precision_recall_fscore_support.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,18 @@ def precision_recall_fscore_support(
115115

116116
session = df._session
117117
assert session is not None
118-
query = df.queries["queries"][-1]
119118
sproc_name = f"precision_recall_fscore_support_{snowpark_utils.generate_random_alphanumeric()}"
120119
statement_params = telemetry.get_statement_params(_PROJECT, _SUBPROJECT)
121120

121+
cols = []
122+
if isinstance(y_true_col_names, str):
123+
cols = [y_true_col_names, y_pred_col_names]
124+
elif isinstance(y_true_col_names, list):
125+
cols = y_true_col_names + y_pred_col_names # type:ignore[assignment, operator]
126+
if sample_weight_col_name:
127+
cols.append(sample_weight_col_name)
128+
query = df[cols].queries["queries"][-1]
129+
122130
@F.sproc( # type: ignore[misc]
123131
session=session,
124132
name=sproc_name,
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from typing import Optional, Tuple, Union
2+
3+
import cloudpickle
4+
import numpy.typing as npt
5+
from sklearn import metrics
6+
7+
from snowflake import snowpark
8+
from snowflake.ml._internal import telemetry
9+
from snowflake.snowpark import functions as F
10+
from snowflake.snowpark._internal import utils as snowpark_utils
11+
12+
_PROJECT = "ModelDevelopment"
13+
_SUBPROJECT = "Metrics"
14+
15+
16+
@telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT)
17+
def roc_curve(
18+
*,
19+
df: snowpark.DataFrame,
20+
y_true_col_name: str,
21+
y_score_col_name: str,
22+
pos_label: Optional[Union[str, int]] = None,
23+
sample_weight_col_name: Optional[str] = None,
24+
drop_intermediate: bool = True,
25+
) -> Tuple[npt.ArrayLike, npt.ArrayLike, npt.ArrayLike]:
26+
"""
27+
Compute Receiver operating characteristic (ROC).
28+
29+
Note: this implementation is restricted to the binary classification task.
30+
31+
Args:
32+
df: Input dataframe.
33+
y_true_col_name: Column name representing true binary labels.
34+
If labels are not either {-1, 1} or {0, 1}, then pos_label should be
35+
explicitly given.
36+
y_score_col_name: Column name representing target scores, can either
37+
be probability estimates of the positive class, confidence values,
38+
or non-thresholded measure of decisions (as returned by
39+
"decision_function" on some classifiers).
40+
pos_label: The label of the positive class.
41+
When ``pos_label=None``, if `y_true` is in {-1, 1} or {0, 1},
42+
``pos_label`` is set to 1, otherwise an error will be raised.
43+
sample_weight_col_name: Column name representing sample weights.
44+
drop_intermediate: Whether to drop some suboptimal thresholds which would
45+
not appear on a plotted ROC curve. This is useful in order to create
46+
lighter ROC curves.
47+
48+
Returns:
49+
fpr: ndarray of shape (>2,)
50+
Increasing false positive rates such that element i is the false
51+
positive rate of predictions with score >= `thresholds[i]`.
52+
tpr : ndarray of shape (>2,)
53+
Increasing true positive rates such that element `i` is the true
54+
positive rate of predictions with score >= `thresholds[i]`.
55+
thresholds : ndarray of shape = (n_thresholds,)
56+
Decreasing thresholds on the decision function used to compute
57+
fpr and tpr. `thresholds[0]` represents no instances being predicted
58+
and is arbitrarily set to `max(y_score) + 1`.
59+
"""
60+
session = df._session
61+
assert session is not None
62+
sproc_name = f"roc_curve_{snowpark_utils.generate_random_alphanumeric()}"
63+
statement_params = telemetry.get_statement_params(_PROJECT, _SUBPROJECT)
64+
65+
cols = [y_true_col_name, y_score_col_name]
66+
if sample_weight_col_name:
67+
cols.append(sample_weight_col_name)
68+
query = df[cols].queries["queries"][-1]
69+
70+
@F.sproc( # type: ignore[misc]
71+
session=session,
72+
name=sproc_name,
73+
replace=True,
74+
packages=["cloudpickle", "scikit-learn", "snowflake-snowpark-python"],
75+
statement_params=statement_params,
76+
)
77+
def roc_curve_sproc(session: snowpark.Session) -> bytes:
78+
df = session.sql(query).to_pandas(statement_params=statement_params)
79+
y_true = df[y_true_col_name]
80+
y_score = df[y_score_col_name]
81+
sample_weight = df[sample_weight_col_name] if sample_weight_col_name else None
82+
fpr, tpr, thresholds = metrics.roc_curve(
83+
y_true,
84+
y_score,
85+
pos_label=pos_label,
86+
sample_weight=sample_weight,
87+
drop_intermediate=drop_intermediate,
88+
)
89+
90+
return cloudpickle.dumps((fpr, tpr, thresholds)) # type: ignore[no-any-return]
91+
92+
loaded_data = cloudpickle.loads(session.call(sproc_name))
93+
res: Tuple[npt.ArrayLike, npt.ArrayLike, npt.ArrayLike] = loaded_data
94+
return res

tests/integ/snowflake/ml/modeling/framework/utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ class DataType(Enum):
144144

145145

146146
def gen_fuzz_data(
147-
rows: int, types: List[DataType], low: int = MIN_INT, high: int = MAX_INT
147+
rows: int, types: List[DataType], low: Union[int, List[int]] = MIN_INT, high: Union[int, List[int]] = MAX_INT
148148
) -> Tuple[List[Any], List[str]]:
149149
"""
150150
Generate random data based on input column types and row count.
@@ -153,8 +153,8 @@ def gen_fuzz_data(
153153
Args:
154154
rows: num of rows to generate
155155
types: type per column
156-
low: lower bound of the output interval (inclusive)
157-
high: upper bound of the output interval (exclusive)
156+
low: lower bound(s) of the output interval (inclusive)
157+
high: upper bound(s) of the output interval (exclusive)
158158
159159
Returns:
160160
A tuple of generated data and column names
@@ -166,10 +166,12 @@ def gen_fuzz_data(
166166
names = ["ID"]
167167

168168
for idx, t in enumerate(types):
169+
_low = low if isinstance(low, int) else low[idx]
170+
_high = high if isinstance(high, int) else high[idx]
169171
if t == DataType.INTEGER:
170-
data.append(np.random.randint(low, high, rows))
172+
data.append(np.random.randint(_low, _high, rows))
171173
elif t == DataType.FLOAT:
172-
data.append(np.random.uniform(low, high, rows))
174+
data.append(np.random.uniform(_low, _high, rows))
173175
else:
174176
raise ValueError(f"Unsupported data type {t}")
175177
names.append(f"COL_{idx}")

0 commit comments

Comments
 (0)