Skip to content

Commit 664835c

Browse files
committed
WIP refactor benchmarking: extract Engine
1 parent db6a61d commit 664835c

File tree

3 files changed

+16
-17
lines changed

3 files changed

+16
-17
lines changed

dlclive/benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from dlclive import DLCLive
2626
from dlclive import VERSION
2727
from dlclive import __file__ as dlcfile
28-
from dlclive.factory import Engine
28+
from dlclive.engine import Engine
2929
from dlclive.utils import decode_fourcc
3030

3131

dlclive/engine.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from enum import Enum
2+
3+
class Engine(Enum):
4+
TENSORFLOW = "tensorflow"
5+
PYTORCH = "pytorch"
6+
7+
@classmethod
8+
def from_model_type(cls, model_type: str) -> "Engine":
9+
if model_type.lower() == "pytorch":
10+
return cls.PYTORCH
11+
elif model_type.lower() in ("tensorflow", "base", "tensorrt", "lite"):
12+
return cls.TENSORFLOW
13+
else:
14+
raise ValueError(f"Unknown model type: {model_type}")

dlclive/factory.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Literal
66

77
from dlclive.core.runner import BaseRunner
8+
from dlclive.engine import Engine
89

910

1011
def build_runner(
@@ -54,19 +55,3 @@ def build_runner(
5455
def filter_keys(valid: set[str], kwargs: dict) -> dict:
5556
"""Filters the keys in kwargs, only keeping those in valid."""
5657
return {k: v for k, v in kwargs.items() if k in valid}
57-
58-
59-
from enum import Enum
60-
61-
class Engine(Enum):
62-
TENSORFLOW = "tensorflow"
63-
PYTORCH = "pytorch"
64-
65-
@classmethod
66-
def from_model_type(cls, model_type: str) -> "Engine":
67-
if model_type.lower() == "pytorch":
68-
return cls.PYTORCH
69-
elif model_type.lower() in ("tensorflow", "base", "tensorrt", "lite"):
70-
return cls.TENSORFLOW
71-
else:
72-
raise ValueError(f"Unknown model type: {model_type}")

0 commit comments

Comments
 (0)