1010# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
13- import sys
1413import ast
14+ import sys
1515
1616from sagemaker_algorithm_toolkit import exceptions as exc
1717
1818
1919class Hyperparameter (object ):
2020 """Represents a single SageMaker training job hyperparameter."""
21- def __init__ (self ,
22- name ,
23- range = None ,
24- dependencies = None ,
25- required = None , default = None ,
26- tunable = False , tunable_recommended_range = None ):
21+
22+ def __init__ (
23+ self ,
24+ name ,
25+ range = None ,
26+ dependencies = None ,
27+ required = None ,
28+ default = None ,
29+ tunable = False ,
30+ tunable_recommended_range = None ,
31+ ):
2732 if required is None and default is None :
2833 raise exc .AlgorithmError ("At least one of 'required' or 'default' must be specified." )
2934
@@ -98,12 +103,9 @@ def format_tunable_range(self):
98103
99104 min_ , max_ = self .tunable_recommended_range .format_as_integer ()
100105 scale = self .tunable_recommended_range .scale
101- return {"IntegerParameterRanges" : [{
102- "MinValue" : min_ ,
103- "MaxValue" : max_ ,
104- "Name" : self .name ,
105- "ScalingType" : scale
106- }]}
106+ return {
107+ "IntegerParameterRanges" : [{"MinValue" : min_ , "MaxValue" : max_ , "Name" : self .name , "ScalingType" : scale }]
108+ }
107109
108110
109111class CategoricalHyperparameter (Hyperparameter ):
@@ -122,15 +124,16 @@ def _format_range_helper(self, range_):
122124 return range_ .format ()
123125
124126 def format_range (self ):
125- return {"CategoricalParameterRangeSpecification" : {
126- "Values" : self ._format_range_helper (self .range )}}
127+ return {"CategoricalParameterRangeSpecification" : {"Values" : self ._format_range_helper (self .range )}}
127128
128129 def format_tunable_range (self ):
129130 if not self .tunable or self .tunable_recommended_range is None :
130131 return None
131- return {"CategoricalParameterRanges" : [
132- {"Name" : self .name ,
133- "Values" : self ._format_range_helper (self .tunable_recommended_range )}]}
132+ return {
133+ "CategoricalParameterRanges" : [
134+ {"Name" : self .name , "Values" : self ._format_range_helper (self .tunable_recommended_range )}
135+ ]
136+ }
134137
135138
136139class ContinuousHyperparameter (Hyperparameter ):
@@ -156,11 +159,9 @@ def format_tunable_range(self):
156159
157160 min_ , max_ = self .tunable_recommended_range .format_as_continuous ()
158161 scale = self .tunable_recommended_range .scale
159- return {"ContinuousParameterRanges" : [
160- {"Name" : self .name ,
161- "MinValue" : min_ ,
162- "MaxValue" : max_ ,
163- "ScalingType" : scale }]}
162+ return {
163+ "ContinuousParameterRanges" : [{"Name" : self .name , "MinValue" : min_ , "MaxValue" : max_ , "ScalingType" : scale }]
164+ }
164165
165166
166167class CommaSeparatedListHyperparameter (Hyperparameter ):
@@ -191,9 +192,7 @@ def parse(self, value):
191192
192193 def format_range (self ):
193194 min_ , max_ = self .range .format_as_integer ()
194- return {"NestedParameterRangeSpecification" : {
195- "MinValue" : min_ ,
196- "MaxValue" : max_ }}
195+ return {"NestedParameterRangeSpecification" : {"MinValue" : min_ , "MaxValue" : max_ }}
197196
198197 def validate_range (self , value ):
199198 if any ([element not in self .range for outer in value for element in outer ]):
@@ -213,8 +212,7 @@ def parse(self, value):
213212 return value
214213
215214 def format_range (self ):
216- return {"TupleParameterRangeSpecification" : {
217- "Values" : self .range }}
215+ return {"TupleParameterRangeSpecification" : {"Values" : self .range }}
218216
219217 def validate_range (self , value ):
220218 if any ([element not in self .range for element in value ]):
@@ -290,8 +288,9 @@ def validate(self, user_hyperparameters):
290288 except exc .UserError :
291289 raise
292290 except Exception as e :
293- raise exc .AlgorithmError ("Hyperparameter {}: unexpected failure when validating {}" .format (hp , value ),
294- caused_by = e )
291+ raise exc .AlgorithmError (
292+ "Hyperparameter {}: unexpected failure when validating {}" .format (hp , value ), caused_by = e
293+ )
295294
296295 # NOTE: 4. Validate dependencies.
297296 sorted_deps = self ._sort_dependencies (converted_hyperparameters .keys ())
@@ -300,9 +299,11 @@ def validate(self, user_hyperparameters):
300299 hp = sorted_deps .pop ()
301300 value = converted_hyperparameters [hp ]
302301 if self .hyperparameters [hp ].dependencies :
303- dependencies = {hp_d : new_validated_hyperparameters [hp_d ]
304- for hp_d in self .hyperparameters [hp ].dependencies
305- if hp_d in new_validated_hyperparameters }
302+ dependencies = {
303+ hp_d : new_validated_hyperparameters [hp_d ]
304+ for hp_d in self .hyperparameters [hp ].dependencies
305+ if hp_d in new_validated_hyperparameters
306+ }
306307 self .hyperparameters [hp ].validate_dependencies (value , dependencies )
307308 new_validated_hyperparameters [hp ] = value
308309
@@ -312,12 +313,12 @@ def __getitem__(self, name):
312313 return self .hyperparameters [name ]
313314
314315 def format (self ):
315- return [hyperparameter .format ()
316- for name , hyperparameter in self .hyperparameters .items ()]
316+ return [hyperparameter .format () for name , hyperparameter in self .hyperparameters .items ()]
317317
318318
319319class Range :
320320 """Abstract interface for Hyperparameter.range objects."""
321+
321322 def __contains__ (self , value ):
322323 raise NotImplementedError
323324
@@ -363,30 +364,34 @@ def __str__(self):
363364
364365 def __contains__ (self , value ):
365366 return not (
366- (self .min_open is not None and value <= self .min_open ) or
367- (self .min_closed is not None and value < self .min_closed ) or
368- (self .max_open is not None and value >= self .max_open ) or
369- (self .max_closed is not None and value > self .max_closed ))
367+ (self .min_open is not None and value <= self .min_open )
368+ or (self .min_closed is not None and value < self .min_closed )
369+ or (self .max_open is not None and value >= self .max_open )
370+ or (self .max_closed is not None and value > self .max_closed )
371+ )
370372
371373 def _format_range_value (self , open_ , closed , default ):
372- return str (open_ if open_ is not None
373- else closed if closed is not None
374- else default )
374+ return str (open_ if open_ is not None else closed if closed is not None else default )
375375
376376 def format_as_integer (self ):
377- max_neg_signed_int = - 2 ** 31
377+ max_neg_signed_int = - ( 2 ** 31 )
378378 max_signed_int = 2 ** 31 - 1
379- return (self ._format_range_value (self .min_open , self .min_closed , max_neg_signed_int ),
380- self ._format_range_value (self .max_open , self .max_closed , max_signed_int ))
379+ return (
380+ self ._format_range_value (self .min_open , self .min_closed , max_neg_signed_int ),
381+ self ._format_range_value (self .max_open , self .max_closed , max_signed_int ),
382+ )
381383
382384 def format_as_continuous (self ):
383385 max_float = sys .float_info .max
384- return (self ._format_range_value (self .min_open , self .min_closed , - max_float ),
385- self ._format_range_value (self .max_open , self .max_closed , max_float ))
386+ return (
387+ self ._format_range_value (self .min_open , self .min_closed , - max_float ),
388+ self ._format_range_value (self .max_open , self .max_closed , max_float ),
389+ )
386390
387391
388392class range_validator :
389393 """Function decorator helper to override hyperparameter's range validation."""
394+
390395 def __init__ (self , range ):
391396 self .range = range
392397
@@ -400,11 +405,13 @@ def __str__(self_):
400405
401406 def __contains__ (self_ , value ):
402407 return f (self .range , value )
408+
403409 return inner ()
404410
405411
406412class dependencies_validator :
407413 """Function decorator helper to override hyperparameter's dependency validation."""
414+
408415 def __init__ (self , dependencies ):
409416 self .dependencies = dependencies
410417
@@ -421,4 +428,5 @@ def __next__(self):
421428
422429 def __call__ (self , value , dependencies ):
423430 return f (value , dependencies )
431+
424432 return inner ()
0 commit comments