Skip to content

Commit fb1d96a

Browse files
committed
Move get_obj_from_str() and instantiate_from_config() into package. I am copying commit opened in issue CompVis#173 of the original repo which was waiting to be merged
1 parent 2bf4d87 commit fb1d96a

File tree

4 files changed

+16
-10
lines changed

4 files changed

+16
-10
lines changed

main.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,10 @@
1111
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
1212
from pytorch_lightning.utilities import rank_zero_only
1313

14+
from taming import get_obj_from_str, instantiate_from_config
1415
from taming.data.utils import custom_collate
1516

1617

17-
def get_obj_from_str(string, reload=False):
18-
module, cls = string.rsplit(".", 1)
19-
if reload:
20-
module_imp = importlib.import_module(module)
21-
importlib.reload(module_imp)
22-
return getattr(importlib.import_module(module, package=None), cls)
23-
24-
2518
def get_parser(**parser_kwargs):
2619
def str2bool(v):
2720
if isinstance(v, bool):

taming/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import importlib
2+
3+
def get_obj_from_str(string, reload=False):
4+
module, cls = string.rsplit(".", 1)
5+
if reload:
6+
module_imp = importlib.import_module(module)
7+
importlib.reload(module_imp)
8+
return getattr(importlib.import_module(module, package=None), cls)
9+
10+
def instantiate_from_config(config):
11+
if not "target" in config:
12+
raise KeyError("Expected key `target` to instantiate.")
13+
return get_obj_from_str(config["target"])(**config.get("params", dict()))

taming/models/cond_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.nn.functional as F
44
import pytorch_lightning as pl
55

6-
from main import instantiate_from_config
6+
from taming import instantiate_from_config
77
from taming.modules.util import SOSProvider
88

99

taming/models/vqgan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch.nn.functional as F
33
import pytorch_lightning as pl
44

5-
from main import instantiate_from_config
5+
from taming import instantiate_from_config
66

77
from taming.modules.diffusionmodules.model import Encoder, Decoder
88
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer

0 commit comments

Comments
 (0)