Skip to content

Commit feae820

Browse files
Add files via upload
Code changes from CompVis/latent-diffusion#123
1 parent 69ae4b3 commit feae820

File tree

7 files changed

+23
-15
lines changed

7 files changed

+23
-15
lines changed

ldm/models/diffusion/ddim.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self, model, schedule="linear", **kwargs):
1818

1919
def register_buffer(self, name, attr):
2020
if type(attr) == torch.Tensor:
21-
if attr.device != torch.device("cuda"):
21+
if attr.device != torch.device("cuda") and torch.cuda.is_available():
2222
attr = attr.to(torch.device("cuda"))
2323
setattr(self, name, attr)
2424

@@ -238,4 +238,4 @@ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unco
238238
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
239239
unconditional_guidance_scale=unconditional_guidance_scale,
240240
unconditional_conditioning=unconditional_conditioning)
241-
return x_dec
241+
return x_dec

ldm/models/diffusion/plms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, model, schedule="linear", **kwargs):
1717

1818
def register_buffer(self, name, attr):
1919
if type(attr) == torch.Tensor:
20-
if attr.device != torch.device("cuda"):
20+
if attr.device != torch.device("cuda") and torch.cuda.is_available():
2121
attr = attr.to(torch.device("cuda"))
2222
setattr(self, name, attr)
2323

ldm/modules/encoders/modules.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def forward(self, batch, key=None):
3535

3636
class TransformerEmbedder(AbstractEncoder):
3737
"""Some transformer encoder layers"""
38-
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
38+
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda" if torch.cuda.is_available() else "cpu"):
3939
super().__init__()
4040
self.device = device
4141
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
@@ -52,7 +52,7 @@ def encode(self, x):
5252

5353
class BERTTokenizer(AbstractEncoder):
5454
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
55-
def __init__(self, device="cuda", vq_interface=True, max_length=77):
55+
def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu", vq_interface=True, max_length=77):
5656
super().__init__()
5757
from transformers import BertTokenizerFast # TODO: add to reuquirements
5858
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
@@ -80,7 +80,7 @@ def decode(self, text):
8080
class BERTEmbedder(AbstractEncoder):
8181
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
8282
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
83-
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
83+
device="cuda" if torch.cuda.is_available() else "cpu", use_tokenizer=True, embedding_dropout=0.0):
8484
super().__init__()
8585
self.use_tknz_fn = use_tokenizer
8686
if self.use_tknz_fn:
@@ -136,7 +136,7 @@ def encode(self, x):
136136

137137
class FrozenCLIPEmbedder(AbstractEncoder):
138138
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
139-
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
139+
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda" if torch.cuda.is_available() else "cpu", max_length=77):
140140
super().__init__()
141141
self.tokenizer = CLIPTokenizer.from_pretrained(version)
142142
self.transformer = CLIPTextModel.from_pretrained(version)
@@ -231,4 +231,4 @@ def forward(self, x):
231231
if __name__ == "__main__":
232232
from ldm.util import count_params
233233
model = FrozenCLIPEmbedder()
234-
count_params(model, verbose=True)
234+
count_params(model, verbose=True)

notebook_helpers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def load_model_from_config(config, ckpt):
4444
sd = pl_sd["state_dict"]
4545
model = instantiate_from_config(config.model)
4646
m, u = model.load_state_dict(sd, strict=False)
47-
model.cuda()
47+
if torch.cuda.is_available():
48+
model.cuda()
4849
model.eval()
4950
return {"model": model}, global_step
5051

@@ -117,7 +118,8 @@ def get_cond(mode, selected_path):
117118
c = rearrange(c, '1 c h w -> 1 h w c')
118119
c = 2. * c - 1.
119120

120-
c = c.to(torch.device("cuda"))
121+
if torch.cuda.is_available():
122+
c = c.to(torch.device("cuda"))
121123
example["LR_image"] = c
122124
example["image"] = c_up
123125

@@ -267,4 +269,4 @@ def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, e
267269
log["sample"] = x_sample
268270
log["time"] = t1 - t0
269271

270-
return log
272+
return log

scripts/knn2img.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def load_model_from_config(config, ckpt, verbose=False):
5353
print("unexpected keys:")
5454
print(u)
5555

56-
model.cuda()
56+
if torch.cuda.is_available():
57+
model.cuda()
5758
model.eval()
5859
return model
5960

@@ -358,7 +359,10 @@ def __call__(self, x, n):
358359
uc = None
359360
if searcher is not None:
360361
nn_dict = searcher(c, opt.knn)
361-
c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1)
362+
nn_embeddings = torch.from_numpy(nn_dict['nn_embeddings'])
363+
if torch.cuda.is_available():
364+
nn_embeddings = nn_embeddings.cuda()
365+
c = torch.cat([c, nn_embeddings], dim=1)
362366
if opt.scale != 1.0:
363367
uc = torch.zeros_like(c)
364368
if isinstance(prompts, tuple):

scripts/sample_diffusion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,8 @@ def get_parser():
220220
def load_model_from_config(config, sd):
221221
model = instantiate_from_config(config)
222222
model.load_state_dict(sd,strict=False)
223-
model.cuda()
223+
if torch.cuda.is_available():
224+
model.cuda()
224225
model.eval()
225226
return model
226227

scripts/txt2img.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def load_model_from_config(config, ckpt, verbose=False):
6060
print("unexpected keys:")
6161
print(u)
6262

63-
model.cuda()
63+
if torch.cuda.is_available():
64+
model.cuda()
6465
model.eval()
6566
return model
6667

0 commit comments

Comments
 (0)