Skip to content

[DO NOT MERGE YET] wip: CLIP Vision becomes its own thing #6161

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField"
T2IAdapterModel = "T2IAdapterModelField"
CLIPVisionModel = "CLIPVisionModelField"
# endregion

# region Misc Field Types
Expand Down Expand Up @@ -134,6 +135,7 @@ class FieldDescriptions:
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
clip_vision_model = "CLIP Vision Model to load"
lora_weight = "The weight at which the LoRA is applied to each model"
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
raw_prompt = "Raw prompt text (no parsing)"
Expand Down
55 changes: 7 additions & 48 deletions invokeai/app/invocations/ip_adapter.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
from builtins import float
from typing import List, Literal, Union
from typing import List, Union

from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self

from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.model import CLIPVisionField, ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
IPAdapterCheckpointConfig,
IPAdapterInvokeAIConfig,
ModelType,
)
from invokeai.backend.model_manager.config import IPAdapterCheckpointConfig, IPAdapterInvokeAIConfig


class IPAdapterField(BaseModel):
Expand Down Expand Up @@ -49,9 +43,6 @@ class IPAdapterOutput(BaseInvocationOutput):
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")


CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}


@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.2.2")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes."""
Expand All @@ -65,9 +56,9 @@ class IPAdapterInvocation(BaseInvocation):
ui_order=-1,
ui_type=UIType.IPAdapterModel,
)
clip_vision_model: Literal["ViT-H", "ViT-G"] = InputField(
description="CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models.",
default="ViT-H",
clip_vision: CLIPVisionField = InputField(
description="The CLIP Vision model.",
title="CLIP Vision",
ui_order=2,
)
weight: Union[float, List[float]] = InputField(
Expand Down Expand Up @@ -96,45 +87,13 @@ def invoke(self, context: InvocationContext) -> IPAdapterOutput:
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))

if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig):
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
else:
image_encoder_model_name = CLIP_VISION_MODEL_MAP[self.clip_vision_model]

image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)

return IPAdapterOutput(
ip_adapter=IPAdapterField(
image=self.image,
ip_adapter_model=self.ip_adapter_model,
image_encoder_model=ModelIdentifierField.from_config(image_encoder_model),
image_encoder_model=self.clip_vision.clip_vision,
weight=self.weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
),
)

def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig:
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)

if not len(image_encoder_models) > 0:
context.logger.warning(
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed. \
Downloading and installing now. This may take a while."
)

installer = context._services.model_manager.install
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
installer.wait_for_job(job, timeout=600) # Wait for up to 10 minutes
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)

if len(image_encoder_models) == 0:
context.logger.error("Error while fetching CLIP Vision Image Encoder")
assert len(image_encoder_models) == 1

return image_encoder_models[0]
2 changes: 1 addition & 1 deletion invokeai/app/invocations/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class IPAdapterMetadataField(BaseModel):

image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
clip_vision_model: ModelIdentifierField = Field(description="The CLIP Vision model")
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
Expand Down
36 changes: 30 additions & 6 deletions invokeai/app/invocations/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,7 @@
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType

from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output


class ModelIdentifierField(BaseModel):
Expand Down Expand Up @@ -65,6 +60,10 @@ class VAEField(BaseModel):
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')


class CLIPVisionField(BaseModel):
clip_vision: ModelIdentifierField = Field(description="Info to load clip vision model")


@invocation_output("unet_output")
class UNetOutput(BaseInvocationOutput):
"""Base class for invocations that output a UNet field."""
Expand Down Expand Up @@ -368,3 +367,28 @@ class FreeUInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> UNetOutput:
self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2)
return UNetOutput(unet=self.unet)


@invocation_output("clip_vision_output")
class CLIPVisionOutput(BaseInvocationOutput):
"""Output for CLIP Vision Model Loader"""

clip_vision: CLIPVisionField = OutputField(description=FieldDescriptions.clip_vision_model, title="CLIP Vision")


@invocation("clip_vision_model_loader", title="CLIP Vision Model", category="clip", version="1.0.0")
class CLIPVisionModelLoaderInvocation(BaseInvocation):
"""Loads the specified CLIP Vision Model"""

clip_vision_model: ModelIdentifierField = InputField(
description=FieldDescriptions.clip_vision_model,
input=Input.Direct,
title="CLIP Vision",
ui_type=UIType.CLIPVisionModel,
)

def invoke(self, context: InvocationContext) -> CLIPVisionOutput:
if not context.models.exists(self.clip_vision_model.key):
raise Exception(f"Unknown model {self.clip_vision_model.key}")

return CLIPVisionOutput(clip_vision=CLIPVisionField(clip_vision=self.clip_vision_model))
2 changes: 1 addition & 1 deletion invokeai/backend/ip_adapter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ The weights in `ip_adapter.bin` are stored in a nested dict, which is not suppor

## InvokeAI Hosted IP-Adapters

Image Encoders:
CLIP Vision Image Encoders:
- [InvokeAI/ip_adapter_sd_image_encoder](https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder)
- [InvokeAI/ip_adapter_sdxl_image_encoder](https://huggingface.co/InvokeAI/ip_adapter_sdxl_image_encoder)

Expand Down
17 changes: 12 additions & 5 deletions invokeai/backend/ip_adapter/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ class ImageProjModel(torch.nn.Module):
"""Image Projection Model"""

def __init__(
self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024, clip_extra_context_tokens: int = 4
self,
cross_attention_dim: int = 1024,
clip_embeddings_dim: int = 1024,
clip_extra_context_tokens: int = 4,
):
super().__init__()

Expand Down Expand Up @@ -149,8 +152,10 @@ def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisi
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds
try:
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
image_prompt_embeds = self._image_proj_model(clip_image_embeds.to(device=self.device, dtype=self.dtype))
uncond_image_prompt_embeds = self._image_proj_model(
torch.zeros_like(clip_image_embeds.to(device=self.device, dtype=self.dtype))
)
return image_prompt_embeds, uncond_image_prompt_embeds
except RuntimeError as e:
raise RuntimeError("Selected CLIP Vision Model is incompatible with the current IP Adapter") from e
Expand Down Expand Up @@ -178,8 +183,10 @@ def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisi
-2
]
try:
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
image_prompt_embeds = self._image_proj_model(clip_image_embeds.to(device=self.device, dtype=self.dtype))
uncond_image_prompt_embeds = self._image_proj_model(
uncond_clip_image_embeds.to(device=self.device, dtype=self.dtype)
)
return image_prompt_embeds, uncond_image_prompt_embeds
except RuntimeError as e:
raise RuntimeError("Selected CLIP Vision Model is incompatible with the current IP Adapter") from e
Expand Down
26 changes: 20 additions & 6 deletions invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ class IPAdapterInvokeAIConfig(IPAdapterBaseConfig):
"""Model config for IP Adapter diffusers format models."""

image_encoder_model_id: str
format: Literal[ModelFormat.InvokeAI]
format: Literal[ModelFormat.InvokeAI] = ModelFormat.InvokeAI

@staticmethod
def get_tag() -> Tag:
Expand All @@ -341,18 +341,31 @@ def get_tag() -> Tag:
class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
"""Model config for IP Adapter checkpoint format models."""

format: Literal[ModelFormat.Checkpoint]
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint

@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")


class CLIPVisionDiffusersConfig(DiffusersConfigBase):
class CLIPVisionBaseConfig(ModelConfigBase):
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision


class CLIPVisionCheckpointConfig(CLIPVisionBaseConfig):
"""Model config for CLIPVision."""

type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers]
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint

@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Checkpoint.value}")


class CLIPVisionDiffusersConfig(CLIPVisionBaseConfig):
"""Model config for CLIPVision."""

format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers

@staticmethod
def get_tag() -> Tag:
Expand All @@ -363,7 +376,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
"""Model config for T2I."""

type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers]
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers

@staticmethod
def get_tag() -> Tag:
Expand Down Expand Up @@ -407,6 +420,7 @@ def get_model_discriminator_value(v: Any) -> str:
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
Annotated[CLIPVisionCheckpointConfig, CLIPVisionCheckpointConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import pathlib
from typing import Optional, TypedDict

import safetensors.torch
from transformers import CLIPVisionConfig, CLIPVisionModelWithProjection

from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.util.devices import choose_torch_device


class CLIPVisionConfigParams(TypedDict):
hidden_size: int
intermediate_size: int
projection_dim: int
num_hidden_layers: int
num_attention_heads: int
num_channels: int
image_size: int
patch_size: int
hidden_act: str
layer_norm_eps: float
attention_dropout: float
initializer_range: float
initializer_factor: float
torch_dtype: str


CLIP_VISION_STANDARD_CONFIG: CLIPVisionConfigParams = {
"hidden_size": 768,
"intermediate_size": 3072,
"projection_dim": 512,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"num_channels": 3,
"image_size": 224,
"patch_size": 32,
"hidden_act": "quick_gelu",
"layer_norm_eps": 1e-05,
"attention_dropout": 0.0,
"initializer_range": 0.02,
"initializer_factor": 1.0,
"torch_dtype": "float16",
}


CLIP_VISION_VIT_H_CONFIG: CLIPVisionConfigParams = {
**CLIP_VISION_STANDARD_CONFIG,
"hidden_size": 1280,
"intermediate_size": 5120,
"projection_dim": 1024,
"num_hidden_layers": 32,
"num_attention_heads": 16,
"patch_size": 14,
"hidden_act": "gelu",
"layer_norm_eps": 1e-05,
}

CLIP_VISION_VIT_G_CONFIG: CLIPVisionConfigParams = {
**CLIP_VISION_STANDARD_CONFIG,
"hidden_size": 1664,
"intermediate_size": 8192,
"projection_dim": 1280,
"num_hidden_layers": 48,
"num_attention_heads": 16,
"patch_size": 14,
"hidden_act": "gelu",
}


@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Checkpoint)
class CLIPVisionModelLoader(ModelLoader):
"""Class to load CLIP Vision Checkpoint Models"""

def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> CLIPVisionModelWithProjection:
model_path = pathlib.Path(config.path)
clip_vision_state_dict = safetensors.torch.load_file(model_path, device=choose_torch_device().type)
clip_vision_keys = clip_vision_state_dict.keys()

if not any(key.startswith("vision_model.") for key in clip_vision_keys):
raise RuntimeError("Not a recognized CLIP Vision model.")

if "vision_model.encoder.layers.30.layer_norm1.weight" in clip_vision_keys:
clip_config = CLIPVisionConfig(**CLIP_VISION_VIT_H_CONFIG)
elif "vision_model.encoder.layers.47.layer_norm1.weight" in clip_vision_keys:
clip_config = CLIPVisionConfig(**CLIP_VISION_VIT_G_CONFIG)
else:
raise RuntimeError("Unrecognized CLIP Vision Model. Failed to load.")

clip_vision_model = CLIPVisionModelWithProjection(clip_config)
clip_vision_model.load_state_dict(clip_vision_state_dict, strict=False)

return clip_vision_model
Loading