|
8 | 8 | from ats.tests import AtsTest |
9 | 9 | from lxml import etree |
10 | 10 | import logging |
11 | | -from .test_steps import geos |
| 11 | +from .test_steps import geos, pygeos_test |
12 | 12 | from .test_case import TestCase |
13 | 13 |
|
14 | 14 | test_build_failures = [] |
15 | 15 | logger = logging.getLogger( 'geos-ats' ) |
16 | 16 |
|
| 17 | +has_pygeos = True |
| 18 | +try: |
| 19 | + import pygeos |
| 20 | +except ImportError: |
| 21 | + logger.warning( 'pygeos is not available on this system' ) |
| 22 | + has_pygeos = False |
| 23 | + |
17 | 24 |
|
18 | 25 | @dataclass( frozen=True ) |
19 | 26 | class RestartcheckParameters: |
@@ -43,6 +50,7 @@ class TestDeck: |
43 | 50 | partitions: Iterable[ Tuple[ int, int, int ] ] |
44 | 51 | restart_step: int |
45 | 52 | check_step: int |
| 53 | + pygeos_script: str = '' |
46 | 54 | restartcheck_params: RestartcheckParameters = None |
47 | 55 | curvecheck_params: CurveCheckParameters = None |
48 | 56 |
|
@@ -121,33 +129,38 @@ def generate_geos_tests( decks: Iterable[ TestDeck ], test_type='smoke' ): |
121 | 129 | if curvecheck_params: |
122 | 130 | checks.append( 'curve' ) |
123 | 131 |
|
124 | | - steps = [ |
125 | | - geos( deck=xml_file, |
126 | | - name=base_name, |
127 | | - np=N, |
128 | | - ngpu=N, |
129 | | - x_partitions=nx, |
130 | | - y_partitions=ny, |
131 | | - z_partitions=nz, |
132 | | - restartcheck_params=restartcheck_params, |
133 | | - curvecheck_params=curvecheck_params ) |
134 | | - ] |
| 132 | + # Setup model inputs |
| 133 | + model_type = geos |
| 134 | + model_kwargs = { |
| 135 | + 'deck': xml_file, |
| 136 | + 'name': base_name, |
| 137 | + 'np': N, |
| 138 | + 'ngpu': N, |
| 139 | + 'x_partitions': nx, |
| 140 | + 'y_partitions': ny, |
| 141 | + 'z_partitions': nz, |
| 142 | + 'restartcheck_params': restartcheck_params, |
| 143 | + 'curvecheck_params': curvecheck_params |
| 144 | + } |
| 145 | + |
| 146 | + if deck.pygeos_script: |
| 147 | + if has_pygeos: |
| 148 | + model_type = pygeos_test |
| 149 | + model_kwargs[ 'script' ] = deck.pygeos_script |
| 150 | + else: |
| 151 | + logger.warning( f'Skipping test that requires pygeos: {deck.name}' ) |
| 152 | + continue |
| 153 | + |
| 154 | + steps = [ model_type( **model_kwargs ) ] |
135 | 155 |
|
136 | 156 | if deck.restart_step > 0: |
137 | 157 | checks.append( 'restart' ) |
138 | | - steps.append( |
139 | | - geos( deck=xml_file, |
140 | | - name="{:d}to{:d}".format( deck.restart_step, deck.check_step ), |
141 | | - np=N, |
142 | | - ngpu=N, |
143 | | - x_partitions=nx, |
144 | | - y_partitions=ny, |
145 | | - z_partitions=nz, |
146 | | - restart_file=os.path.join( testcase_name, |
147 | | - "{}_restart_{:09d}".format( base_name, deck.restart_step ) ), |
148 | | - baseline_pattern=f"{base_name}_restart_[0-9]+\.root", |
149 | | - allow_rebaseline=False, |
150 | | - restartcheck_params=restartcheck_params ) ) |
| 158 | + model_kwargs[ 'name' ] = "{:d}to{:d}".format( deck.restart_step, deck.check_step ) |
| 159 | + model_kwargs[ 'restart_file' ] = os.path.join( |
| 160 | + testcase_name, "{}_restart_{:09d}".format( base_name, deck.restart_step ) ) |
| 161 | + model_kwargs[ 'baseline_pattern' ] = f"{base_name}_restart_[0-9]+\.root" |
| 162 | + model_kwargs[ 'allow_rebaseline' ] = False |
| 163 | + steps.append( model_type( **model_kwargs ) ) |
151 | 164 |
|
152 | 165 | AtsTest.stick( level=ii ) |
153 | 166 | AtsTest.stick( checks=','.join( checks ) ) |
|
0 commit comments