Skip to content

Commit 13a194f

Browse files
mabundayMark Bunday
andauthored
Apply and enable black + isort through tox (#324)
Co-authored-by: Mark Bunday <[email protected]>
1 parent ba6cd4c commit 13a194f

File tree

71 files changed

+2639
-2094
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+2639
-2094
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[tool.isort]
2+
profile = "black"

setup.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,42 @@
11
from __future__ import absolute_import
2+
23
import os
34
from glob import glob
45
from os.path import basename, splitext
5-
from setuptools import setup, find_packages
6+
7+
from setuptools import find_packages, setup
68

79

810
def read(fname):
911
return open(os.path.join(os.path.dirname(__file__), fname)).read()
1012

1113

1214
setup(
13-
name='sagemaker_xgboost_container',
14-
version='2.0',
15-
description='Open source library for creating XGBoost containers to run on Amazon SageMaker.',
16-
17-
packages=find_packages(where='src', exclude=('test',)),
18-
package_dir={'': 'src'},
19-
py_modules=[splitext(basename(path))[0] for path in glob('src/*.py')],
20-
21-
long_description=read('README.rst'),
22-
author='Amazon Web Services',
23-
license='Apache License 2.0',
24-
15+
name="sagemaker_xgboost_container",
16+
version="2.0",
17+
description="Open source library for creating XGBoost containers to run on Amazon SageMaker.",
18+
packages=find_packages(where="src", exclude=("test",)),
19+
package_dir={"": "src"},
20+
py_modules=[splitext(basename(path))[0] for path in glob("src/*.py")],
21+
long_description=read("README.rst"),
22+
author="Amazon Web Services",
23+
license="Apache License 2.0",
2524
classifiers=[
2625
"Development Status :: 5 - Production/Stable",
2726
"Intended Audience :: Developers",
2827
"Natural Language :: English",
2928
"License :: OSI Approved :: Apache Software License",
3029
"Programming Language :: Python",
31-
'Programming Language :: Python :: 3.6',
32-
'Programming Language :: Python :: 3.7',
33-
'Programming Language :: Python :: 3.8',
30+
"Programming Language :: Python :: 3.6",
31+
"Programming Language :: Python :: 3.7",
32+
"Programming Language :: Python :: 3.8",
3433
],
35-
3634
install_requires=read("requirements.txt"),
37-
38-
extras_require={
39-
'test': read("test-requirements.txt")
40-
},
41-
35+
extras_require={"test": read("test-requirements.txt")},
4236
entry_points={
4337
"console_scripts": [
4438
"serve=sagemaker_xgboost_container.serving:serving_entrypoint",
4539
]
4640
},
47-
48-
python_requires='>=3.6',
41+
python_requires=">=3.6",
4942
)

src/sagemaker_algorithm_toolkit/channel_validation.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
# language governing permissions and limitations under the License.
1313
from sagemaker_algorithm_toolkit import exceptions as exc
1414

15-
1615
CONTENT_TYPE = "ContentType"
1716
TRAINING_INPUT_MODE = "TrainingInputMode"
1817
S3_DIST_TYPE = "S3DistributionType"
1918

2019

2120
class Channel(object):
2221
"""Represents a single SageMaker training job channel."""
22+
2323
FILE_MODE = "File"
2424
PIPE_MODE = "Pipe"
2525
AUGMENTED_MODE = "Augmented"
@@ -36,12 +36,13 @@ def format(self):
3636
"""Format channel for SageMaker's CreateAlgorithm API."""
3737
supported_content_types = list(set(c[0] for c in self.supported))
3838
supported_input_modes = list(set(c[1] for c in self.supported))
39-
return {"Name": self.name,
40-
"Description": self.name,
41-
"IsRequired": self.required,
42-
"SupportedContentTypes": supported_content_types,
43-
"SupportedInputModes": supported_input_modes,
44-
}
39+
return {
40+
"Name": self.name,
41+
"Description": self.name,
42+
"IsRequired": self.required,
43+
"SupportedContentTypes": supported_content_types,
44+
"SupportedInputModes": supported_input_modes,
45+
}
4546

4647
def add(self, content_type, supported_input_mode, supported_s3_data_distribution_type):
4748
"""Add relevant configuration as a supported configuration for the channel."""

src/sagemaker_algorithm_toolkit/exceptions.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,7 @@ class BaseToolkitError(Exception):
3939
non-BaseToolkitError.
4040
"""
4141

42-
def __init__(self,
43-
message=None,
44-
caused_by=None):
42+
def __init__(self, message=None, caused_by=None):
4543
formatted_message = BaseToolkitError._format_exception_message(message, caused_by)
4644
super(BaseToolkitError, self).__init__(formatted_message)
4745
self.message = formatted_message
@@ -63,7 +61,7 @@ def _format_exception_message(message, caused_by):
6361
elif caused_by:
6462
with warnings.catch_warnings():
6563
warnings.simplefilter("ignore") # Suppress deprecation warning
66-
formatted_message = getattr(caused_by, 'message', str(caused_by))
64+
formatted_message = getattr(caused_by, "message", str(caused_by))
6765
else:
6866
formatted_message = "unknown error occurred"
6967

src/sagemaker_algorithm_toolkit/hyperparameter_validation.py

Lines changed: 55 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,25 @@
1010
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
import sys
1413
import ast
14+
import sys
1515

1616
from sagemaker_algorithm_toolkit import exceptions as exc
1717

1818

1919
class Hyperparameter(object):
2020
"""Represents a single SageMaker training job hyperparameter."""
21-
def __init__(self,
22-
name,
23-
range=None,
24-
dependencies=None,
25-
required=None, default=None,
26-
tunable=False, tunable_recommended_range=None):
21+
22+
def __init__(
23+
self,
24+
name,
25+
range=None,
26+
dependencies=None,
27+
required=None,
28+
default=None,
29+
tunable=False,
30+
tunable_recommended_range=None,
31+
):
2732
if required is None and default is None:
2833
raise exc.AlgorithmError("At least one of 'required' or 'default' must be specified.")
2934

@@ -98,12 +103,9 @@ def format_tunable_range(self):
98103

99104
min_, max_ = self.tunable_recommended_range.format_as_integer()
100105
scale = self.tunable_recommended_range.scale
101-
return {"IntegerParameterRanges": [{
102-
"MinValue": min_,
103-
"MaxValue": max_,
104-
"Name": self.name,
105-
"ScalingType": scale
106-
}]}
106+
return {
107+
"IntegerParameterRanges": [{"MinValue": min_, "MaxValue": max_, "Name": self.name, "ScalingType": scale}]
108+
}
107109

108110

109111
class CategoricalHyperparameter(Hyperparameter):
@@ -122,15 +124,16 @@ def _format_range_helper(self, range_):
122124
return range_.format()
123125

124126
def format_range(self):
125-
return {"CategoricalParameterRangeSpecification": {
126-
"Values": self._format_range_helper(self.range)}}
127+
return {"CategoricalParameterRangeSpecification": {"Values": self._format_range_helper(self.range)}}
127128

128129
def format_tunable_range(self):
129130
if not self.tunable or self.tunable_recommended_range is None:
130131
return None
131-
return {"CategoricalParameterRanges": [
132-
{"Name": self.name,
133-
"Values": self._format_range_helper(self.tunable_recommended_range)}]}
132+
return {
133+
"CategoricalParameterRanges": [
134+
{"Name": self.name, "Values": self._format_range_helper(self.tunable_recommended_range)}
135+
]
136+
}
134137

135138

136139
class ContinuousHyperparameter(Hyperparameter):
@@ -156,11 +159,9 @@ def format_tunable_range(self):
156159

157160
min_, max_ = self.tunable_recommended_range.format_as_continuous()
158161
scale = self.tunable_recommended_range.scale
159-
return {"ContinuousParameterRanges": [
160-
{"Name": self.name,
161-
"MinValue": min_,
162-
"MaxValue": max_,
163-
"ScalingType": scale}]}
162+
return {
163+
"ContinuousParameterRanges": [{"Name": self.name, "MinValue": min_, "MaxValue": max_, "ScalingType": scale}]
164+
}
164165

165166

166167
class CommaSeparatedListHyperparameter(Hyperparameter):
@@ -191,9 +192,7 @@ def parse(self, value):
191192

192193
def format_range(self):
193194
min_, max_ = self.range.format_as_integer()
194-
return {"NestedParameterRangeSpecification": {
195-
"MinValue": min_,
196-
"MaxValue": max_}}
195+
return {"NestedParameterRangeSpecification": {"MinValue": min_, "MaxValue": max_}}
197196

198197
def validate_range(self, value):
199198
if any([element not in self.range for outer in value for element in outer]):
@@ -213,8 +212,7 @@ def parse(self, value):
213212
return value
214213

215214
def format_range(self):
216-
return {"TupleParameterRangeSpecification": {
217-
"Values": self.range}}
215+
return {"TupleParameterRangeSpecification": {"Values": self.range}}
218216

219217
def validate_range(self, value):
220218
if any([element not in self.range for element in value]):
@@ -290,8 +288,9 @@ def validate(self, user_hyperparameters):
290288
except exc.UserError:
291289
raise
292290
except Exception as e:
293-
raise exc.AlgorithmError("Hyperparameter {}: unexpected failure when validating {}".format(hp, value),
294-
caused_by=e)
291+
raise exc.AlgorithmError(
292+
"Hyperparameter {}: unexpected failure when validating {}".format(hp, value), caused_by=e
293+
)
295294

296295
# NOTE: 4. Validate dependencies.
297296
sorted_deps = self._sort_dependencies(converted_hyperparameters.keys())
@@ -300,9 +299,11 @@ def validate(self, user_hyperparameters):
300299
hp = sorted_deps.pop()
301300
value = converted_hyperparameters[hp]
302301
if self.hyperparameters[hp].dependencies:
303-
dependencies = {hp_d: new_validated_hyperparameters[hp_d]
304-
for hp_d in self.hyperparameters[hp].dependencies
305-
if hp_d in new_validated_hyperparameters}
302+
dependencies = {
303+
hp_d: new_validated_hyperparameters[hp_d]
304+
for hp_d in self.hyperparameters[hp].dependencies
305+
if hp_d in new_validated_hyperparameters
306+
}
306307
self.hyperparameters[hp].validate_dependencies(value, dependencies)
307308
new_validated_hyperparameters[hp] = value
308309

@@ -312,12 +313,12 @@ def __getitem__(self, name):
312313
return self.hyperparameters[name]
313314

314315
def format(self):
315-
return [hyperparameter.format()
316-
for name, hyperparameter in self.hyperparameters.items()]
316+
return [hyperparameter.format() for name, hyperparameter in self.hyperparameters.items()]
317317

318318

319319
class Range:
320320
"""Abstract interface for Hyperparameter.range objects."""
321+
321322
def __contains__(self, value):
322323
raise NotImplementedError
323324

@@ -363,30 +364,34 @@ def __str__(self):
363364

364365
def __contains__(self, value):
365366
return not (
366-
(self.min_open is not None and value <= self.min_open) or
367-
(self.min_closed is not None and value < self.min_closed) or
368-
(self.max_open is not None and value >= self.max_open) or
369-
(self.max_closed is not None and value > self.max_closed))
367+
(self.min_open is not None and value <= self.min_open)
368+
or (self.min_closed is not None and value < self.min_closed)
369+
or (self.max_open is not None and value >= self.max_open)
370+
or (self.max_closed is not None and value > self.max_closed)
371+
)
370372

371373
def _format_range_value(self, open_, closed, default):
372-
return str(open_ if open_ is not None
373-
else closed if closed is not None
374-
else default)
374+
return str(open_ if open_ is not None else closed if closed is not None else default)
375375

376376
def format_as_integer(self):
377-
max_neg_signed_int = -2 ** 31
377+
max_neg_signed_int = -(2 ** 31)
378378
max_signed_int = 2 ** 31 - 1
379-
return (self._format_range_value(self.min_open, self.min_closed, max_neg_signed_int),
380-
self._format_range_value(self.max_open, self.max_closed, max_signed_int))
379+
return (
380+
self._format_range_value(self.min_open, self.min_closed, max_neg_signed_int),
381+
self._format_range_value(self.max_open, self.max_closed, max_signed_int),
382+
)
381383

382384
def format_as_continuous(self):
383385
max_float = sys.float_info.max
384-
return (self._format_range_value(self.min_open, self.min_closed, -max_float),
385-
self._format_range_value(self.max_open, self.max_closed, max_float))
386+
return (
387+
self._format_range_value(self.min_open, self.min_closed, -max_float),
388+
self._format_range_value(self.max_open, self.max_closed, max_float),
389+
)
386390

387391

388392
class range_validator:
389393
"""Function decorator helper to override hyperparameter's range validation."""
394+
390395
def __init__(self, range):
391396
self.range = range
392397

@@ -400,11 +405,13 @@ def __str__(self_):
400405

401406
def __contains__(self_, value):
402407
return f(self.range, value)
408+
403409
return inner()
404410

405411

406412
class dependencies_validator:
407413
"""Function decorator helper to override hyperparameter's dependency validation."""
414+
408415
def __init__(self, dependencies):
409416
self.dependencies = dependencies
410417

@@ -421,4 +428,5 @@ def __next__(self):
421428

422429
def __call__(self, value, dependencies):
423430
return f(value, dependencies)
431+
424432
return inner()

0 commit comments

Comments
 (0)