Skip to content

Commit 1f29afb

Browse files
committed
Test and fix errors in dalle_mini examples.
1 parent 7c2e754 commit 1f29afb

File tree

4 files changed

+42
-13
lines changed

4 files changed

+42
-13
lines changed

lit_nlp/examples/dalle_mini/data.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,16 @@
77
class DallePrompts(lit_dataset.Dataset):
88

99
def __init__(self, prompts: list[str]):
10-
self.examples = []
10+
self._examples = []
1111
for prompt in prompts:
12-
self.examples.append({"prompt": prompt})
12+
self._examples.append({"prompt": prompt})
1313

1414
def spec(self) -> lit_types.Spec:
1515
return {"prompt": lit_types.TextSegment()}
1616

1717
def __iter__(self):
18-
return iter(self.examples)
18+
return iter(self._examples)
19+
20+
@property
21+
def examples(self):
22+
return self._examples

lit_nlp/examples/dalle_mini/demo.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,39 @@
11
r"""Example for dalle-mini demo model.
22
3+
First run following command to install required packages:
4+
pip install -r ./lit_nlp/examples/dalle_mini/requirements.txt
5+
36
To run locally with a small number of examples:
47
python -m lit_nlp.examples.dalle_mini.demo
58
6-
9+
By default, this module uses the "cuda" device for image generation.
10+
The `requirements.txt` file installs a CUDA-enabled version of PyTorch for GPU acceleration.
11+
12+
If you are running on a machine without a compatible GPU or CUDA drivers,
13+
you must switch the device to "cpu" and reinstall the CPU-only version of PyTorch.
14+
15+
Usage:
16+
- Default: device="cuda"
17+
- On CPU-only machines:
18+
1. Set device="cpu" during model initialization
19+
2. Uninstall the CUDA version of PyTorch:
20+
pip uninstall torch
21+
3. Install the CPU-only version:
22+
pip install torch==2.1.2+cpu --extra-index-url https://download.pytorch.org/whl/cpu
23+
24+
Example:
25+
>>> model = MinDalle(..., device="cpu")
26+
27+
Check CUDA availability:
28+
>>> import torch
29+
>>> torch.cuda.is_available()
30+
False # if no GPU support is present
31+
32+
Error Handling:
33+
- If CUDA is selected but unsupported, you will see:
34+
AssertionError: Torch not compiled with CUDA enabled
35+
- To fix this, either install the correct CUDA-enabled PyTorch or switch to CPU mode.
36+
737
Then navigate to localhost:5432 to access the demo UI.
838
"""
939

@@ -26,8 +56,6 @@
2656
_FLAGS.set_default("development_demo", True)
2757
_FLAGS.set_default("default_layout", "DALLE_LAYOUT")
2858

29-
_FLAGS.DEFINE_integer("grid_size", 4, "The grid size to use for the model.")
30-
3159
_MODELS = (["dalle-mini"],)
3260

3361
_CANNED_PROMPTS = ["I have a dream", "I have a shiba dog named cola"]

lit_nlp/examples/dalle_mini/model.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,8 @@ def tensor_to_pil_image(tensor):
9797
return images
9898

9999
def input_spec(self):
100-
return {
101-
"grid_size": lit_types.Scalar(),
102-
"temperature": lit_types.Scalar(),
103-
"top_k": lit_types.Scalar(),
104-
"supercondition_factor": lit_types.Scalar(),
105-
}
106-
100+
return {"prompt": lit_types.TextSegment()}
101+
107102
def output_spec(self):
108103
return {
109104
"image": lit_types.ImageBytesList(),

lit_nlp/examples/dalle_mini/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,5 @@
1717

1818
# Dalle-Mini dependencies
1919
min_dalle==0.4.11
20+
torch==2.1.2+cu118
21+
--extra-index-url https://download.pytorch.org/whl/cu118

0 commit comments

Comments
 (0)