Skip to content

Commit e994808

Browse files
committed
Merge branch 'main' into maxim/dlclive3
2 parents 4475ea2 + 44c960f commit e994808

File tree

4 files changed

+73
-3
lines changed

4 files changed

+73
-3
lines changed

.github/workflows/testing.yml

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,23 @@ jobs:
5353
- name: Install and test
5454
shell: bash -el {0} # Important: activates the conda environment
5555
run: |
56-
python -m pip install --upgrade pip wheel poetry
57-
python -m poetry install --extras "tf" --extras "pytorch"
58-
python -m poetry run dlc-live-test --nodisplay
56+
conda install pytables==3.8.0 "numpy<2"
57+
58+
- name: Install dependencies via Conda
59+
shell: bash -el {0}
60+
run: conda install -y "numpy>=1.26,<2.0"
61+
62+
- name: Install Poetry
63+
run: pip install --upgrade pip wheel poetry
64+
65+
- name: Regenerate Poetry lock
66+
run: poetry lock --no-cache
67+
68+
- name: Install project dependencies
69+
run: poetry install --with dev --extras "tf" --extras "pytorch"
70+
71+
- name: Run DLC Live Tests
72+
run: poetry run dlc-live-test --nodisplay
73+
74+
- name: Run Functional Benchmark Test
75+
run: poetry run pytest tests/test_benchmark_script.py

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ py-cpuinfo = ">=5.0.0"
3535
tqdm = "^4.62.3"
3636
pandas = ">=1.0.1,!=1.5.0"
3737
tables = "^3.8"
38+
pytest = "^8.0"
3839
dlclibrary = ">=0.0.6"
40+
3941
# PyTorch models
4042
scipy = ">=1.9"
4143
timm = { version = ">=1.0.7", optional = true }

pytest.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[pytest]
2+
markers =
3+
functional: functional tests

tests/test_benchmark_script.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import os
2+
import glob
3+
import pathlib
4+
import pytest
5+
from dlclive import benchmark_videos, download_benchmarking_data
6+
from dlclive.engine import Engine
7+
8+
@pytest.mark.functional
9+
def test_benchmark_script_runs(tmp_path):
10+
datafolder = tmp_path / "Data-DLC-live-benchmark"
11+
download_benchmarking_data(str(datafolder))
12+
13+
dog_models = glob.glob(str(datafolder / "dog" / "*[!avi]"))
14+
dog_video = glob.glob(str(datafolder / "dog" / "*.avi"))[0]
15+
mouse_models = glob.glob(str(datafolder / "mouse_lick" / "*[!avi]"))
16+
mouse_video = glob.glob(str(datafolder / "mouse_lick" / "*.avi"))[0]
17+
18+
out_dir = tmp_path / "results"
19+
out_dir.mkdir(exist_ok=True)
20+
21+
pixels = [100, 400] #[2500, 10000]
22+
n_frames = 5
23+
24+
for model_path in dog_models:
25+
print(f"Running dog model: {model_path}")
26+
result = benchmark_videos(
27+
model_path=model_path,
28+
model_type="base" if Engine.from_model_path(model_path) == Engine.TENSORFLOW else "pytorch",
29+
video_path=dog_video,
30+
output=str(out_dir),
31+
n_frames=n_frames,
32+
pixels=pixels
33+
)
34+
print("Dog model result:", result)
35+
36+
for model_path in mouse_models:
37+
print(f"Running mouse model: {model_path}")
38+
result = benchmark_videos(
39+
model_path=model_path,
40+
model_type="base" if Engine.from_model_path(model_path) == Engine.TENSORFLOW else "pytorch",
41+
video_path=mouse_video,
42+
output=str(out_dir),
43+
n_frames=n_frames,
44+
pixels=pixels
45+
)
46+
print("Mouse model result:", result)
47+
48+
assert any(out_dir.iterdir())

0 commit comments

Comments
 (0)