Skip to content

Commit 367c2ff

Browse files
authored
Merge pull request #3570 from radhakrishnatg/custom-block
Add a default rule for custom blocks
2 parents 8dfd55b + de281d4 commit 367c2ff

File tree

12 files changed

+373
-71
lines changed

12 files changed

+373
-71
lines changed

pyomo/core/base/block.py

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313
import copy
14+
import functools
1415
import logging
1516
import sys
1617
import weakref
@@ -2120,7 +2121,6 @@ def __init__(self, *args, **kwargs):
21202121
# initializer
21212122
self._dense = kwargs.pop('dense', True)
21222123
kwargs.setdefault('ctype', Block)
2123-
ActiveIndexedComponent.__init__(self, *args, **kwargs)
21242124
if _options is not None:
21252125
deprecation_warning(
21262126
"The Block 'options=' keyword is deprecated. "
@@ -2129,19 +2129,10 @@ def __init__(self, *args, **kwargs):
21292129
"the function arguments",
21302130
version='5.7.2',
21312131
)
2132-
if self.is_indexed():
2133-
2134-
def rule_wrapper(model, *_idx):
2135-
return _rule(model, *_idx, **_options)
2136-
2137-
else:
2138-
2139-
def rule_wrapper(model):
2140-
return _rule(model, **_options)
2141-
2142-
self._rule = Initializer(rule_wrapper)
2132+
self._rule = Initializer(functools.partial(_rule, **_options))
21432133
else:
21442134
self._rule = Initializer(_rule)
2135+
ActiveIndexedComponent.__init__(self, *args, **kwargs)
21452136
if _concrete:
21462137
# Call self.construct() as opposed to just setting the _constructed
21472138
# flag so that the base class construction procedure fires (this
@@ -2426,6 +2417,7 @@ class CustomBlock(Block):
24262417
def __init__(self, *args, **kwargs):
24272418
if self._default_ctype is not None:
24282419
kwargs.setdefault('ctype', self._default_ctype)
2420+
kwargs.setdefault("rule", getattr(self, '_default_rule', None))
24292421
Block.__init__(self, *args, **kwargs)
24302422

24312423
def __new__(cls, *args, **kwargs):
@@ -2446,13 +2438,56 @@ def __new__(cls, *args, **kwargs):
24462438
return super().__new__(cls._indexed_custom_block, *args, **kwargs)
24472439

24482440

2449-
def declare_custom_block(name, new_ctype=None):
2441+
class _custom_block_rule_redirect(object):
2442+
"""Functor to redirect the default rule to a BlockData method"""
2443+
2444+
def __init__(self, cls, name):
2445+
self.cls = cls
2446+
self.name = name
2447+
2448+
def __call__(self, block, *args, **kwargs):
2449+
return getattr(self.cls, self.name)(block, *args, **kwargs)
2450+
2451+
2452+
def declare_custom_block(name, new_ctype=None, rule=None):
24502453
"""Decorator to declare components for a custom block data class
24512454
2455+
This decorator simplifies the definition of custom derived Block
2456+
classes. With this decorator, developers must only implement the
2457+
derived "Data" class. The decorator automatically creates the
2458+
derived containers using the provided name, and adds them to the
2459+
current module:
2460+
24522461
>>> @declare_custom_block(name="FooBlock")
24532462
... class FooBlockData(BlockData):
2454-
... # custom block data class
24552463
... pass
2464+
2465+
>>> s = FooBlock()
2466+
>>> type(s)
2467+
<class 'ScalarFooBlock'>
2468+
2469+
>>> s = FooBlock([1,2])
2470+
>>> type(s)
2471+
<class 'IndexedFooBlock'>
2472+
2473+
It is frequently desirable for the custom class to have a default
2474+
``rule`` for constructing and populating new instances. The default
2475+
rule can be provided either as an explicit function or a string. If
2476+
a string, the rule is obtained by attribute lookup on the derived
2477+
Data class:
2478+
2479+
>>> @declare_custom_block(name="BarBlock", rule="build")
2480+
... class BarBlockData(BlockData):
2481+
... def build(self, *args):
2482+
... self.x = Var(initialize=5)
2483+
2484+
>>> m = pyo.ConcreteModel()
2485+
>>> m.b = BarBlock([1,2])
2486+
>>> print(m.b[1].x.value)
2487+
5
2488+
>>> print(m.b[2].x.value)
2489+
5
2490+
24562491
"""
24572492

24582493
def block_data_decorator(block_data):
@@ -2476,9 +2511,16 @@ def block_data_decorator(block_data):
24762511
"_ComponentDataClass": block_data,
24772512
# By default this new block does not declare a new ctype
24782513
"_default_ctype": None,
2514+
# Define the default rule (may be None)
2515+
"_default_rule": rule,
24792516
},
24802517
)
24812518

2519+
# If the default rule is a string, then replace it with a
2520+
# function that will look up the attribute on the data class.
2521+
if type(rule) is str:
2522+
comp._default_rule = _custom_block_rule_redirect(block_data, rule)
2523+
24822524
if new_ctype is not None:
24832525
if new_ctype is True:
24842526
comp._default_ctype = comp

pyomo/core/base/component.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from pyomo.core.pyomoobject import PyomoObject
3333
from pyomo.core.base.component_namer import name_repr, index_repr
3434
from pyomo.core.base.global_set import UnindexedComponent_index
35+
from pyomo.core.base.initializer import PartialInitializer
3536

3637
logger = logging.getLogger('pyomo.core')
3738

@@ -451,10 +452,15 @@ def __init__(self, **kwds):
451452
self.doc = kwds.pop('doc', None)
452453
self._name = kwds.pop('name', None)
453454
if kwds:
454-
raise ValueError(
455-
"Unexpected keyword options found while constructing '%s':\n\t%s"
456-
% (type(self).__name__, ','.join(sorted(kwds.keys())))
457-
)
455+
# If there are leftover keywords, and the component has a
456+
# rule, pass those keywords on to the rule
457+
if getattr(self, '_rule', None) is not None:
458+
self._rule = PartialInitializer(self._rule, **kwds)
459+
else:
460+
raise ValueError(
461+
"Unexpected keyword options found while constructing '%s':\n\t%s"
462+
% (type(self).__name__, ','.join(sorted(kwds.keys())))
463+
)
458464
#
459465
# Verify that ctype has been specified.
460466
#

pyomo/core/base/constraint.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pyomo.common.pyomo_typing import overload
1717
from typing import Union, Type
1818

19-
from pyomo.common.deprecation import RenamedClass
19+
from pyomo.common.deprecation import RenamedClass, deprecated
2020
from pyomo.common.errors import DeveloperError, TemplateExpressionError
2121
from pyomo.common.formatting import tabular_writer
2222
from pyomo.common.log import is_debug_set
@@ -545,7 +545,7 @@ def __init__(self, template_info, component, index):
545545
def expr(self):
546546
# Note that it is faster to just generate the expression from
547547
# scratch than it is to clone it and replace the IndexTemplate objects
548-
self.set_value(self.parent_component().rule(self.parent_block(), self.index()))
548+
self.set_value(self.parent_component()._rule(self.parent_block(), self.index()))
549549
return self.expr
550550

551551
def template_expr(self):
@@ -640,9 +640,9 @@ def __init__(self, *args, **kwargs):
640640
_init = self._pop_from_kwargs('Constraint', kwargs, ('rule', 'expr'), None)
641641
# Special case: we accept 2- and 3-tuples as constraints
642642
if type(_init) is tuple:
643-
self.rule = Initializer(_init, treat_sequences_as_mappings=False)
643+
self._rule = Initializer(_init, treat_sequences_as_mappings=False)
644644
else:
645-
self.rule = Initializer(_init)
645+
self._rule = Initializer(_init)
646646

647647
kwargs.setdefault('ctype', Constraint)
648648
ActiveIndexedComponent.__init__(self, *args, **kwargs)
@@ -663,7 +663,7 @@ def construct(self, data=None):
663663
for _set in self._anonymous_sets:
664664
_set.construct()
665665

666-
rule = self.rule
666+
rule = self._rule
667667
try:
668668
# We do not (currently) accept data for constructing Constraints
669669
index = None
@@ -719,9 +719,9 @@ def construct(self, data=None):
719719
timer.report()
720720

721721
def _getitem_when_not_present(self, idx):
722-
if self.rule is None:
722+
if self._rule is None:
723723
raise KeyError(idx)
724-
con = self._setitem_when_not_present(idx, self.rule(self.parent_block(), idx))
724+
con = self._setitem_when_not_present(idx, self._rule(self.parent_block(), idx))
725725
if con is None:
726726
raise KeyError(idx)
727727
return con
@@ -746,6 +746,20 @@ def _pprint(self):
746746
],
747747
)
748748

749+
@property
750+
def rule(self):
751+
return self._rule
752+
753+
@rule.setter
754+
@deprecated(
755+
f"The 'Constraint.rule' attribute will be made "
756+
"read-only in a future Pyomo release.",
757+
version='6.9.3.dev0',
758+
remove_in='6.11',
759+
)
760+
def rule(self, rule):
761+
self._rule = rule
762+
749763
def display(self, prefix="", ostream=None):
750764
"""
751765
Print component state information
@@ -971,14 +985,14 @@ def __init__(self, **kwargs):
971985

972986
super().__init__(Set(dimen=1), **kwargs)
973987

974-
self.rule = Initializer(
988+
self._rule = Initializer(
975989
_rule, treat_sequences_as_mappings=False, allow_generators=True
976990
)
977991
# HACK to make the "counted call" syntax work. We wait until
978992
# after the base class is set up so that is_indexed() is
979993
# reliable.
980-
if self.rule is not None and type(self.rule) is IndexedCallInitializer:
981-
self.rule = CountedCallInitializer(self, self.rule, self._starting_index)
994+
if self._rule is not None and type(self._rule) is IndexedCallInitializer:
995+
self._rule = CountedCallInitializer(self, self._rule, self._starting_index)
982996

983997
def construct(self, data=None):
984998
"""
@@ -995,8 +1009,8 @@ def construct(self, data=None):
9951009
for _set in self._anonymous_sets:
9961010
_set.construct()
9971011

998-
if self.rule is not None:
999-
_rule = self.rule(self.parent_block(), ())
1012+
if self._rule is not None:
1013+
_rule = self._rule(self.parent_block(), ())
10001014
for cc in iter(_rule):
10011015
if cc is ConstraintList.End:
10021016
break

pyomo/core/base/initializer.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -340,27 +340,27 @@ class IndexedCallInitializer(InitializerBase):
340340
def __init__(self, _fcn):
341341
self._fcn = _fcn
342342

343-
def __call__(self, parent, idx):
343+
def __call__(self, parent, idx, **kwargs):
344344
# Note: this is called by a component using data from a Set (so
345345
# any tuple-like type should have already been checked and
346346
# converted to a tuple; or flattening is turned off and it is
347347
# the user's responsibility to sort things out.
348348
if idx.__class__ is tuple:
349-
return self._fcn(parent, *idx)
349+
return self._fcn(parent, *idx, **kwargs)
350350
else:
351-
return self._fcn(parent, idx)
351+
return self._fcn(parent, idx, **kwargs)
352352

353353

354354
class ParameterizedIndexedCallInitializer(IndexedCallInitializer):
355355
"""IndexedCallInitializer that accepts additional arguments"""
356356

357357
__slots__ = ()
358358

359-
def __call__(self, parent, idx, *args):
359+
def __call__(self, parent, idx, *args, **kwargs):
360360
if idx.__class__ is tuple:
361-
return self._fcn(parent, *args, *idx)
361+
return self._fcn(parent, *args, *idx, **kwargs)
362362
else:
363-
return self._fcn(parent, *args, idx)
363+
return self._fcn(parent, *args, idx, **kwargs)
364364

365365

366366
class CountedCallGenerator(object):
@@ -481,8 +481,8 @@ def __init__(self, _fcn, constant=True):
481481
self._fcn = _fcn
482482
self._constant = constant
483483

484-
def __call__(self, parent, idx):
485-
return self._fcn(parent)
484+
def __call__(self, parent, idx, **kwargs):
485+
return self._fcn(parent, **kwargs)
486486

487487
def constant(self):
488488
"""Return True if this initializer is constant across all indices"""
@@ -494,8 +494,8 @@ class ParameterizedScalarCallInitializer(ScalarCallInitializer):
494494

495495
__slots__ = ()
496496

497-
def __call__(self, parent, idx, *args):
498-
return self._fcn(parent, *args)
497+
def __call__(self, parent, idx, *args, **kwargs):
498+
return self._fcn(parent, *args, **kwargs)
499499

500500

501501
class DefaultInitializer(InitializerBase):
@@ -523,9 +523,9 @@ def __init__(self, initializer, default, exceptions):
523523
self._default = default
524524
self._exceptions = exceptions
525525

526-
def __call__(self, parent, index):
526+
def __call__(self, parent, index, **kwargs):
527527
try:
528-
return self._initializer(parent, index)
528+
return self._initializer(parent, index, **kwargs)
529529
except self._exceptions:
530530
return self._default
531531

@@ -542,7 +542,7 @@ def indices(self):
542542

543543

544544
class ParameterizedInitializer(InitializerBase):
545-
"""Base class for all Initializer objects"""
545+
"""Wrapper to provide additional positional arguments to Initializer objects"""
546546

547547
__slots__ = ('_base_initializer',)
548548

@@ -565,8 +565,33 @@ def indices(self):
565565
"""
566566
return self._base_initializer.indices()
567567

568-
def __call__(self, parent, idx, *args):
569-
return self._base_initializer(parent, idx)(parent, *args)
568+
def __call__(self, parent, idx, *args, **kwargs):
569+
return self._base_initializer(parent, idx)(parent, *args, **kwargs)
570+
571+
572+
class PartialInitializer(InitializerBase):
573+
"""Partial wrapper of an InitializerBase that supplies additional arguments"""
574+
575+
__slots__ = ('_fcn',)
576+
577+
def __init__(self, _fcn, *args, **kwargs):
578+
self._fcn = functools.partial(_fcn, *args, **kwargs)
579+
580+
def constant(self):
581+
return self._fcn.func.constant()
582+
583+
def contains_indices(self):
584+
return self._fcn.func.contains_indices()
585+
586+
def indices(self):
587+
return self._fcn.func.indices()
588+
589+
def __call__(self, parent, idx, *args, **kwargs):
590+
# Note that the Initializer.__call__ API is different from the
591+
# rule API. As a result, we cannot just inherit from
592+
# IndexedCallInitializer and must instead implement our own
593+
# __call__ here.
594+
return self._fcn(parent, idx, *args, **kwargs)
570595

571596

572597
_bound_sequence_types = collections.defaultdict(None.__class__)
@@ -618,8 +643,8 @@ def __init__(self, arg, obj=NOTSET):
618643
arg, treat_sequences_as_mappings=treat_sequences_as_mappings
619644
)
620645

621-
def __call__(self, parent, index):
622-
val = self._initializer(parent, index)
646+
def __call__(self, parent, index, **kwargs):
647+
val = self._initializer(parent, index, **kwargs)
623648
if _bound_sequence_types[val.__class__]:
624649
return val
625650
if _bound_sequence_types[val.__class__] is None:

0 commit comments

Comments
 (0)