Skip to content

[WIP] Test suite to detect changes that break loading of models (Issue #458) #519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ docsrc/source/contributing.md
examples/checkpoints/
build
docs/

_compatibility_data/

# mypy
.mypy_cache
Expand Down
88 changes: 88 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import nox
import argparse
from pathlib import Path
import os
import tempfile
import shutil


def git_rev_parse(session, commit):
print(f"Converting provided commit '{commit}' to Git revision...")
rev = session.run("git", "rev-parse", commit, external=True, silent=True).strip()
return rev


@nox.session
def save_and_load(session: nox.Session):
"""Save models and outputs to disk and compare outputs between versions.

This session installs the bayesflow version specified by the `commit` argument, and runs the test suite either in
"save" or in "load" mode. In save mode, results are stored to disk and a within-version load test is performed.
In load mode, the stored models and outputs are loaded from disk, and old and new outputs are compared.
This helps to detect breaking serialization between versions.

Important: The test code from the current checkout, not from `commit`, is used.
"""
# parse the arguments
parser = argparse.ArgumentParser()
# add subparsers for the two different commands
subparsers = parser.add_subparsers(help="subcommand help", dest="mode")
# save command
parser_save = subparsers.add_parser("save")
parser_save.add_argument("commit", type=str)
# load command, additional "from" argument
parser_load = subparsers.add_parser("load")
parser_load.add_argument("--from", type=str, required=True, dest="from_commit")
parser_load.add_argument("commit", type=str)

# keep unknown arguments, they will be forwarded to pytest below
args, unknownargs = parser.parse_known_args(session.posargs)

if args.mode == "load":
if args.from_commit == ".":
from_commit = "local"
else:
from_commit = git_rev_parse(session, args.from_commit)

from_path = Path("_compatibility_data").absolute() / from_commit
if not from_path.exists():
raise FileNotFoundError(
f"The directory {from_path} does not exist, cannot load data.\n"
f"Please run 'nox -- save {args.from_commit}' to create it, and then rerun this command."
)

print(f"Data will be loaded from path {from_path}.")

# install dependencies, currently the jax backend is used, but we could add a configuration option for this
repo_path = Path(os.curdir).absolute()
if args.commit == ".":
print("'.' provided, installing local state...")
if args.mode == "save":
print("Output will be saved to the alias 'local'")
commit = "local"
session.install(".[test]")
else:
commit = git_rev_parse(session, args.commit)
print("Installing specified revision...")
session.install(f"bayesflow[test] @ git+file://{str(repo_path)}@{commit}")
session.install("jax")

with tempfile.TemporaryDirectory() as tmpdirname:
# launch in temporary directory, as the local bayesflow would overshadow the installed one
tmpdirname = Path(tmpdirname)
# pass mode and data path to pytest, required for correct save and load behavior
if args.mode == "load":
data_path = from_path
else:
data_path = Path("_compatibility_data").absolute() / commit
if data_path.exists():
print(f"Removing existing data directory {data_path}...")
shutil.rmtree(data_path)

cmd = ["pytest", "tests/test_compatibility", f"--mode={args.mode}", f"--data-path={data_path}"]
cmd += unknownargs

print(f"Copying tests from working directory to temporary directory: {tmpdirname}")
shutil.copytree("tests", tmpdirname / "tests")
with session.chdir(tmpdirname):
session.run(*cmd, env={"KERAS_BACKEND": "jax"})
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ all = [
"sphinxcontrib-bibtex ~= 2.6",
"snowballstemmer ~= 2.2.0",
# test
"nox",
"pytest",
"pytest-cov",
"pytest-rerunfailures",
Expand Down Expand Up @@ -82,6 +83,7 @@ test = [
"nbconvert",
"ipython",
"ipykernel",
"nox",
"pytest",
"pytest-cov",
"pytest-rerunfailures",
Expand Down
35 changes: 20 additions & 15 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
BACKENDS = ["jax", "numpy", "tensorflow", "torch"]


def pytest_addoption(parser):
parser.addoption("--mode", choices=["save", "load"])
parser.addoption("--data-path", type=str)


def pytest_runtest_setup(item):
"""Skips backends by test markers. Unmarked tests are treated as backend-agnostic"""
backend = keras.backend.backend()
Expand Down Expand Up @@ -41,42 +46,42 @@ def pytest_make_parametrize_id(config, val, argname):
return f"{argname}={repr(val)}"


@pytest.fixture(params=[2], scope="session")
@pytest.fixture(params=[2])
def batch_size(request):
return request.param


@pytest.fixture(params=[None, 2, 3], scope="session")
@pytest.fixture(params=[None, 2, 3])
def conditions_size(request):
return request.param


@pytest.fixture(params=[1, 4], scope="session")
@pytest.fixture(params=[1, 4])
def summary_dim(request):
return request.param


@pytest.fixture(params=["two_moons"], scope="session")
@pytest.fixture(params=["two_moons"])
def dataset(request):
return request.getfixturevalue(request.param)


@pytest.fixture(params=[2, 3], scope="session")
@pytest.fixture(params=[2, 3])
def feature_size(request):
return request.param


@pytest.fixture(scope="session")
def random_conditions(batch_size, conditions_size):
@pytest.fixture()
def random_conditions(random_seed, batch_size, conditions_size):
if conditions_size is None:
return None

return keras.random.normal((batch_size, conditions_size))
return keras.random.normal((batch_size, conditions_size), seed=10)


@pytest.fixture(scope="session")
def random_samples(batch_size, feature_size):
return keras.random.normal((batch_size, feature_size))
@pytest.fixture()
def random_samples(random_seed, batch_size, feature_size):
return keras.random.normal((batch_size, feature_size), seed=20)


@pytest.fixture(scope="function", autouse=True)
Expand All @@ -86,11 +91,11 @@ def random_seed():
return seed


@pytest.fixture(scope="session")
def random_set(batch_size, set_size, feature_size):
return keras.random.normal((batch_size, set_size, feature_size))
@pytest.fixture()
def random_set(random_seed, batch_size, set_size, feature_size):
return keras.random.normal((batch_size, set_size, feature_size), seed=30)


@pytest.fixture(params=[2, 3], scope="session")
@pytest.fixture(params=[2, 3])
def set_size(request):
return request.param
200 changes: 200 additions & 0 deletions tests/test_compatibility/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import pytest
from pathlib import Path


@pytest.fixture(autouse=True, scope="session")
def mode(request):
mode = request.config.getoption("--mode")
if not mode:
return "save"
return mode


@pytest.fixture(autouse=True, scope="session")
def data_dir(request, tmp_path_factory):
# read config option to detect "unset" scenario
mode = request.config.getoption("--mode")
path = request.config.getoption("--data-path")
if not mode:
# if mode is unset, save and load from a temporary directory
return Path(tmp_path_factory.mktemp("_compatibility_data"))
elif not path:
pytest.exit(reason="Please provide the --data-path argument for model saving/loading.")
elif mode == "load":
path = Path(path)
if not path.exists():
pytest.exit(reason=f"Load path '{path}' does not exist. Please specify a valid load path", returncode=1)
return path


# reduce number of test configurations
@pytest.fixture(params=[None, 3])
def conditions_size(request):
return request.param


@pytest.fixture(params=[1, 2])
def summary_dim(request):
return request.param


@pytest.fixture(params=[4])
def feature_size(request):
return request.param


# Generic fixtures for use as input to the tested classes.
# The classes to test are constructed in the respective subdirectories, to allow for more thorough configuation.
@pytest.fixture(params=[None, "all"])
def standardize(request):
return request.param


@pytest.fixture()
def adapter(request):
import bayesflow as bf

match request.param:
case "summary":
return bf.Adapter.create_default("parameters").rename("observables", "summary_variables")
case "direct":
return bf.Adapter.create_default("parameters").rename("observables", "inference_conditions")
case "default":
return bf.Adapter.create_default("parameters")
case "empty":
return bf.Adapter()
case None:
return None
case _:
raise ValueError(f"Invalid request parameter for adapter: {request.param}")


@pytest.fixture(params=["coupling_flow", "flow_matching"])
def inference_network(request):
match request.param:
case "coupling_flow":
from bayesflow.networks import CouplingFlow

return CouplingFlow(depth=2)

case "flow_matching":
from bayesflow.networks import FlowMatching

return FlowMatching(subnet_kwargs=dict(widths=(32, 32)), use_optimal_transport=False)

case None:
return None

case _:
raise ValueError(f"Invalid request parameter for inference_network: {request.param}")


@pytest.fixture(params=["time_series_transformer", "fusion_transformer", "time_series_network", "custom"])
def summary_network(request):
match request.param:
case "time_series_transformer":
from bayesflow.networks import TimeSeriesTransformer

return TimeSeriesTransformer(embed_dims=(8, 8), mlp_widths=(16, 8), mlp_depths=(1, 1))

case "fusion_transformer":
from bayesflow.networks import FusionTransformer

return FusionTransformer(
embed_dims=(8, 8), mlp_widths=(8, 16), mlp_depths=(2, 1), template_dim=8, bidirectional=False
)

case "time_series_network":
from bayesflow.networks import TimeSeriesNetwork

return TimeSeriesNetwork(filters=4, skip_steps=2)

case "deep_set":
from bayesflow.networks import DeepSet

return DeepSet(summary_dim=2, depth=1)

case "custom":
from bayesflow.networks import SummaryNetwork
from bayesflow.utils.serialization import serializable
import keras

@serializable("test", disable_module_check=True)
class Custom(SummaryNetwork):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.inner = keras.Sequential([keras.layers.LSTM(8), keras.layers.Dense(4)])

def call(self, x, **kwargs):
return self.inner(x, training=kwargs.get("stage") == "training")

return Custom()

case "flatten":
# very simple summary network for fast training
from bayesflow.networks import SummaryNetwork
from bayesflow.utils.serialization import serializable
import keras

@serializable("test", disable_module_check=True)
class FlattenSummaryNetwork(SummaryNetwork):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.inner = keras.layers.Flatten()

def call(self, x, **kwargs):
return self.inner(x, training=kwargs.get("stage") == "training")

return FlattenSummaryNetwork()

case "fusion_network":
from bayesflow.networks import FusionNetwork, DeepSet

return FusionNetwork({"a": DeepSet(), "b": keras.layers.Flatten()}, head=keras.layers.Dense(2))
case None:
return None
case _:
raise ValueError(f"Invalid request parameter for summary_network: {request.param}")


@pytest.fixture(params=["sir", "fusion"])
def simulator(request):
match request.param:
case "sir":
from bayesflow.simulators import SIR

return SIR()
case "lotka_volterra":
from bayesflow.simulators import LotkaVolterra

return LotkaVolterra()

case "two_moons":
from bayesflow.simulators import TwoMoons

return TwoMoons()
case "normal":
from tests.utils.normal_simulator import NormalSimulator

return NormalSimulator()
case "fusion":
from bayesflow.simulators import Simulator
from bayesflow.types import Shape, Tensor
from bayesflow.utils.decorators import allow_batch_size
import numpy as np

class FusionSimulator(Simulator):
@allow_batch_size
def sample(self, batch_shape: Shape, num_observations: int = 4) -> dict[str, Tensor]:
mean = np.random.normal(0.0, 0.1, size=batch_shape + (2,))
noise = np.random.standard_normal(batch_shape + (num_observations, 2))

x = mean[:, None] + noise

return dict(mean=mean, a=x, b=x)

return FusionSimulator()
case None:
return None
case _:
raise ValueError(f"Invalid request parameter for simulator: {request.param}")
Empty file.
Loading