Skip to content

Commit 4475ea2

Browse files
committed
Refactor benchmarking: model_type in benchmark_videos()
1 parent 664835c commit 4475ea2

File tree

4 files changed

+53
-8
lines changed

4 files changed

+53
-8
lines changed

benchmarking/run_dlclive_benchmark.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import glob
1313

1414
from dlclive import benchmark_videos, download_benchmarking_data
15+
from dlclive.engine import Engine
1516

1617
datafolder = os.path.join(
1718
pathlib.Path(__file__).parent.absolute(), "Data-DLC-live-benchmark"
@@ -36,8 +37,22 @@
3637
if not os.path.isdir(out_dir):
3738
os.mkdir(out_dir)
3839

39-
for m in dog_models:
40-
benchmark_videos(m, dog_video, output=out_dir, n_frames=n_frames, pixels=pixels)
41-
42-
for m in mouse_models:
43-
benchmark_videos(m, mouse_video, output=out_dir, n_frames=n_frames, pixels=pixels)
40+
for model_path in dog_models:
41+
benchmark_videos(
42+
model_path=model_path,
43+
model_type="base" if Engine.from_model_path(model_path) == Engine.TENSORFLOW else "pytorch",
44+
video_path=dog_video,
45+
output=out_dir,
46+
n_frames=n_frames,
47+
pixels=pixels
48+
)
49+
50+
for model_path in mouse_models:
51+
benchmark_videos(
52+
model_path=model_path,
53+
model_type="base" if Engine.from_model_path(model_path) == Engine.TENSORFLOW else "pytorch",
54+
video_path=mouse_video,
55+
output=out_dir,
56+
n_frames=n_frames,
57+
pixels=pixels
58+
)

dlclive/benchmark.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def show_progress(count, block_size, total_size):
6666

6767
def benchmark_videos(
6868
model_path,
69+
model_type,
6970
video_path,
7071
output=None,
7172
n_frames=1000,
@@ -94,6 +95,9 @@ def benchmark_videos(
9495
----------
9596
model_path : str
9697
path to exported DeepLabCut model
98+
model_type: string, optional
99+
Which model to use. For the PyTorch engine, options are [`pytorch`]. For the
100+
TensorFlow engine, options are [`base`, `tensorrt`, `lite`].
97101
video_path : str or list
98102
path to video file or list of paths to video files
99103
output : str
@@ -169,7 +173,7 @@ def benchmark_videos(
169173

170174
this_inf_times, this_im_size, meta = benchmark(
171175
model_path=model_path,
172-
model_type="base",
176+
model_type=model_type,
173177
video_path=video,
174178
tf_config=tf_config,
175179
resize=resize[i],
@@ -375,6 +379,7 @@ def benchmark(
375379
Path to the DeepLabCut model.
376380
model_type : str
377381
Which model to use. For the PyTorch engine, options are [`pytorch`]. For the
382+
TensorFlow engine, options are [`base`, `tensorrt`, `lite`].
378383
video_path : str
379384
Path to the video file to be analyzed.
380385
TensorFlow engine, options are [`base`, `tensorrt`, `lite`].

dlclive/check_install/check_install.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model
1818

19-
from dlclive.benchmark_tf import benchmark_videos
19+
from dlclive.benchmark import benchmark_videos
20+
from dlclive.engine import Engine
2021

2122
MODEL_NAME = "superanimal_quadruped"
2223
SNAPSHOT_NAME = "snapshot-700000.pb"
@@ -77,7 +78,12 @@ def main():
7778
# run benchmark videos
7879
print("\n Running inference...\n")
7980
benchmark_videos(
80-
str(model_dir), video_file, display=display, resize=0.5, pcutoff=0.25
81+
model_path=str(model_dir),
82+
model_type="base" if Engine.from_model_path(model_dir) == Engine.TENSORFLOW else "pytorch",
83+
video_path=video_file,
84+
display=display,
85+
resize=0.5,
86+
pcutoff=0.25
8187
)
8288

8389
# deleting temporary files

dlclive/engine.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from enum import Enum
2+
from pathlib import Path
23

34
class Engine(Enum):
45
TENSORFLOW = "tensorflow"
@@ -12,3 +13,21 @@ def from_model_type(cls, model_type: str) -> "Engine":
1213
return cls.TENSORFLOW
1314
else:
1415
raise ValueError(f"Unknown model type: {model_type}")
16+
17+
@classmethod
18+
def from_model_path(cls, model_path: str | Path) -> "Engine":
19+
path = Path(model_path)
20+
21+
if not path.exists():
22+
raise FileNotFoundError(f"Model path does not exist: {model_path}")
23+
24+
if path.is_dir():
25+
has_cfg = (path / "pose_cfg.yaml").is_file()
26+
has_pb = any(p.suffix == ".pb" for p in path.glob("*.pb"))
27+
if has_cfg and has_pb:
28+
return cls.TENSORFLOW
29+
elif path.is_file():
30+
if path.suffix == ".pt":
31+
return cls.PYTORCH
32+
33+
raise ValueError(f"Could not determine engine from model path: {model_path}")

0 commit comments

Comments
 (0)