Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
<a href="https://modelscope.cn/studios/ZhipuAI/CogView4" target="_blank"> 🤖ModelScope Space</a>
<a href="https://zhipuaishengchan.datasink.sensorsdata.cn/t/4z" target="_blank"> 🛠️ZhipuAI MaaS(Faster)</a>
<br>
<a href="resources/WECHAT.md" target="_blank"> 👋 WeChat Community</a> <a href="https://arxiv.org/abs/2403.05121" target="_blank">📚 CogView3 Paper</a>
<a href="resources/WECHAT.md" target="_blank"> 👋 WeChat Community</a> <a href="https://arxiv.org/abs/2403.05121" target="_blank">📚 CogView3 Paper</a> <a href="https://replicate.com/lucataco/cogview4-6b" target="_blank">🤖 Replicate</a>
</p>

![showcase.png](resources/showcase.png)
Expand Down
2 changes: 1 addition & 1 deletion README_ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
<a href="https://modelscope.cn/studios/ZhipuAI/CogView4" target="_blank"> 🤖ModelScope Space</a>
<a href="https://zhipuaishengchan.datasink.sensorsdata.cn/t/4z" target="_blank"> 🛠️ZhipuAI MaaS(Faster)</a>
<br>
<a href="resources/WECHAT.md" target="_blank"> 👋 WeChat Community</a> <a href="https://arxiv.org/abs/2403.05121" target="_blank">📚 CogView3 Paper</a>
<a href="resources/WECHAT.md" target="_blank"> 👋 WeChat Community</a> <a href="https://arxiv.org/abs/2403.05121" target="_blank">📚 CogView3 Paper</a> <a href="https://replicate.com/lucataco/cogview4-6b" target="_blank"> 🤖 Replicate</a>
</p>


Expand Down
2 changes: 1 addition & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
<br>
<a href="resources/WECHAT.md" target="_blank"> 👋 微信社区</a>
<a href="https://arxiv.org/abs/2403.05121" target="_blank"> 📚 CogView3 论文</a>
</p>
<a href="https://replicate.com/lucataco/cogview4-6b" target="_blank"> 🤖 Replicate</a>

![showcase.png](resources/showcase.png)

Expand Down
21 changes: 21 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Configuration for Cog ⚙️
# Reference: https://cog.run/yaml

build:
gpu: true
cuda: "12.1"
python_version: "3.11"
python_packages:
- "torch==2.4"
- "git+https://github.com/huggingface/diffusers.git@24c062aaa19f5626d03d058daf8afffa2dfd49f7"
- "transformers==4.49.0"
- "accelerate==1.4.0"
- "safetensors==0.5.3"
- "pillow==10.1.0"
- "numpy<2"

run:
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.9.1/pget_linux_x86_64" && chmod +x /usr/local/bin/pget

# predict.py defines how predictions are run on your model
predict: "predict.py:Predictor"
108 changes: 108 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Prediction interface for Cog ⚙️
# https://cog.run/python

import os
import time
import torch
import subprocess
from diffusers import CogView4Pipeline
from cog import BasePredictor, Input, Path

MODEL_CACHE = "checkpoints"
MODEL_URL = "https://weights.replicate.delivery/default/THUDM/CogView4-6B/model.tar"

def download_weights(url, dest):
start = time.time()
print("downloading url: ", url)
print("downloading to: ", dest)
subprocess.check_call(["pget", "-xf", url, dest], close_fds=False)
print("downloading took: ", time.time() - start)

class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
# Download weights if they don't exist
if not os.path.exists(MODEL_CACHE):
download_weights(MODEL_URL, MODEL_CACHE)

# Load CogView4-6B model with bfloat16 precision as recommended
self.pipe = CogView4Pipeline.from_pretrained(
MODEL_CACHE,
torch_dtype=torch.bfloat16
)

# Enable optimizations to reduce GPU memory usage and improve speed
self.pipe.enable_model_cpu_offload()
self.pipe.vae.enable_slicing()
self.pipe.vae.enable_tiling()


def predict(
self,
prompt: str = Input(
description="Text prompt to generate an image from"
),
negative_prompt: str = Input(
description="Negative prompt to guide image generation away from certain concepts",
default=None
),
width: int = Input(
description="Width of the generated image (must be between 512 and 2048, divisible by 32)",
default=1024,
ge=512,
le=2048
),
height: int = Input(
description="Height of the generated image (must be between 512 and 2048, divisible by 32)",
default=1024,
ge=512,
le=2048
),
num_inference_steps: int = Input(
description="Number of denoising steps",
default=50,
ge=1,
le=100
),
guidance_scale: float = Input(
description="Guidance scale for classifier-free guidance",
default=3.5,
ge=0.0,
le=20.0
),
seed: int = Input(
description="Random seed for reproducible image generation",
default=None
)
) -> Path:
"""Run a single prediction on the model"""
# Validate dimensions
if width % 32 != 0 or height % 32 != 0:
raise ValueError("Width and height must be divisible by 32")
if width * height > 2**21:
raise ValueError(f"Resolution {width}x{height} exceeds maximum allowed pixels (2^21)")

# Set seed for reproducibility
generator = None
if seed is None:
seed = int.from_bytes(os.urandom(3), "big")
generator = torch.Generator().manual_seed(seed)
print("Using seed: ", seed)

# Generate image(s)
images = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
num_images_per_prompt=1,
).images

# Save the first generated image
output_path = Path(f"/tmp/output.png")
images[0].save(output_path)
return output_path