Skip to content

Commit 7c2e754

Browse files
llcourageLIT team
authored andcommitted
LIT Dalle-Mini demo.
PiperOrigin-RevId: 758468402
1 parent 151f83e commit 7c2e754

File tree

5 files changed

+273
-0
lines changed

5 files changed

+273
-0
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
Dalle_Mini Demo for the Learning Interpretability Tool
2+
=======================================================
3+
4+
This demo showcases how LIT can be used in text-to-image generation mode. It is
5+
based on the mini-dalle Mini model
6+
(https://www.piwheels.org/project/dalle-mini/).
7+
8+
You will need a standalone virtual environment for the Python libraries, which
9+
you can set up using the following commands from the root of the LIT repo.
10+
11+
```sh
12+
# Create the virtual environment. You may want to use python3 or python3.10
13+
# depends on how many Python versions you have installed and their aliases.
14+
python -m venv .dalle-mini
15+
source .dalle-mini/bin/activate
16+
# This requirements.txt file will also install the core LIT library deps.
17+
pip install -r ./lit_nlp/examples/dalle_mini/requirements.txt
18+
# The LIT web app still needs to be built in the usual way.
19+
(cd ./lit_nlp && yarn && yarn build)
20+
```
21+
22+
Once your virtual environment is setup, you can launch the demo with the
23+
following command.
24+
25+
```sh
26+
python -m lit_nlp.examples.dalle_mini.demo
27+
```
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""Data loaders for dalle-mini model."""
2+
3+
from lit_nlp.api import dataset as lit_dataset
4+
from lit_nlp.api import types as lit_types
5+
6+
7+
class DallePrompts(lit_dataset.Dataset):
8+
9+
def __init__(self, prompts: list[str]):
10+
self.examples = []
11+
for prompt in prompts:
12+
self.examples.append({"prompt": prompt})
13+
14+
def spec(self) -> lit_types.Spec:
15+
return {"prompt": lit_types.TextSegment()}
16+
17+
def __iter__(self):
18+
return iter(self.examples)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
r"""Example for dalle-mini demo model.
2+
3+
To run locally with a small number of examples:
4+
python -m lit_nlp.examples.dalle_mini.demo
5+
6+
7+
Then navigate to localhost:5432 to access the demo UI.
8+
"""
9+
10+
from collections.abc import Sequence
11+
import sys
12+
from typing import Optional
13+
14+
from absl import app
15+
from absl import flags
16+
from lit_nlp import app as lit_app
17+
from lit_nlp import dev_server
18+
from lit_nlp import server_flags
19+
from lit_nlp.api import layout
20+
from lit_nlp.examples.dalle_mini import data as dalle_data
21+
from lit_nlp.examples.dalle_mini import model as dalle_model
22+
23+
24+
# NOTE: additional flags defined in server_flags.py
25+
_FLAGS = flags.FLAGS
26+
_FLAGS.set_default("development_demo", True)
27+
_FLAGS.set_default("default_layout", "DALLE_LAYOUT")
28+
29+
_FLAGS.DEFINE_integer("grid_size", 4, "The grid size to use for the model.")
30+
31+
_MODELS = (["dalle-mini"],)
32+
33+
_CANNED_PROMPTS = ["I have a dream", "I have a shiba dog named cola"]
34+
35+
# Custom frontend layout; see api/layout.py
36+
_modules = layout.LitModuleName
37+
_DALLE_LAYOUT = layout.LitCanonicalLayout(
38+
upper={
39+
"Main": [
40+
_modules.DataTableModule,
41+
_modules.DatapointEditorModule,
42+
]
43+
},
44+
lower={
45+
"Predictions": [
46+
_modules.GeneratedImageModule,
47+
_modules.GeneratedTextModule,
48+
],
49+
},
50+
description="Custom layout for Text to Image models.",
51+
)
52+
53+
54+
CUSTOM_LAYOUTS = layout.DEFAULT_LAYOUTS | {"DALLE_LAYOUT": _DALLE_LAYOUT}
55+
56+
57+
def get_wsgi_app() -> Optional[dev_server.LitServerType]:
58+
_FLAGS.set_default("server_type", "external")
59+
_FLAGS.set_default("demo_mode", True)
60+
# Parse flags without calling app.run(main), to avoid conflict with
61+
# gunicorn command line flags.
62+
unused = _FLAGS(sys.argv, known_only=True)
63+
return main(unused)
64+
65+
66+
def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
67+
if len(argv) > 1:
68+
raise app.UsageError("Too many command-line arguments.")
69+
70+
# Load models, according to the --models flag.
71+
models = {}
72+
73+
model_loaders: lit_app.ModelLoadersMap = {}
74+
model_loaders["dalle-mini"] = (
75+
dalle_model.DalleMiniModel,
76+
dalle_model.DalleMiniModel.init_spec(),
77+
)
78+
79+
datasets = {"examples": dalle_data.DallePrompts(_CANNED_PROMPTS)}
80+
dataset_loaders: lit_app.DatasetLoadersMap = {}
81+
dataset_loaders["text_to_image"] = (
82+
dalle_data.DallePrompts,
83+
dalle_data.DallePrompts.init_spec(),
84+
)
85+
86+
lit_demo = dev_server.Server(
87+
models=models,
88+
model_loaders=model_loaders,
89+
datasets=datasets,
90+
dataset_loaders=dataset_loaders,
91+
layouts=CUSTOM_LAYOUTS,
92+
**server_flags.get_flags(),
93+
)
94+
return lit_demo.serve()
95+
96+
97+
if __name__ == "__main__":
98+
app.run(main)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""LIT wrappers for MiniDalleModel."""
2+
3+
from collections.abc import Iterable
4+
5+
from lit_nlp.api import model as lit_model
6+
from lit_nlp.api import types as lit_types
7+
from lit_nlp.lib import image_utils
8+
from min_dalle import MinDalle
9+
import numpy as np
10+
from PIL import Image
11+
import torch
12+
13+
14+
class DalleMiniModel(lit_model.Model):
15+
"""LIT model wrapper for Dalle-Mini Text-to-Image model.
16+
17+
This wrapper simplifies the pipeline using Dalle-Mini for text-to-image
18+
generation.
19+
20+
21+
The basic flow within this model wrapper's predict() function is:
22+
23+
24+
1. Dalle-Mini processes the text prompt.
25+
2. Images are directly generated by Dalle-Mini.
26+
"""
27+
28+
def __init__(
29+
self,
30+
device: str = "cuda", # Use "cuda" for GPU or "cpu" for CPU
31+
grid_size: int = 4, # each batch will generate grid_size**2 images
32+
temperature: float = 0.5,
33+
top_k: int = 256,
34+
supercondition_factor: int = 32,
35+
):
36+
super().__init__()
37+
self.grid_size = grid_size
38+
self.temperature = temperature
39+
self.top_k = top_k
40+
self.supercondition_factor = supercondition_factor
41+
42+
# Load Dalle-Mini model
43+
self.model = MinDalle(
44+
models_root="./pretrained",
45+
dtype=torch.float32,
46+
device=device,
47+
is_mega=True,
48+
is_reusable=True,
49+
)
50+
51+
def max_minibatch_size(self) -> int:
52+
return 8
53+
54+
def predict(
55+
self, inputs: Iterable[lit_types.JsonDict], **unused_kw
56+
) -> Iterable[lit_types.JsonDict]:
57+
"""Generate images based on the input prompts."""
58+
59+
def tensor_to_pil_image(tensor):
60+
img_np = tensor.detach().cpu().numpy()
61+
img_np = np.squeeze(img_np)
62+
if img_np.ndim == 2:
63+
img_np = np.stack([img_np] * 3, axis=-1)
64+
elif img_np.ndim != 3 or img_np.shape[2] != 3:
65+
raise ValueError(
66+
f"Unexpected image shape: {img_np.shape}. Expected (H, W, 3)."
67+
)
68+
69+
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min()) * 255
70+
img_np = img_np.clip(0, 255).astype(np.uint8)
71+
return Image.fromarray(img_np)
72+
73+
prompts = [ex["prompt"] for ex in inputs]
74+
images = []
75+
for prompt in prompts:
76+
# Generate images using the model
77+
generated_images = self.model.generate_images(
78+
text=prompt,
79+
seed=-1,
80+
grid_size=self.grid_size,
81+
is_seamless=False,
82+
temperature=self.temperature,
83+
top_k=self.top_k,
84+
supercondition_factor=self.supercondition_factor,
85+
is_verbose=False,
86+
)
87+
pil_images = []
88+
for img_tensor in generated_images:
89+
pil_images.append(tensor_to_pil_image(img_tensor))
90+
images.append({
91+
"image": [
92+
image_utils.convert_pil_to_image_str(img) for img in pil_images
93+
],
94+
"prompt": prompt,
95+
})
96+
97+
return images
98+
99+
def input_spec(self):
100+
return {
101+
"grid_size": lit_types.Scalar(),
102+
"temperature": lit_types.Scalar(),
103+
"top_k": lit_types.Scalar(),
104+
"supercondition_factor": lit_types.Scalar(),
105+
}
106+
107+
def output_spec(self):
108+
return {
109+
"image": lit_types.ImageBytesList(),
110+
"prompt": lit_types.TextSegment(),
111+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
-r ../../../requirements.txt
17+
18+
# Dalle-Mini dependencies
19+
min_dalle==0.4.11

0 commit comments

Comments
 (0)