Skip to content

Commit e7e8a54

Browse files
committed
Move all custom Exceptions to exceptions.py
1 parent 4856e22 commit e7e8a54

29 files changed

+101
-98
lines changed

pymc/backends/base.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,6 @@
4545
logger = logging.getLogger(__name__)
4646

4747

48-
class BackendError(Exception):
49-
pass
50-
51-
5248
class IBaseTrace(ABC, Sized):
5349
"""Minimal interface needed to record and access draws and stats for one MCMC chain."""
5450

pymc/data.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
import pymc as pm
3838

39+
from pymc.exceptions import ShapeError
3940
from pymc.pytensorf import convert_observed_data
4041

4142
__all__ = [
@@ -237,7 +238,7 @@ def determine_coords(
237238

238239
if isinstance(value, np.ndarray) and dims is not None:
239240
if len(dims) != value.ndim:
240-
raise pm.exceptions.ShapeError(
241+
raise ShapeError(
241242
"Invalid data shape. The rank of the dataset must match the " "length of `dims`.",
242243
actual=value.shape,
243244
expected=value.ndim,
@@ -445,7 +446,7 @@ def Data(
445446
if isinstance(dims, str):
446447
dims = (dims,)
447448
if not (dims is None or len(dims) == x.ndim):
448-
raise pm.exceptions.ShapeError(
449+
raise ShapeError(
449450
"Length of `dims` must match the dimensions of the dataset.",
450451
actual=len(dims),
451452
expected=x.ndim,

pymc/exceptions.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__all__ = [
16-
"SamplingError",
17-
"ImputationWarning",
18-
"ShapeWarning",
19-
"ShapeError",
20-
]
21-
2215

2316
class SamplingError(RuntimeError):
2417
pass
@@ -74,3 +67,55 @@ class NotConstantValueError(ValueError):
7467

7568
class BlockModelAccessError(RuntimeError):
7669
pass
70+
71+
72+
class ParallelSamplingError(Exception):
73+
def __init__(self, message, chain):
74+
super().__init__(message)
75+
self._chain = chain
76+
77+
78+
class RemoteTraceback(Exception):
79+
def __init__(self, tb):
80+
self.tb = tb
81+
82+
def __str__(self):
83+
return self.tb
84+
85+
86+
class VariationalInferenceError(Exception):
87+
"""Exception for VI specific cases"""
88+
89+
90+
class NotImplementedInference(VariationalInferenceError, NotImplementedError):
91+
"""Marking non functional parts of code"""
92+
93+
94+
class ExplicitInferenceError(VariationalInferenceError, TypeError):
95+
"""Exception for bad explicit inference"""
96+
97+
98+
class ParametrizationError(VariationalInferenceError, ValueError):
99+
"""Error raised in case of bad parametrization"""
100+
101+
102+
class GroupError(VariationalInferenceError, TypeError):
103+
"""Error related to VI groups"""
104+
105+
106+
class IntegrationError(RuntimeError):
107+
pass
108+
109+
110+
class PositiveDefiniteError(ValueError):
111+
def __init__(self, msg, idx):
112+
super().__init__(msg)
113+
self.idx = idx
114+
self.msg = msg
115+
116+
def __str__(self):
117+
return f"Scaling is not positive definite: {self.msg}. Check indexes {self.idx}."
118+
119+
120+
class ParameterValueError(ValueError):
121+
"""Exception for invalid parameters values in logprob graphs"""

pymc/logprob/utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from pytensor.tensor.random.op import RandomVariable
6464
from pytensor.tensor.variable import TensorVariable
6565

66+
from pymc.exceptions import ParameterValueError
6667
from pymc.logprob.abstract import MeasurableVariable, _logprob
6768
from pymc.util import makeiter
6869

@@ -231,10 +232,6 @@ def check_potential_measurability(
231232
return False
232233

233234

234-
class ParameterValueError(ValueError):
235-
"""Exception for invalid parameters values in logprob graphs"""
236-
237-
238235
class CheckParameterValue(CheckAndRaise):
239236
"""Implements a parameter value check in a logprob graph.
240237

pymc/model/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@
5959
from pymc.exceptions import (
6060
BlockModelAccessError,
6161
ImputationWarning,
62+
ParameterValueError,
6263
SamplingError,
6364
ShapeError,
6465
ShapeWarning,
6566
)
6667
from pymc.initial_point import make_initial_point_fn
6768
from pymc.logprob.basic import transformed_conditional_logp
68-
from pymc.logprob.utils import ParameterValueError
6969
from pymc.model_graph import model_to_graphviz
7070
from pymc.pytensorf import (
7171
PointFunc,

pymc/sampling/mcmc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
)
5757
from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains
5858
from pymc.blocking import DictToArrayBijection
59-
from pymc.exceptions import SamplingError
59+
from pymc.exceptions import ParallelSamplingError, SamplingError
6060
from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain
6161
from pymc.model import Model, modelcontext
6262
from pymc.sampling.parallel import Draw, _cpu_count
@@ -1199,7 +1199,7 @@ def _mp_sample(
11991199
if callback is not None:
12001200
callback(trace=strace, draw=draw)
12011201

1202-
except ps.ParallelSamplingError as error:
1202+
except ParallelSamplingError as error:
12031203
strace = traces[error._chain]
12041204
for strace in traces:
12051205
strace.close()

pymc/sampling/parallel.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,13 @@
2929
from fastprogress.fastprogress import progress_bar
3030

3131
from pymc.blocking import DictToArrayBijection
32-
from pymc.exceptions import SamplingError
32+
from pymc.exceptions import ParallelSamplingError, RemoteTraceback, SamplingError
3333
from pymc.util import RandomSeed
3434

3535
logger = logging.getLogger(__name__)
3636

3737

38-
class ParallelSamplingError(Exception):
39-
def __init__(self, message, chain):
40-
super().__init__(message)
41-
self._chain = chain
42-
43-
4438
# Taken from https://hg.python.org/cpython/rev/c4f92b597074
45-
class RemoteTraceback(Exception):
46-
def __init__(self, tb):
47-
self.tb = tb
48-
49-
def __str__(self):
50-
return self.tb
5139

5240

5341
class ExceptionWithTraceback:

pymc/step_methods/hmc/base_hmc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@
2323
import numpy as np
2424

2525
from pymc.blocking import DictToArrayBijection, RaveledVars, StatsType
26-
from pymc.exceptions import SamplingError
26+
from pymc.exceptions import IntegrationError, SamplingError
2727
from pymc.model import Point, modelcontext
2828
from pymc.pytensorf import floatX
2929
from pymc.stats.convergence import SamplerWarning, WarningType
3030
from pymc.step_methods import step_sizes
3131
from pymc.step_methods.arraystep import GradientSharedStep
3232
from pymc.step_methods.hmc import integration
33-
from pymc.step_methods.hmc.integration import IntegrationError, State
33+
from pymc.step_methods.hmc.integration import State
3434
from pymc.step_methods.hmc.quadpotential import QuadPotentialDiagAdapt, quad_potential
3535
from pymc.tuning import guess_scaling
3636
from pymc.util import get_value_vars_from_user_vars

pymc/step_methods/hmc/hmc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818

1919
import numpy as np
2020

21+
from pymc.exceptions import IntegrationError
2122
from pymc.stats.convergence import SamplerWarning
2223
from pymc.step_methods.compound import Competence
2324
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
24-
from pymc.step_methods.hmc.integration import IntegrationError, State
25+
from pymc.step_methods.hmc.integration import State
2526
from pymc.vartypes import discrete_types
2627

2728
__all__ = ["HamiltonianMC"]

pymc/step_methods/hmc/integration.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from scipy import linalg
2020

2121
from pymc.blocking import RaveledVars
22+
from pymc.exceptions import IntegrationError
2223
from pymc.step_methods.hmc.quadpotential import QuadPotential
2324

2425

@@ -32,10 +33,6 @@ class State(NamedTuple):
3233
index_in_trajectory: int
3334

3435

35-
class IntegrationError(RuntimeError):
36-
pass
37-
38-
3936
class CpuLeapfrogIntegrator:
4037
def __init__(self, potential: QuadPotential, logp_dlogp_func):
4138
"""Leapfrog integrator using CPU."""

0 commit comments

Comments
 (0)