From dfc7835359dd68144495b8b7432b9e59826e9b4e Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Mon, 28 Apr 2025 22:45:30 -0400 Subject: [PATCH 01/14] Setup Probe and UI to accept bria main models --- invokeai/app/invocations/fields.py | 1 + .../backend/model_manager/legacy_probe.py | 3 + .../model_manager/load/model_loaders/bria.py | 56 +++++++++++++++++++ invokeai/backend/model_manager/taxonomy.py | 1 + .../Invocation/fields/InputFieldRenderer.tsx | 10 ++++ .../BriaMainModelFieldInputComponent.tsx | 44 +++++++++++++++ .../web/src/features/nodes/types/common.ts | 2 + .../web/src/features/nodes/types/constants.ts | 1 + .../web/src/features/nodes/types/field.ts | 30 ++++++++++ .../util/schema/buildFieldInputInstance.ts | 1 + .../util/schema/buildFieldInputTemplate.ts | 16 ++++++ .../src/services/api/hooks/modelsByType.ts | 3 +- .../frontend/web/src/services/api/schema.ts | 4 +- .../frontend/web/src/services/api/types.ts | 4 ++ 14 files changed, 173 insertions(+), 3 deletions(-) create mode 100644 invokeai/backend/model_manager/load/model_loaders/bria.py create mode 100644 invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BriaMainModelFieldInputComponent.tsx diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index fb6d6af03d8..085f539426e 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -42,6 +42,7 @@ class UIType(str, Enum, metaclass=MetaEnum): MainModel = "MainModelField" CogView4MainModel = "CogView4MainModelField" FluxMainModel = "FluxMainModelField" + BriaMainModel = "BriaMainModelField" SD3MainModel = "SD3MainModelField" SDXLMainModel = "SDXLMainModelField" SDXLRefinerModel = "SDXLRefinerModelField" diff --git a/invokeai/backend/model_manager/legacy_probe.py b/invokeai/backend/model_manager/legacy_probe.py index 8a0e770d037..caff085b373 100644 --- a/invokeai/backend/model_manager/legacy_probe.py +++ b/invokeai/backend/model_manager/legacy_probe.py @@ -125,6 +125,7 @@ class ModelProbe(object): } CLASS2TYPE = { + "BriaPipeline": ModelType.Main, "FluxPipeline": ModelType.Main, "StableDiffusionPipeline": ModelType.Main, "StableDiffusionInpaintPipeline": ModelType.Main, @@ -861,6 +862,8 @@ def get_base_type(self) -> BaseModelType: return BaseModelType.StableDiffusion3 elif transformer_conf["_class_name"] == "CogView4Transformer2DModel": return BaseModelType.CogView4 + elif transformer_conf["_class_name"] == "BriaTransformer2DModel": + return BaseModelType.Bria else: raise InvalidModelConfigException(f"Unknown base model for {self.model_path}") diff --git a/invokeai/backend/model_manager/load/model_loaders/bria.py b/invokeai/backend/model_manager/load/model_loaders/bria.py new file mode 100644 index 00000000000..6712e13896e --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/bria.py @@ -0,0 +1,56 @@ +from pathlib import Path +from typing import Optional + +from invokeai.backend.model_manager.config import ( + AnyModelConfig, + CheckpointConfigBase, + DiffusersConfigBase, +) +from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry +from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader +from invokeai.backend.model_manager.taxonomy import ( + AnyModel, + BaseModelType, + ModelFormat, + ModelType, + SubModelType, +) + + +@ModelLoaderRegistry.register(base=BaseModelType.Bria, type=ModelType.Main, format=ModelFormat.Diffusers) +class BriaDiffusersModel(GenericDiffusersLoader): + """Class to load Bria main models.""" + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if isinstance(config, CheckpointConfigBase): + raise NotImplementedError("CheckpointConfigBase is not implemented for Bria models.") + + if submodel_type is None: + raise Exception("A submodel type must be provided when loading main pipelines.") + + model_path = Path(config.path) + load_class = self.get_hf_load_class(model_path, submodel_type) + repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None + variant = repo_variant.value if repo_variant else None + model_path = model_path / submodel_type.value + + dtype = self._torch_dtype + try: + result: AnyModel = load_class.from_pretrained( + model_path, + torch_dtype=dtype, + variant=variant, + ) + except OSError as e: + if variant and "no file named" in str( + e + ): # try without the variant, just in case user's preferences changed + result = load_class.from_pretrained(model_path, torch_dtype=dtype) + else: + raise e + + return result diff --git a/invokeai/backend/model_manager/taxonomy.py b/invokeai/backend/model_manager/taxonomy.py index a353a44e765..a1d48ff4808 100644 --- a/invokeai/backend/model_manager/taxonomy.py +++ b/invokeai/backend/model_manager/taxonomy.py @@ -30,6 +30,7 @@ class BaseModelType(str, Enum): Imagen4 = "imagen4" ChatGPT4o = "chatgpt-4o" FluxKontext = "flux-kontext" + Bria = "bria" class ModelType(str, Enum): diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index 493960fdba6..848172e4dde 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -29,6 +29,8 @@ import { isBooleanFieldInputTemplate, isChatGPT4oModelFieldInputInstance, isChatGPT4oModelFieldInputTemplate, + isBriaMainModelFieldInputInstance, + isBriaMainModelFieldInputTemplate, isCLIPEmbedModelFieldInputInstance, isCLIPEmbedModelFieldInputTemplate, isCLIPGEmbedModelFieldInputInstance, @@ -117,6 +119,7 @@ import { assert } from 'tsafe'; import BoardFieldInputComponent from './inputs/BoardFieldInputComponent'; import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent'; +import BriaMainModelFieldInputComponent from './inputs/BriaMainModelFieldInputComponent'; import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent'; import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInputComponent'; import CLIPLEmbedModelFieldInputComponent from './inputs/CLIPLEmbedModelFieldInputComponent'; @@ -448,6 +451,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props) return ; } + if (isBriaMainModelFieldInputTemplate(template)) { + if (!isBriaMainModelFieldInputInstance(field)) { + return null; + } + return ; + } + if (isSD3MainModelFieldInputTemplate(template)) { if (!isSD3MainModelFieldInputInstance(field)) { return null; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BriaMainModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BriaMainModelFieldInputComponent.tsx new file mode 100644 index 00000000000..8d8af426a56 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BriaMainModelFieldInputComponent.tsx @@ -0,0 +1,44 @@ +import { useAppDispatch } from 'app/store/storeHooks'; +import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox'; +import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; +import type { BriaMainModelFieldInputInstance, BriaMainModelFieldInputTemplate } from 'features/nodes/types/field'; +import { memo, useCallback } from 'react'; +import { useBriaModels } from 'services/api/hooks/modelsByType'; +import type { MainModelConfig } from 'services/api/types'; + +import type { FieldComponentProps } from './types'; + +type Props = FieldComponentProps; + +const BriaMainModelFieldInputComponent = (props: Props) => { + const { nodeId, field } = props; + const dispatch = useAppDispatch(); + const [modelConfigs, { isLoading }] = useBriaModels(); + const onChange = useCallback( + (value: MainModelConfig | null) => { + if (!value) { + return; + } + dispatch( + fieldMainModelValueChanged({ + nodeId, + fieldName: field.name, + value, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + return ( + + ); +}; + +export default memo(BriaMainModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index 1c9b4ec8ee4..7deeb7a7680 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -76,6 +76,7 @@ const zBaseModel = z.enum([ 'imagen4', 'chatgpt-4o', 'flux-kontext', + 'bria', ]); export type BaseModelType = z.infer; export const zMainModelBase = z.enum([ @@ -89,6 +90,7 @@ export const zMainModelBase = z.enum([ 'imagen4', 'chatgpt-4o', 'flux-kontext', + 'bria', ]); type MainModelBase = z.infer; export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success; diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index a8ab6d231e3..0e6131e4882 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -52,6 +52,7 @@ export const FIELD_COLORS: { [key: string]: string } = { LoRAModelField: 'teal.500', MainModelField: 'teal.500', FluxMainModelField: 'teal.500', + BriaMainModelField: 'teal.500', SD3MainModelField: 'teal.500', CogView4MainModelField: 'teal.500', SDXLMainModelField: 'teal.500', diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index af5132b394f..805a7d02f25 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -185,6 +185,10 @@ const zFluxMainModelFieldType = zFieldTypeBase.extend({ name: z.literal('FluxMainModelField'), originalType: zStatelessFieldType.optional(), }); +const zBriaMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('BriaMainModelField'), + originalType: zStatelessFieldType.optional(), +}); const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ name: z.literal('SDXLRefinerModelField'), originalType: zStatelessFieldType.optional(), @@ -325,6 +329,7 @@ const zStatefulFieldType = z.union([ zIntegerGeneratorFieldType, zStringGeneratorFieldType, zImageGeneratorFieldType, + zBriaMainModelFieldType, ]); export type StatefulFieldType = z.infer; const statefulFieldTypeNames = zStatefulFieldType.options.map((o) => o.shape.name.value); @@ -341,6 +346,7 @@ const modelFieldTypeNames = [ zSD3MainModelFieldType.shape.name.value, zCogView4MainModelFieldType.shape.name.value, zFluxMainModelFieldType.shape.name.value, + zBriaMainModelFieldType.shape.name.value, zSDXLRefinerModelFieldType.shape.name.value, zVAEModelFieldType.shape.name.value, zLoRAModelFieldType.shape.name.value, @@ -888,6 +894,26 @@ export const isFluxMainModelFieldInputTemplate = buildTemplateTypeGuard('FluxMainModelField'); // #endregion +// #region BriaMainModelField +const zBriaMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. +const zBriaMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + value: zBriaMainModelFieldValue, +}); +const zBriaMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zBriaMainModelFieldType, + originalType: zFieldType.optional(), + default: zBriaMainModelFieldValue, +}); +const zBriaMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zBriaMainModelFieldType, +}); +export type BriaMainModelFieldInputInstance = z.infer; +export type BriaMainModelFieldInputTemplate = z.infer; +export const isBriaMainModelFieldInputInstance = buildInstanceTypeGuard(zBriaMainModelFieldInputInstance); +export const isBriaMainModelFieldInputTemplate = + buildTemplateTypeGuard('BriaMainModelField'); +// #endregion + // #region SDXLRefinerModelField /** @alias */ // tells knip to ignore this duplicate export export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only. @@ -1887,6 +1913,7 @@ export const zStatefulFieldValue = z.union([ zMainModelFieldValue, zSDXLMainModelFieldValue, zFluxMainModelFieldValue, + zBriaMainModelFieldValue, zSD3MainModelFieldValue, zCogView4MainModelFieldValue, zSDXLRefinerModelFieldValue, @@ -1938,6 +1965,7 @@ const zStatefulFieldInputInstance = z.union([ zModelIdentifierFieldInputInstance, zMainModelFieldInputInstance, zFluxMainModelFieldInputInstance, + zBriaMainModelFieldInputInstance, zSD3MainModelFieldInputInstance, zCogView4MainModelFieldInputInstance, zSDXLMainModelFieldInputInstance, @@ -1980,6 +2008,7 @@ const zStatefulFieldInputTemplate = z.union([ zModelIdentifierFieldInputTemplate, zMainModelFieldInputTemplate, zFluxMainModelFieldInputTemplate, + zBriaMainModelFieldInputTemplate, zSD3MainModelFieldInputTemplate, zCogView4MainModelFieldInputTemplate, zSDXLMainModelFieldInputTemplate, @@ -2032,6 +2061,7 @@ const zStatefulFieldOutputTemplate = z.union([ zModelIdentifierFieldOutputTemplate, zMainModelFieldOutputTemplate, zFluxMainModelFieldOutputTemplate, + zBriaMainModelFieldOutputTemplate, zSD3MainModelFieldOutputTemplate, zCogView4MainModelFieldOutputTemplate, zSDXLMainModelFieldOutputTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts index 20ba67b24b3..c4c308a801b 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts @@ -17,6 +17,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record = SchedulerField: 'dpmpp_3m_k', SDXLMainModelField: undefined, FluxMainModelField: undefined, + BriaMainModelField: undefined, SD3MainModelField: undefined, CogView4MainModelField: undefined, SDXLRefinerModelField: undefined, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts index 4e3284f92a6..12d23df07e5 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts @@ -4,6 +4,7 @@ import type { BoardFieldInputTemplate, BooleanFieldInputTemplate, ChatGPT4oModelFieldInputTemplate, + BriaMainModelFieldInputTemplate, CLIPEmbedModelFieldInputTemplate, CLIPGEmbedModelFieldInputTemplate, CLIPLEmbedModelFieldInputTemplate, @@ -342,6 +343,20 @@ const buildFluxMainModelFieldInputTemplate: FieldInputTemplateBuilder = ({ + schemaObject, + baseField, + fieldType, +}) => { + const template: BriaMainModelFieldInputTemplate = { + ...baseField, + type: fieldType, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, @@ -834,6 +849,7 @@ export const TEMPLATE_BUILDER_MAP: Record { + return config.type === 'main' && config.base === 'bria'; +}; + export const isFluxFillMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => { return config.type === 'main' && config.base === 'flux' && config.variant === 'inpaint'; }; From 7f3e8087baa7fede01115df1dc9211c33a7cdeea Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 9 Jul 2025 10:32:58 +0000 Subject: [PATCH 02/14] added support for loading bria transformer --- invokeai/backend/bria/__init__.py | 0 invokeai/backend/bria/bria_utils.py | 314 ++++++++++++ invokeai/backend/bria/pipeline.py | 459 ++++++++++++++++++ invokeai/backend/bria/transformer_bria.py | 320 ++++++++++++ .../load/model_loaders/generic_diffusers.py | 3 + 5 files changed, 1096 insertions(+) create mode 100644 invokeai/backend/bria/__init__.py create mode 100644 invokeai/backend/bria/bria_utils.py create mode 100644 invokeai/backend/bria/pipeline.py create mode 100644 invokeai/backend/bria/transformer_bria.py diff --git a/invokeai/backend/bria/__init__.py b/invokeai/backend/bria/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/invokeai/backend/bria/bria_utils.py b/invokeai/backend/bria/bria_utils.py new file mode 100644 index 00000000000..a821ebe7ba1 --- /dev/null +++ b/invokeai/backend/bria/bria_utils.py @@ -0,0 +1,314 @@ +import math +import os +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.distributed as dist +from diffusers.utils import logging +from transformers import ( + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_t5_prompt_embeds( + tokenizer: T5TokenizerFast, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str], None] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, +): + device = device or text_encoder.device + + if prompt is None: + prompt = "" + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + # padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + + # Concat zeros to max_sequence + b, seq_len, dim = prompt_embeds.shape + if seq_len < max_sequence_length: + padding = torch.zeros( + (b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + prompt_embeds = torch.concat([prompt_embeds, padding], dim=1) + + prompt_embeds = prompt_embeds.to(device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + +# in order the get the same sigmas as in training and sample from them +def get_original_sigmas(num_train_timesteps=1000, num_inference_steps=1000): + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + sigmas = timesteps / num_train_timesteps + + inds = [int(ind) for ind in np.linspace(0, num_train_timesteps - 1, num_inference_steps)] + new_sigmas = sigmas[inds] + return new_sigmas + + +def is_ng_none(negative_prompt): + return ( + negative_prompt is None + or negative_prompt == "" + or (isinstance(negative_prompt, list) and negative_prompt[0] is None) + or (type(negative_prompt) == list and negative_prompt[0] == "") + ) + + +class CudaTimerContext: + def __init__(self, times_arr): + self.times_arr = times_arr + + def __enter__(self): + self.before_event = torch.cuda.Event(enable_timing=True) + self.after_event = torch.cuda.Event(enable_timing=True) + self.before_event.record() + + def __exit__(self, type, value, traceback): + self.after_event.record() + torch.cuda.synchronize() + elapsed_time = self.before_event.elapsed_time(self.after_event) / 1000 + self.times_arr.append(elapsed_time) + + +def get_env_prefix(): + env = os.environ.get("CLOUD_PROVIDER", "AWS").upper() + if env == "AWS": + return "SM_CHANNEL" + elif env == "AZURE": + return "AZUREML_DATAREFERENCE" + + raise Exception(f"Env {env} not supported") + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None +): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + +def initialize_distributed(): + # Initialize the process group for distributed training + dist.init_process_group("nccl") + + # Get the current process's rank (ID) and the total number of processes (world size) + rank = dist.get_rank() + world_size = dist.get_world_size() + + print(f"Initialized distributed training: Rank {rank}/{world_size}") + + +def get_clip_prompt_embeds( + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 77, + device: Optional[torch.device] = None, +): + device = device or text_encoder.device + assert max_sequence_length == tokenizer.model_max_length + prompt = [prompt] if isinstance(prompt, str) else prompt + + # Define tokenizers and text encoders + tokenizers = [tokenizer, tokenizer_2] + text_encoders = [text_encoder, text_encoder_2] + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders, strict=False): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, pooled_prompt_embeds + + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[np.ndarray, int], + theta: float = 10000.0, + use_real=False, + linear_factor=1.0, + ntk_factor=1.0, + repeat_interleave_real=True, + freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux) +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end + index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 + data type. + + Args: + dim (`int`): Dimension of the frequency tensor. + pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. Defaults to 10000.0. + use_real (`bool`, *optional*): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + linear_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the context extrapolation. Defaults to 1.0. + ntk_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. + repeat_interleave_real (`bool`, *optional*, defaults to `True`): + If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. + Otherwise, they are concateanted with themselves. + freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): + the dtype of the frequency tensor. + Returns: + `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + """ + assert dim % 2 == 0 + + if isinstance(pos, int): + pos = torch.arange(pos) + if isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) # type: ignore # [S] + + theta = theta * ntk_factor + freqs = ( + 1.0 + / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) + / linear_factor + ) # [D/2] + freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] + if use_real and repeat_interleave_real: + # flux, hunyuan-dit, cogvideox + freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] + return freqs_cos, freqs_sin + elif use_real: + # stable audio, allegro + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] + return freqs_cos, freqs_sin + else: + # lumina + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + + +class FluxPosEmbed(torch.nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + freqs_dtype = torch.float32 if is_mps else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin diff --git a/invokeai/backend/bria/pipeline.py b/invokeai/backend/bria/pipeline.py new file mode 100644 index 00000000000..d62e695db73 --- /dev/null +++ b/invokeai/backend/bria/pipeline.py @@ -0,0 +1,459 @@ +#!/usr/bin/env python +""" +Bria TextΓÇætoΓÇæImage Pipeline (GPUΓÇæready) +Using your local Bria checkpoints. +""" + +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn + +# Your bria_utils imports +from .bria_utils import get_original_sigmas, get_t5_prompt_embeds, is_ng_none +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler +from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps +from PIL import Image +from tqdm import tqdm # add this at the top of your file + +# Your custom transformer import +from .transformer_bria import BriaTransformer2DModel +from transformers import T5EncoderModel, T5TokenizerFast + + +# ----------------------------------------------------------------------------- +# 1. Model Loader +# ----------------------------------------------------------------------------- +class BriaModelLoader: + def __init__( + self, + transformer_ckpt: str, + vae_ckpt: str, + text_encoder_ckpt: str, + tokenizer_ckpt: str, + device: torch.device, + ): + self.device = device + + # print("Loading Bria Transformer from", transformer_ckpt) + # self.transformer = BriaTransformer2DModel.from_pretrained(transformer_ckpt, torch_dtype=torch.bfloat16).to(device) + + # print("Loading VAE from", vae_ckpt) + # self.vae = AutoencoderKL.from_pretrained(vae_ckpt, torch_dtype=torch.float32).to(device) + + # print("Loading T5 Encoder from", text_encoder_ckpt) + # self.text_encoder = T5EncoderModel.from_pretrained(text_encoder_ckpt, torch_dtype=torch.float16).to(device) + + # print("Loading Tokenizer from", tokenizer_ckpt) + # self.tokenizer = T5TokenizerFast.from_pretrained(tokenizer_ckpt, legacy=False) + self.transformer = BriaTransformer2DModel.from_pretrained(transformer_ckpt, torch_dtype=torch.float16).to( + device + ) + self.vae = AutoencoderKL.from_pretrained(vae_ckpt, torch_dtype=torch.float16).to(device) + self.text_encoder = T5EncoderModel.from_pretrained(text_encoder_ckpt, torch_dtype=torch.float16).to(device) + self.tokenizer = T5TokenizerFast.from_pretrained(tokenizer_ckpt) + + def get(self): + return { + "transformer": self.transformer, + "vae": self.vae, + "text_encoder": self.text_encoder, + "tokenizer": self.tokenizer, + } + + +# ----------------------------------------------------------------------------- +# 2. Text Encoder (uses bria_utils) +# ----------------------------------------------------------------------------- +class BriaTextEncoder: + def __init__( + self, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + device: torch.device, + max_length: int = 128, + ): + self.model = text_encoder.to(device) + self.tokenizer = tokenizer + self.device = device + self.max_length = max_length + + def encode( + self, + prompt: str, + negative_prompt: Optional[str] = None, + num_images_per_prompt: int = 1, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # 1) get positive embeddings + pos = get_t5_prompt_embeds( + tokenizer=self.tokenizer, + text_encoder=self.model, + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=self.max_length, + device=self.device, + ) + # 2) get negative or zeros + if negative_prompt is None or is_ng_none(negative_prompt): + neg = torch.zeros_like(pos) + else: + neg = get_t5_prompt_embeds( + tokenizer=self.tokenizer, + text_encoder=self.model, + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=self.max_length, + device=self.device, + ) + + # 3) build text_ids: shape [S_text, 3] + # S_text = number of tokens = pos.shape[1] + S_text = pos.shape[1] + text_ids = torch.zeros((1, S_text, 3), device=self.device, dtype=torch.long) + text_ids = torch.zeros((S_text, 3), device=self.device, dtype=torch.long) + + print(f"Text embeds shapes ΓåÆ pos: {pos.shape}, neg: {neg.shape}, text_ids: {text_ids.shape}") + return pos, neg, text_ids + + +# ----------------------------------------------------------------------------- +# 3. Latent Sampler +# ----------------------------------------------------------------------------- +class BriaLatentSampler: + def __init__(self, transformer: BriaTransformer2DModel, vae: AutoencoderKL, device: torch.device): + self.device = device + self.latent_channels = transformer.config.in_channels + # self.latent_height = vae.config.sample_size + # self.latent_width = vae.config.sample_size + self.latent_height = 128 + self.latent_width = 128 + + @staticmethod + def _prepare_latent_image_ids(batch_size: int, height: int, width: int, device: torch.device, dtype: torch.dtype): + # Build the same img_ids FluxPipeline.prepare_latents would use + latent_image_ids = torch.zeros((height, width, 3), device=device, dtype=dtype) + latent_image_ids[..., 1] = torch.arange(height, device=device)[:, None] + latent_image_ids[..., 2] = torch.arange(width, device=device)[None, :] + # reshape to [1, height*width, 3] then repeat for batch + latent_image_ids = latent_image_ids.view(1, height * width, 3) + return latent_image_ids.repeat(batch_size, 1, 1) + + def sample(self, batch_size: int = 1, seed: int = 0) -> tuple[torch.Tensor, torch.Tensor]: + gen = torch.Generator(device=self.device).manual_seed(seed) + + # 1) sample & pack the noise exactly as before + shrunk = self.latent_channels // 4 + noise4d = torch.randn( + (batch_size, shrunk, self.latent_height, self.latent_width), + device=self.device, + generator=gen, + ) + latents = ( + noise4d.view(batch_size, shrunk, self.latent_height // 2, 2, self.latent_width // 2, 2) + .permute(0, 2, 4, 1, 3, 5) + .reshape(batch_size, (self.latent_height // 2) * (self.latent_width // 2), shrunk * 4) + ) + + # 2) build the matching latent_image_ids + latent_image_ids = self._prepare_latent_image_ids( + batch_size, + self.latent_height // 2, + self.latent_width // 2, + device=self.device, + dtype=torch.long, + ) + if latent_image_ids.ndim == 3 and latent_image_ids.shape[0] == 1: + latent_image_ids = latent_image_ids[0] # [S_img , 3] + + latent_image_ids = latent_image_ids.squeeze(0) + + print(f"Sampled & packed latents: {latents.shape}") + return latents, latent_image_ids + + +# ----------------------------------------------------------------------------- +# 4. Denoising Loop (uses bria_utils for ╧â schedule) +# ----------------------------------------------------------------------------- +class BriaDenoise: + def __init__( + self, + transformer: nn.Module, + scheduler_name: str, + device: torch.device, + num_train_timesteps: int, + num_inference_steps: int, + **sched_kwargs, + ): + self.transformer = transformer.to(device) + self.device = device + + # Build scheduler + if scheduler_name == "flow_match": + from diffusers import FlowMatchEulerDiscreteScheduler + + self.scheduler = FlowMatchEulerDiscreteScheduler.from_config(transformer.config, **sched_kwargs) + else: + from diffusers import DDIMScheduler + + self.scheduler = DDIMScheduler(**sched_kwargs) + + # Use your exact ╧â schedule from bria_utils + from bria_utils import get_original_sigmas + + sigmas = get_original_sigmas( + num_train_timesteps=num_train_timesteps, + num_inference_steps=num_inference_steps, + ) + self.scheduler.set_timesteps( + num_inference_steps=None, + timesteps=None, + sigmas=sigmas, + device=device, + ) + + # allow early exit + self.interrupt = False + # will be set in denoise() + self._guidance_scale = 1.0 + self._joint_attention_kwargs = {} + self.transformer = transformer.to(device) + self.device = device + + @property + def guidance_scale(self) -> float: + return self._guidance_scale + + @property + def do_classifier_free_guidance(self) -> bool: + return self.guidance_scale > 1.0 + + @property + def joint_attention_kwargs(self) -> dict: + return self._joint_attention_kwargs + + @torch.no_grad() + def denoise( + self, + latents: torch.Tensor, # [B, seq_len, C_hidden] + latent_image_ids: torch.Tensor, # [B, seq_len, 3] + prompt_embeds: torch.Tensor, # [B, S_text, D] + negative_prompt_embeds: torch.Tensor, # [B, S_text, D] + text_ids: torch.Tensor, # [B, S_text, 3] + num_inference_steps: int = 30, + guidance_scale: float = 5.0, + normalize: bool = False, + clip_value: float | None = None, + seed: int = 0, + ) -> torch.Tensor: + # 0) Quick cast & setup + device = self.device + # ensure dtype matches transformer + target_dtype = next(self.transformer.parameters()).dtype + latents = latents.to(device, dtype=target_dtype) + prompt_embeds = prompt_embeds.to(device, dtype=target_dtype) + negative_prompt_embeds = negative_prompt_embeds.to(device, dtype=target_dtype) + + # replicate reference encode_prompt behaviour + if negative_prompt_embeds is None: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + if guidance_scale > 1.0: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + self._guidance_scale = guidance_scale + + # 1) Prepare FlowΓÇæMatch timesteps identical to reference pipeline + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler) and getattr( + self.scheduler.config, "use_dynamic_shifting", False + ): + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + mu = calculate_shift(image_seq_len, 256, 16_384, 0.25, 0.75) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, None, sigmas, mu=mu + ) + else: + sigmas = get_original_sigmas( + num_train_timesteps=self.scheduler.config.num_train_timesteps, + num_inference_steps=num_inference_steps, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, None, sigmas + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 2) Loop with progress bar + + with tqdm(total=num_inference_steps, desc="Denoising", unit="step") as progress_bar: + for i, t in enumerate(timesteps): + # a) expand for CFG? + latent_model_input = torch.cat([latents] * 2, dim=0) if self.do_classifier_free_guidance else latents + + # b) scale model input if needed + if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # c) broadcast timestep + timestep = t.expand(latent_model_input.shape[0]) + + # d) predict noise + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + txt_ids=text_ids, + img_ids=latent_image_ids, + )[0] + + # e) classifierΓÇæfree guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + cfg_noise_pred_text = noise_pred_text.std() + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # f) optional normalize/clip + if normalize: + noise_pred = noise_pred * (0.7 * (cfg_noise_pred_text / noise_pred.std())) + 0.3 * noise_pred + + if clip_value: + noise_pred = noise_pred.clamp(-clip_value, clip_value) + + # g) scheduler step, inΓÇæplace + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if (i + 1) % 5 == 0 or i == len(timesteps) - 1: + progress_bar.update(5 if i + 1 < len(timesteps) else (len(timesteps) % 5)) + + # # j) XLA sync + # if XLA_AVAILABLE: + # xm.mark_step() + + # 3) Return the final packed latents (still [B, seq_len, C_hidden]) + return latents + + +# ----------------------------------------------------------------------------- +# 5. Latents ΓåÆ Image +# ----------------------------------------------------------------------------- +class BriaLatentsToImage: + def __init__(self, vae: AutoencoderKL, device: torch.device): + self.vae = vae.to(device) + self.device = device + + @torch.no_grad() + def decode(self, latents: torch.Tensor) -> list[Image.Image]: + """ + Accepts either of the two packed shapes that come out of the denoiser + + ΓÇó [B , S , 16] ΓÇô 3ΓÇæD, where S = H┬▓ (e.g. 16┬á384 for 1024├ù1024) + ΓÇó [B , 1 , S , 16] ΓÇô 4ΓÇæD misΓÇæordered (what caused your crash) + + Converts them to the VAEΓÇÖs expected shape [B , 4 , H , W] before decoding. + """ + # ---- 1. UnΓÇæpack to (B , 4 , H , W) ---------------------------------- + if latents.ndim == 3: # (B,S,16) + B, S, C = latents.shape + H2 = int(S**0.5) # 128 for 1024├ù1024 + latents = ( + latents.view(B, H2, H2, 4, 2, 2) # split channels into 4├ù(2├ù2) + .permute(0, 3, 1, 4, 2, 5) # (B,4,H2,2,W2,2) + .reshape(B, 4, H2 * 2, H2 * 2) # (B,4,H,W) + ) + + elif latents.ndim == 4 and latents.shape[1] == 1: # (B,1,S,16) + B, _, S, C = latents.shape + H2 = int(S**0.5) + latents = ( + latents.squeeze(1) # -> (B,S,16) + .view(B, H2, H2, 4, 2, 2) + .permute(0, 3, 1, 4, 2, 5) + .reshape(B, 4, H2 * 2, H2 * 2) + ) + # else: already (B,4,H,W) + + # ---- 2. Standard VAE decode ----------------------------------------- + shift = 0 if self.vae.config.shift_factor is None else self.vae.config.shift_factor + latents = (latents / self.vae.config.scaling_factor) + shift + + # 1. temporarily move VAE to fp32 for the forward pass + self.vae.to(dtype=torch.float32) + images = self.vae.decode(latents.to(torch.float32)).sample # fullΓÇæprecision decode + self.vae.to(dtype=torch.bfloat16) # cast to fp32 **after** decode + images = (images.clamp(-1, 1) + 1) / 2 # [0,1] fp32 + images = (images.cpu().permute(0, 2, 3, 1).numpy() * 255).astype("uint8") + + return [Image.fromarray(img) for img in images] + + +# ----------------------------------------------------------------------------- +# Main: Assemble & Run +# ----------------------------------------------------------------------------- +def main(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print("Using device:", device) + + # ΓöÇΓöÇΓöÇ Use your actual checkpoint locations ΓöÇΓöÇΓöÇ + transformer_ckpt = "/home/ubuntu/invoke_local_nodes/bria_3_1/transformer" + vae_ckpt = "/home/ubuntu/invoke_local_nodes/bria_3_1/vae" + text_encoder_ckpt = "/home/ubuntu/invoke_local_nodes/bria_3_1/text_encoder" + tokenizer_ckpt = "/home/ubuntu/invoke_local_nodes/bria_3_1/tokenizer" + + # 1. Load models + loader = BriaModelLoader( + transformer_ckpt, + vae_ckpt, + text_encoder_ckpt, + tokenizer_ckpt, + device, + ) + mdl = loader.get() + # if diffusers.__version__ >= "0.27.0": + # mdl["transformer"].enable_xformers_memory_efficient_attention() # now safe + # else: + # mdl["transformer"].disable_xformers_memory_efficient_attention() # keep quality + + # 2. Encode prompt ΓÇö now capture text_ids as well + text_enc = BriaTextEncoder(mdl["text_encoder"], mdl["tokenizer"], device) + pos_embeds, neg_embeds, text_ids = text_enc.encode( + prompt="3d rendered image, landscape made out of ice cream, rich ice cream textures, ice cream-valley , with a milky ice cream river, the ice cream has rich texture with visible chocolate chunks and intricate details, in the background an air balloon floats over the vally, in the sky visible dramatic like clouds, brown-chocolate color white and pink pallet, drama, beautiful surreal landscape, polarizing lens, very high contrast, 3d rendered realistic", + negative_prompt=None, + num_images_per_prompt=1, + ) + + # 3. Sample initial noise ΓåÆ get both latents & latent_image_ids + sampler = BriaLatentSampler(mdl["transformer"], mdl["vae"], device) + init_latents, latent_image_ids = sampler.sample(batch_size=1, seed=1249141701) + + # 4. Denoise ΓÇö now passing latent_image_ids and text_ids + denoiser = BriaDenoise( + transformer=mdl["transformer"], + scheduler_name="flow_match", + device=device, + num_train_timesteps=1000, + num_inference_steps=30, + base_shift=0.5, + max_shift=1.15, + ) + final_latents = denoiser.denoise( + init_latents, + latent_image_ids, + pos_embeds, + neg_embeds, + text_ids, + num_inference_steps=30, + guidance_scale=5.0, + seed=1249141701, + ) + + # 5. Decode + decoder = BriaLatentsToImage(mdl["vae"], device) + images = decoder.decode(final_latents) + for i, img in enumerate(images): + img.save(f"bria_output_{i}.png") + + +if __name__ == "__main__": + main() diff --git a/invokeai/backend/bria/transformer_bria.py b/invokeai/backend/bria/transformer_bria.py new file mode 100644 index 00000000000..d19d11dc496 --- /dev/null +++ b/invokeai/backend/bria/transformer_bria.py @@ -0,0 +1,320 @@ +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +from .bria_utils import FluxPosEmbed as EmbedND +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.models.embeddings import TimestepEmbedding, get_timestep_embedding +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormContinuous +from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock +from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Timesteps(nn.Module): + def __init__( + self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000 + ): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + self.time_theta = time_theta + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + max_period=self.time_theta, + ) + return t_emb + + +class TimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, time_theta): + super().__init__() + + self.time_proj = Timesteps( + num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta + ) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timestep, dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D) + return timesteps_emb + + +""" +Based on FluxPipeline with several changes: +- no pooled embeddings +- We use zero padding for prompts +- No guidance embedding since this is not a distilled version +""" + + +class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + """ + The Transformer model introduced in Flux. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Parameters: + patch_size (`int`): Patch size to turn the input data into small patches. + in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. + num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. + joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. + guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 64, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 4096, + pooled_projection_dim: int = None, + guidance_embeds: bool = False, + axes_dims_rope: List[int] = [16, 56, 56], + rope_theta=10000, + time_theta=10000, + ): + super().__init__() + self.out_channels = in_channels + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + + self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope) + + self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta) + + # if pooled_projection_dim: + # self.pooled_text_embed = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim=self.inner_dim, act_fn="silu") + + if guidance_embeds: + self.guidance_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim) + + self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) + self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for i in range(self.config.num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for i in range(self.config.num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) + else: + guidance = None + + # temb = ( + # self.time_text_embed(timestep, pooled_projections) + # if guidance is None + # else self.time_text_embed(timestep, guidance, pooled_projections) + # ) + + temb = self.time_embed(timestep, dtype=hidden_states.dtype) + + # if pooled_projections: + # temb+=self.pooled_text_embed(pooled_projections) + + if guidance: + temb += self.guidance_embed(guidance, dtype=hidden_states.dtype) + + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if len(txt_ids.shape) == 2: + ids = torch.cat((txt_ids, img_ids), dim=0) + else: + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = self.pos_embed(ids) + + for index_block, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + for index_block, block in enumerate(self.single_transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + # controlnet residual + if controlnet_single_block_samples is not None: + interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( + hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + controlnet_single_block_samples[index_block // interval_control] + ) + + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py index 8a690583d5d..245d812a014 100644 --- a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -80,7 +80,10 @@ def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # "transformers", "invokeai.backend.quantization.fast_quantized_transformers_model", "invokeai.backend.quantization.fast_quantized_diffusion_model", + "transformer_bria", ]: + if module == "transformer_bria": + module = "invokeai.backend.bria.transformer_bria" res_type = sys.modules[module] else: res_type = sys.modules["diffusers"].pipelines From 25a57326b37e3b8ac01383fceac5f4f6ebbe96b5 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 9 Jul 2025 18:20:30 +0000 Subject: [PATCH 03/14] front end support for bria --- .../subpanels/ModelManagerPanel/ModelBaseBadge.tsx | 1 + .../frontend/web/src/features/parameters/types/constants.ts | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx index 59b7f022e22..1877a35e4a4 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx @@ -20,6 +20,7 @@ export const BASE_COLOR_MAP: Record = { imagen4: 'pink', 'chatgpt-4o': 'pink', 'flux-kontext': 'pink', + bria: 'purple', }; const ModelBaseBadge = ({ base }: Props) => { diff --git a/invokeai/frontend/web/src/features/parameters/types/constants.ts b/invokeai/frontend/web/src/features/parameters/types/constants.ts index d00ff1b1fa8..99ebfe91484 100644 --- a/invokeai/frontend/web/src/features/parameters/types/constants.ts +++ b/invokeai/frontend/web/src/features/parameters/types/constants.ts @@ -17,6 +17,7 @@ export const MODEL_TYPE_MAP: Record = { imagen4: 'Imagen4', 'chatgpt-4o': 'ChatGPT 4o', 'flux-kontext': 'Flux Kontext', + bria: 'Bria AI', }; /** @@ -35,6 +36,7 @@ export const MODEL_TYPE_SHORT_MAP: Record = { imagen4: 'Imagen4', 'chatgpt-4o': 'ChatGPT 4o', 'flux-kontext': 'Flux Kontext', + bria: 'Bria', }; /** @@ -89,6 +91,10 @@ export const CLIP_SKIP_MAP: Record Date: Wed, 9 Jul 2025 18:21:43 +0000 Subject: [PATCH 04/14] addded bria nodes for bria3.1 and bria3.2 --- invokeai/nodes/__init__.py | 1 + invokeai/nodes/bria_nodes/bria_decoder.py | 46 ++++++ invokeai/nodes/bria_nodes/bria_denoiser.py | 133 ++++++++++++++++++ .../nodes/bria_nodes/bria_latent_sampler.py | 79 +++++++++++ .../nodes/bria_nodes/bria_model_loader.py | 60 ++++++++ .../nodes/bria_nodes/bria_text_encoder.py | 90 ++++++++++++ 6 files changed, 409 insertions(+) create mode 100644 invokeai/nodes/__init__.py create mode 100644 invokeai/nodes/bria_nodes/bria_decoder.py create mode 100644 invokeai/nodes/bria_nodes/bria_denoiser.py create mode 100644 invokeai/nodes/bria_nodes/bria_latent_sampler.py create mode 100644 invokeai/nodes/bria_nodes/bria_model_loader.py create mode 100644 invokeai/nodes/bria_nodes/bria_text_encoder.py diff --git a/invokeai/nodes/__init__.py b/invokeai/nodes/__init__.py new file mode 100644 index 00000000000..f6b74417535 --- /dev/null +++ b/invokeai/nodes/__init__.py @@ -0,0 +1 @@ +from .bria_nodes import * \ No newline at end of file diff --git a/invokeai/nodes/bria_nodes/bria_decoder.py b/invokeai/nodes/bria_nodes/bria_decoder.py new file mode 100644 index 00000000000..38dbac0a0bb --- /dev/null +++ b/invokeai/nodes/bria_nodes/bria_decoder.py @@ -0,0 +1,46 @@ +import torch +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from PIL import Image + +from invokeai.app.invocations.model import VAEField +from invokeai.app.invocations.primitives import FieldDescriptions, Input, InputField, LatentsField +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.invocation_api import BaseInvocation, Classification, ImageOutput, invocation + + +@invocation( + "bria_decoder", + title="Bria Decoder", + tags=["image", "bria"], + category="image", + version="1.0.0", + classification=Classification.Prototype, +) +class BriaDecoderInvocation(BaseInvocation): + latents: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, + ) + vae: VAEField = InputField( + description=FieldDescriptions.vae, + input=Input.Connection, + ) + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> ImageOutput: + latents = context.tensors.load(self.latents.latents_name) + latents = latents.view(1, 64, 64, 4, 2, 2).permute(0, 3, 1, 4, 2, 5).reshape(1, 4, 128, 128) + + with context.models.load(self.vae.vae) as vae: + assert isinstance(vae, AutoencoderKL) + latents = (latents / vae.config.scaling_factor) + latents = latents.to(device=vae.device, dtype=vae.dtype) + + decoded_output = vae.decode(latents) + image = decoded_output.sample + + # Convert to numpy with proper gradient handling + image = ((image.clamp(-1, 1) + 1) / 2 * 255).cpu().detach().permute(0, 2, 3, 1).numpy().astype("uint8")[0] + img = Image.fromarray(image) + image_dto = context.images.save(image=img) + return ImageOutput.build(image_dto) diff --git a/invokeai/nodes/bria_nodes/bria_denoiser.py b/invokeai/nodes/bria_nodes/bria_denoiser.py new file mode 100644 index 00000000000..081c3392f47 --- /dev/null +++ b/invokeai/nodes/bria_nodes/bria_denoiser.py @@ -0,0 +1,133 @@ +import torch +from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler + +from invokeai.app.invocations.fields import Input, InputField +from invokeai.app.invocations.model import SubModelType, TransformerField +from invokeai.app.invocations.primitives import ( + BaseInvocationOutput, + FieldDescriptions, + Input, + InputField, + LatentsField, + OutputField, +) +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.invocation_api import BaseInvocation, Classification, InputField, invocation, invocation_output + +from invokeai.backend.bria.pipeline import get_original_sigmas, retrieve_timesteps +from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel + +@invocation_output("bria_denoise_output") +class BriaDenoiseInvocationOutput(BaseInvocationOutput): + latents: LatentsField = OutputField(description=FieldDescriptions.latents) + + +@invocation( + "bria_denoise", + title="Denoise - Bria", + tags=["image", "bria"], + category="image", + version="1.0.0", + classification=Classification.Prototype, +) +class BriaDenoiseInvocation(BaseInvocation): + num_steps: int = InputField( + default=30, title="Number of Steps", description="The number of steps to use for the denoiser" + ) + guidance_scale: float = InputField( + default=5.0, title="Guidance Scale", description="The guidance scale to use for the denoiser" + ) + + transformer: TransformerField = InputField( + description="Bria model (Transformer) to load", + input=Input.Connection, + title="Transformer", + ) + latents: LatentsField = InputField( + description="Latents to denoise", + input=Input.Connection, + title="Latents", + ) + latent_image_ids: LatentsField = InputField( + description="Latent Image IDs to denoise", + input=Input.Connection, + title="Latent Image IDs", + ) + pos_embeds: LatentsField = InputField( + description="Positive Prompt Embeds", + input=Input.Connection, + title="Positive Prompt Embeds", + ) + neg_embeds: LatentsField = InputField( + description="Negative Prompt Embeds", + input=Input.Connection, + title="Negative Prompt Embeds", + ) + text_ids: LatentsField = InputField( + description="Text IDs", + input=Input.Connection, + title="Text IDs", + ) + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput: + latents = context.tensors.load(self.latents.latents_name) + pos_embeds = context.tensors.load(self.pos_embeds.latents_name) + neg_embeds = context.tensors.load(self.neg_embeds.latents_name) + text_ids = context.tensors.load(self.text_ids.latents_name) + latent_image_ids = context.tensors.load(self.latent_image_ids.latents_name) + scheduler_identifier = self.transformer.transformer.model_copy(update={"submodel_type": SubModelType.Scheduler}) + + device = None + dtype = None + with ( + context.models.load(self.transformer.transformer) as transformer, + context.models.load(scheduler_identifier) as scheduler, + ): + assert isinstance(transformer, BriaTransformer2DModel) + assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler) + dtype = transformer.dtype + device = transformer.device + latents, pos_embeds, neg_embeds = map(lambda x: x.to(device, dtype), (latents, pos_embeds, neg_embeds)) + prompt_embeds = torch.cat([neg_embeds, pos_embeds]) if self.guidance_scale > 1 else pos_embeds + + sigmas = get_original_sigmas(1000, self.num_steps) + timesteps, _ = retrieve_timesteps(scheduler, self.num_steps, device, None, sigmas, mu=0.0) + + for t in timesteps: + # Prepare model input efficiently + if self.guidance_scale > 1: + latent_model_input = torch.cat([latents] * 2) + else: + latent_model_input = latents + + # Prepare timestep tensor efficiently + if isinstance(t, torch.Tensor): + timestep_tensor = t.expand(latent_model_input.shape[0]) + else: + timestep_tensor = torch.tensor([t] * latent_model_input.shape[0], device=device, dtype=torch.float32) + + noise_pred = transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep_tensor, + img_ids=latent_image_ids, + txt_ids=text_ids, + guidance=None, + return_dict=False, + )[0] + + if self.guidance_scale > 1: + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + self.guidance_scale * (noise_text - noise_uncond) + + # Convert timestep for scheduler + t_step = float(t.item()) if isinstance(t, torch.Tensor) else float(t) + + # Use scheduler step with proper dtypes + latents = scheduler.step(noise_pred, t_step, latents, return_dict=False)[0] + + assert isinstance(latents, torch.Tensor) + saved_input_latents_tensor = context.tensors.save(latents) + latents_output = LatentsField(latents_name=saved_input_latents_tensor) + return BriaDenoiseInvocationOutput(latents=latents_output) diff --git a/invokeai/nodes/bria_nodes/bria_latent_sampler.py b/invokeai/nodes/bria_nodes/bria_latent_sampler.py new file mode 100644 index 00000000000..36170ff5d90 --- /dev/null +++ b/invokeai/nodes/bria_nodes/bria_latent_sampler.py @@ -0,0 +1,79 @@ +import torch + +from invokeai.app.invocations.fields import Input, InputField +from invokeai.app.invocations.model import TransformerField +from invokeai.app.invocations.primitives import ( + BaseInvocationOutput, + FieldDescriptions, + Input, + LatentsField, + OutputField, +) +from invokeai.backend.model_manager.config import MainDiffusersConfig +from invokeai.invocation_api import ( + BaseInvocation, + Classification, + InputField, + InvocationContext, + invocation, + invocation_output, +) + + +@invocation_output("bria_latent_sampler_output") +class BriaLatentSamplerInvocationOutput(BaseInvocationOutput): + """Base class for nodes that output a CogView text conditioning tensor.""" + + latents: LatentsField = OutputField(description=FieldDescriptions.cond) + latent_image_ids: LatentsField = OutputField(description=FieldDescriptions.cond) + + +@invocation( + "bria_latent_sampler", + title="Latent Sampler - Bria", + tags=["image", "bria"], + category="image", + version="1.0.0", + classification=Classification.Prototype, +) +class BriaLatentSamplerInvocation(BaseInvocation): + seed: int = InputField( + default=42, + title="Seed", + description="The seed to use for the latent sampler", + ) + transformer: TransformerField = InputField( + description="Bria model (Transformer) to load", + input=Input.Connection, + title="Transformer", + ) + + def invoke(self, context: InvocationContext) -> BriaLatentSamplerInvocationOutput: + device = torch.device("cuda") + transformer_config = context.models.get_config(self.transformer.transformer) + if not isinstance(transformer_config, MainDiffusersConfig): + raise ValueError("Transformer config is not a MainDiffusersConfig") + # TODO: get latent channels from transformer config + latent_channels = 16 + latent_height, latent_width = 128, 128 + shrunk = latent_channels // 4 + gen = torch.Generator(device=device).manual_seed(self.seed) + + noise4d = torch.randn((1, shrunk, latent_height, latent_width), device=device, generator=gen) + latents = noise4d.view(1, shrunk, latent_height // 2, 2, latent_width // 2, 2).permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(1, (latent_height // 2) * (latent_width // 2), shrunk * 4) + + latent_image_ids = torch.zeros((latent_height // 2, latent_width // 2, 3), device=device, dtype=torch.long) + latent_image_ids[..., 1] = torch.arange(latent_height // 2, device=device)[:, None] + latent_image_ids[..., 2] = torch.arange(latent_width // 2, device=device)[None, :] + latent_image_ids = latent_image_ids.view(-1, 3) + + saved_latents_tensor = context.tensors.save(latents) + saved_latent_image_ids_tensor = context.tensors.save(latent_image_ids) + latents_output = LatentsField(latents_name=saved_latents_tensor) + latent_image_ids_output = LatentsField(latents_name=saved_latent_image_ids_tensor) + + return BriaLatentSamplerInvocationOutput( + latents=latents_output, + latent_image_ids=latent_image_ids_output, + ) diff --git a/invokeai/nodes/bria_nodes/bria_model_loader.py b/invokeai/nodes/bria_nodes/bria_model_loader.py new file mode 100644 index 00000000000..b8b20f4f511 --- /dev/null +++ b/invokeai/nodes/bria_nodes/bria_model_loader.py @@ -0,0 +1,60 @@ +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType +from invokeai.app.invocations.model import ( + ModelIdentifierField, + SubModelType, + T5EncoderField, + TransformerField, + VAEField, +) +from invokeai.invocation_api import ( + BaseInvocation, + BaseInvocationOutput, + Classification, + InputField, + InvocationContext, + OutputField, + invocation, + invocation_output, +) + + +@invocation_output("bria_model_loader_output") +class BriaModelLoaderOutput(BaseInvocationOutput): + """Bria base model loader output""" + + transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer") + t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder") + vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE") + + +@invocation( + "bria_model_loader", + title="Main Model - Bria", + tags=["model", "bria"], + version="1.0.0", + classification=Classification.Prototype, +) +class BriaModelLoaderInvocation(BaseInvocation): + """Loads a bria base model, outputting its submodels.""" + + model: ModelIdentifierField = InputField( + description="Bria model (Transformer) to load", + ui_type=UIType.BriaMainModel, + input=Input.Direct, + ) + + def invoke(self, context: InvocationContext) -> BriaModelLoaderOutput: + for key in [self.model.key]: + if not context.models.exists(key): + raise ValueError(f"Unknown model: {key}") + + transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer}) + text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder}) + tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer}) + vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE}) + + return BriaModelLoaderOutput( + transformer=TransformerField(transformer=transformer, loras=[]), + t5_encoder=T5EncoderField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[]), + vae=VAEField(vae=vae), + ) diff --git a/invokeai/nodes/bria_nodes/bria_text_encoder.py b/invokeai/nodes/bria_nodes/bria_text_encoder.py new file mode 100644 index 00000000000..143a873bb10 --- /dev/null +++ b/invokeai/nodes/bria_nodes/bria_text_encoder.py @@ -0,0 +1,90 @@ +from typing import Optional + +import torch +from transformers import ( + T5EncoderModel, + T5TokenizerFast, +) + +from invokeai.app.invocations.model import T5EncoderField +from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions, Input, OutputField +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.invocation_api import ( + BaseInvocation, + Classification, + InputField, + LatentsField, + invocation, + invocation_output, +) + +from invokeai.backend.bria.bria_utils import get_t5_prompt_embeds, is_ng_none + + +@invocation_output("bria_text_encoder_output") +class BriaTextEncoderInvocationOutput(BaseInvocationOutput): + """Base class for nodes that output a CogView text conditioning tensor.""" + + pos_embeds: LatentsField = OutputField(description=FieldDescriptions.cond) + neg_embeds: LatentsField = OutputField(description=FieldDescriptions.cond) + text_ids: LatentsField = OutputField(description=FieldDescriptions.cond) + + +@invocation( + "bria_text_encoder", + title="Prompt - Bria", + tags=["prompt", "conditioning", "bria"], + category="conditioning", + version="1.0.0", + classification=Classification.Prototype, +) +class BriaTextEncoderInvocation(BaseInvocation): + prompt: str = InputField( + title="Prompt", + description="The prompt to encode", + ) + negative_prompt: Optional[str] = InputField( + title="Negative Prompt", + description="The negative prompt to encode", + ) + max_length: int = InputField( + default=128, + title="Max Length", + description="The maximum length of the prompt", + ) + t5_encoder: T5EncoderField = InputField( + title="T5Encoder", + description=FieldDescriptions.t5_encoder, + input=Input.Connection, + ) + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> BriaTextEncoderInvocationOutput: + t5_encoder_info = context.models.load(self.t5_encoder.text_encoder) + t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer) + with ( + t5_encoder_info as text_encoder, + t5_tokenizer_info as tokenizer, + ): + assert isinstance(tokenizer, T5TokenizerFast) + assert isinstance(text_encoder, T5EncoderModel) + pos = get_t5_prompt_embeds(tokenizer, text_encoder, self.prompt, 1, self.max_length, text_encoder.device) + neg = ( + torch.zeros_like(pos) + if is_ng_none(self.negative_prompt) + else get_t5_prompt_embeds( + tokenizer, text_encoder, self.negative_prompt, 1, self.max_length, text_encoder.device + ) + ) + text_ids = torch.zeros((pos.shape[1], 3), device=text_encoder.device, dtype=torch.long) + saved_pos_tensor = context.tensors.save(pos) + saved_neg_tensor = context.tensors.save(neg) + saved_text_ids_tensor = context.tensors.save(text_ids) + pos_embeds_output = LatentsField(latents_name=saved_pos_tensor) + neg_embeds_output = LatentsField(latents_name=saved_neg_tensor) + text_ids_output = LatentsField(latents_name=saved_text_ids_tensor) + return BriaTextEncoderInvocationOutput( + pos_embeds=pos_embeds_output, + neg_embeds=neg_embeds_output, + text_ids=text_ids_output, + ) From 8b08af3949dd700deb43f264c5ea160bb3357461 Mon Sep 17 00:00:00 2001 From: Ilan Tchenak Date: Wed, 9 Jul 2025 23:45:08 +0300 Subject: [PATCH 05/14] Setup Probe and UI to accept bria controlnet models --- invokeai/app/invocations/fields.py | 1 + .../backend/model_manager/legacy_probe.py | 4 ++ .../model_manager/load/model_loaders/bria.py | 41 ++++++++++++++++ .../Invocation/fields/InputFieldRenderer.tsx | 14 +++++- ...BriaControlNetModelFieldInputComponent.tsx | 47 +++++++++++++++++++ .../web/src/features/nodes/types/constants.ts | 1 + .../web/src/features/nodes/types/field.ts | 30 ++++++++++++ .../util/schema/buildFieldInputInstance.ts | 1 + .../util/schema/buildFieldInputTemplate.ts | 18 ++++++- .../src/services/api/hooks/modelsByType.ts | 3 ++ .../frontend/web/src/services/api/schema.ts | 2 +- .../frontend/web/src/services/api/types.ts | 4 ++ 12 files changed, 162 insertions(+), 4 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BriaControlNetModelFieldInputComponent.tsx diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 085f539426e..15c75d996b5 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -43,6 +43,7 @@ class UIType(str, Enum, metaclass=MetaEnum): CogView4MainModel = "CogView4MainModelField" FluxMainModel = "FluxMainModelField" BriaMainModel = "BriaMainModelField" + BriaControlNetModel = "BriaControlNetModelField" SD3MainModel = "SD3MainModelField" SDXLMainModel = "SDXLMainModelField" SDXLRefinerModel = "SDXLRefinerModelField" diff --git a/invokeai/backend/model_manager/legacy_probe.py b/invokeai/backend/model_manager/legacy_probe.py index caff085b373..aeee6fd42f4 100644 --- a/invokeai/backend/model_manager/legacy_probe.py +++ b/invokeai/backend/model_manager/legacy_probe.py @@ -126,6 +126,7 @@ class ModelProbe(object): CLASS2TYPE = { "BriaPipeline": ModelType.Main, + "BriaControlNetModel": ModelType.ControlNet, "FluxPipeline": ModelType.Main, "StableDiffusionPipeline": ModelType.Main, "StableDiffusionInpaintPipeline": ModelType.Main, @@ -1013,6 +1014,9 @@ def get_base_type(self) -> BaseModelType: if config.get("_class_name", None) == "FluxControlNetModel": return BaseModelType.Flux + if config.get("_class_name", None) == "BriaControlNetModel": + return BaseModelType.Bria + # no obvious way to distinguish between sd2-base and sd2-768 dimension = config["cross_attention_dim"] if dimension == 768: diff --git a/invokeai/backend/model_manager/load/model_loaders/bria.py b/invokeai/backend/model_manager/load/model_loaders/bria.py index 6712e13896e..02a2c0835fc 100644 --- a/invokeai/backend/model_manager/load/model_loaders/bria.py +++ b/invokeai/backend/model_manager/load/model_loaders/bria.py @@ -5,6 +5,8 @@ AnyModelConfig, CheckpointConfigBase, DiffusersConfigBase, + ControlNetDiffusersConfig, + ControlNetCheckpointConfig, ) from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader @@ -17,6 +19,45 @@ ) +@ModelLoaderRegistry.register(base=BaseModelType.Bria, type=ModelType.ControlNet, format=ModelFormat.Diffusers) +class BriaControlNetDiffusersModel(GenericDiffusersLoader): + """Class to load Bria control net models.""" + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if isinstance(config, ControlNetCheckpointConfig): + raise NotImplementedError("CheckpointConfigBase is not implemented for Bria models.") + + if submodel_type is None: + raise Exception("A submodel type must be provided when loading control net pipelines.") + + model_path = Path(config.path) + load_class = self.get_hf_load_class(model_path, submodel_type) + repo_variant = config.repo_variant if isinstance(config, ControlNetDiffusersConfig) else None + variant = repo_variant.value if repo_variant else None + model_path = model_path / submodel_type.value + + dtype = self._torch_dtype + + try: + result: AnyModel = load_class.from_pretrained( + model_path, + torch_dtype=dtype, + variant=variant, + ) + except OSError as e: + if variant and "no file named" in str( + e + ): # try without the variant, just in case user's preferences changed + result = load_class.from_pretrained(model_path, torch_dtype=dtype) + else: + raise e + + return result + @ModelLoaderRegistry.register(base=BaseModelType.Bria, type=ModelType.Main, format=ModelFormat.Diffusers) class BriaDiffusersModel(GenericDiffusersLoader): """Class to load Bria main models.""" diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index 848172e4dde..468ba2a8152 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -27,10 +27,12 @@ import { isBoardFieldInputTemplate, isBooleanFieldInputInstance, isBooleanFieldInputTemplate, - isChatGPT4oModelFieldInputInstance, - isChatGPT4oModelFieldInputTemplate, + isBriaControlNetModelFieldInputInstance, + isBriaControlNetModelFieldInputTemplate, isBriaMainModelFieldInputInstance, isBriaMainModelFieldInputTemplate, + isChatGPT4oModelFieldInputInstance, + isChatGPT4oModelFieldInputTemplate, isCLIPEmbedModelFieldInputInstance, isCLIPEmbedModelFieldInputTemplate, isCLIPGEmbedModelFieldInputInstance, @@ -119,6 +121,7 @@ import { assert } from 'tsafe'; import BoardFieldInputComponent from './inputs/BoardFieldInputComponent'; import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent'; +import BriaControlNetModelFieldInputComponent from './inputs/BriaControlNetModelFieldInputComponent'; import BriaMainModelFieldInputComponent from './inputs/BriaMainModelFieldInputComponent'; import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent'; import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInputComponent'; @@ -458,6 +461,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props) return ; } + if (isBriaControlNetModelFieldInputTemplate(template)) { + if (!isBriaControlNetModelFieldInputInstance(field)) { + return null; + } + return ; + } + if (isSD3MainModelFieldInputTemplate(template)) { if (!isSD3MainModelFieldInputInstance(field)) { return null; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BriaControlNetModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BriaControlNetModelFieldInputComponent.tsx new file mode 100644 index 00000000000..31f732abb24 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BriaControlNetModelFieldInputComponent.tsx @@ -0,0 +1,47 @@ +import { useAppDispatch } from 'app/store/storeHooks'; +import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox'; +import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; +import type { + BriaControlNetModelFieldInputInstance, + BriaControlNetModelFieldInputTemplate, +} from 'features/nodes/types/field'; +import { memo, useCallback } from 'react'; +import { useBriaModels } from 'services/api/hooks/modelsByType'; +import type { MainModelConfig } from 'services/api/types'; + +import type { FieldComponentProps } from './types'; + +type Props = FieldComponentProps; + +const BriaControlNetModelFieldInputComponent = (props: Props) => { + const { nodeId, field } = props; + const dispatch = useAppDispatch(); + const [modelConfigs, { isLoading }] = useBriaModels(); + const onChange = useCallback( + (value: MainModelConfig | null) => { + if (!value) { + return; + } + dispatch( + fieldMainModelValueChanged({ + nodeId, + fieldName: field.name, + value, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + return ( + + ); +}; + +export default memo(BriaControlNetModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index 0e6131e4882..0c6f3c1de21 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -53,6 +53,7 @@ export const FIELD_COLORS: { [key: string]: string } = { MainModelField: 'teal.500', FluxMainModelField: 'teal.500', BriaMainModelField: 'teal.500', + BriaControlNetModelField: 'teal.500', SD3MainModelField: 'teal.500', CogView4MainModelField: 'teal.500', SDXLMainModelField: 'teal.500', diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index 805a7d02f25..dd17471cac4 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -189,6 +189,10 @@ const zBriaMainModelFieldType = zFieldTypeBase.extend({ name: z.literal('BriaMainModelField'), originalType: zStatelessFieldType.optional(), }); +const zBriaControlNetModelFieldType = zFieldTypeBase.extend({ + name: z.literal('BriaControlNetModelField'), + originalType: zStatelessFieldType.optional(), +}); const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ name: z.literal('SDXLRefinerModelField'), originalType: zStatelessFieldType.optional(), @@ -330,6 +334,7 @@ const zStatefulFieldType = z.union([ zStringGeneratorFieldType, zImageGeneratorFieldType, zBriaMainModelFieldType, + zBriaControlNetModelFieldType, ]); export type StatefulFieldType = z.infer; const statefulFieldTypeNames = zStatefulFieldType.options.map((o) => o.shape.name.value); @@ -347,6 +352,7 @@ const modelFieldTypeNames = [ zCogView4MainModelFieldType.shape.name.value, zFluxMainModelFieldType.shape.name.value, zBriaMainModelFieldType.shape.name.value, + zBriaControlNetModelFieldType.shape.name.value, zSDXLRefinerModelFieldType.shape.name.value, zVAEModelFieldType.shape.name.value, zLoRAModelFieldType.shape.name.value, @@ -914,6 +920,26 @@ export const isBriaMainModelFieldInputTemplate = buildTemplateTypeGuard('BriaMainModelField'); // #endregion +// #region BriaControlNetModelField +const zBriaControlNetModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. +const zBriaControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({ + value: zBriaControlNetModelFieldValue, +}); +const zBriaControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zBriaControlNetModelFieldType, + originalType: zFieldType.optional(), + default: zBriaControlNetModelFieldValue, +}); +const zBriaControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zBriaControlNetModelFieldType, +}); +export type BriaControlNetModelFieldInputInstance = z.infer; +export type BriaControlNetModelFieldInputTemplate = z.infer; +export const isBriaControlNetModelFieldInputInstance = buildInstanceTypeGuard(zBriaControlNetModelFieldInputInstance); +export const isBriaControlNetModelFieldInputTemplate = + buildTemplateTypeGuard('BriaControlNetModelField'); +// #endregion + // #region SDXLRefinerModelField /** @alias */ // tells knip to ignore this duplicate export export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only. @@ -1914,6 +1940,7 @@ export const zStatefulFieldValue = z.union([ zSDXLMainModelFieldValue, zFluxMainModelFieldValue, zBriaMainModelFieldValue, + zBriaControlNetModelFieldValue, zSD3MainModelFieldValue, zCogView4MainModelFieldValue, zSDXLRefinerModelFieldValue, @@ -1966,6 +1993,7 @@ const zStatefulFieldInputInstance = z.union([ zMainModelFieldInputInstance, zFluxMainModelFieldInputInstance, zBriaMainModelFieldInputInstance, + zBriaControlNetModelFieldInputInstance, zSD3MainModelFieldInputInstance, zCogView4MainModelFieldInputInstance, zSDXLMainModelFieldInputInstance, @@ -2009,6 +2037,7 @@ const zStatefulFieldInputTemplate = z.union([ zMainModelFieldInputTemplate, zFluxMainModelFieldInputTemplate, zBriaMainModelFieldInputTemplate, + zBriaControlNetModelFieldInputTemplate, zSD3MainModelFieldInputTemplate, zCogView4MainModelFieldInputTemplate, zSDXLMainModelFieldInputTemplate, @@ -2062,6 +2091,7 @@ const zStatefulFieldOutputTemplate = z.union([ zMainModelFieldOutputTemplate, zFluxMainModelFieldOutputTemplate, zBriaMainModelFieldOutputTemplate, + zBriaControlNetModelFieldOutputTemplate, zSD3MainModelFieldOutputTemplate, zCogView4MainModelFieldOutputTemplate, zSDXLMainModelFieldOutputTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts index c4c308a801b..7b2b19a0c1d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts @@ -18,6 +18,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record = SDXLMainModelField: undefined, FluxMainModelField: undefined, BriaMainModelField: undefined, + BriaControlNetModelField: undefined, SD3MainModelField: undefined, CogView4MainModelField: undefined, SDXLRefinerModelField: undefined, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts index 12d23df07e5..5520fb519e2 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts @@ -3,8 +3,9 @@ import { FieldParseError } from 'features/nodes/types/error'; import type { BoardFieldInputTemplate, BooleanFieldInputTemplate, - ChatGPT4oModelFieldInputTemplate, + BriaControlNetModelFieldInputTemplate, BriaMainModelFieldInputTemplate, + ChatGPT4oModelFieldInputTemplate, CLIPEmbedModelFieldInputTemplate, CLIPGEmbedModelFieldInputTemplate, CLIPLEmbedModelFieldInputTemplate, @@ -357,6 +358,20 @@ const buildBriaMainModelFieldInputTemplate: FieldInputTemplateBuilder = ({ + schemaObject, + baseField, + fieldType, +}) => { + const template: BriaControlNetModelFieldInputTemplate = { + ...baseField, + type: fieldType, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, @@ -850,6 +865,7 @@ export const TEMPLATE_BUILDER_MAP: Record { + return config.type === 'controlnet' && config.base === 'bria'; +}; + export const isFluxFillMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => { return config.type === 'main' && config.base === 'flux' && config.variant === 'inpaint'; }; From 75ca44d5f989c26a9b3c0b72d223b02c47aaeaf9 Mon Sep 17 00:00:00 2001 From: Ilan Tchenak Date: Thu, 10 Jul 2025 17:11:03 +0300 Subject: [PATCH 06/14] Add Bria text to image model and controlnet support --- .../backend/bria/controlnet_aux/__init__.py | 5 + .../bria/controlnet_aux/canny/__init__.py | 36 + .../bria/controlnet_aux/open_pose/LICENSE | 108 +++ .../bria/controlnet_aux/open_pose/__init__.py | 234 ++++++ .../bria/controlnet_aux/open_pose/body.py | 260 +++++++ .../bria/controlnet_aux/open_pose/face.py | 364 ++++++++++ .../bria/controlnet_aux/open_pose/hand.py | 90 +++ .../bria/controlnet_aux/open_pose/model.py | 217 ++++++ .../bria/controlnet_aux/open_pose/util.py | 383 ++++++++++ invokeai/backend/bria/controlnet_aux/util.py | 146 ++++ invokeai/backend/bria/controlnet_bria.py | 543 ++++++++++++++ invokeai/backend/bria/controlnet_utils.py | 69 ++ invokeai/backend/bria/pipeline_bria.py | 647 +++++++++++++++++ .../backend/bria/pipeline_bria_controlnet.py | 672 ++++++++++++++++++ .../backend/model_manager/legacy_probe.py | 4 +- .../model_manager/load/model_loaders/bria.py | 8 +- .../load/model_loaders/generic_diffusers.py | 3 + .../backend/model_manager/load/model_util.py | 5 + ...BriaControlNetModelFieldInputComponent.tsx | 8 +- invokeai/nodes/bria_nodes/__init__.py | 6 + invokeai/nodes/bria_nodes/bria_controlnet.py | 145 ++++ invokeai/nodes/bria_nodes/bria_denoiser.py | 150 ++-- .../nodes/bria_nodes/bria_latent_sampler.py | 32 +- .../nodes/bria_nodes/bria_text_encoder.py | 27 +- 24 files changed, 4073 insertions(+), 89 deletions(-) create mode 100644 invokeai/backend/bria/controlnet_aux/__init__.py create mode 100644 invokeai/backend/bria/controlnet_aux/canny/__init__.py create mode 100644 invokeai/backend/bria/controlnet_aux/open_pose/LICENSE create mode 100644 invokeai/backend/bria/controlnet_aux/open_pose/__init__.py create mode 100644 invokeai/backend/bria/controlnet_aux/open_pose/body.py create mode 100644 invokeai/backend/bria/controlnet_aux/open_pose/face.py create mode 100644 invokeai/backend/bria/controlnet_aux/open_pose/hand.py create mode 100644 invokeai/backend/bria/controlnet_aux/open_pose/model.py create mode 100644 invokeai/backend/bria/controlnet_aux/open_pose/util.py create mode 100644 invokeai/backend/bria/controlnet_aux/util.py create mode 100644 invokeai/backend/bria/controlnet_bria.py create mode 100644 invokeai/backend/bria/controlnet_utils.py create mode 100644 invokeai/backend/bria/pipeline_bria.py create mode 100644 invokeai/backend/bria/pipeline_bria_controlnet.py create mode 100644 invokeai/nodes/bria_nodes/__init__.py create mode 100644 invokeai/nodes/bria_nodes/bria_controlnet.py diff --git a/invokeai/backend/bria/controlnet_aux/__init__.py b/invokeai/backend/bria/controlnet_aux/__init__.py new file mode 100644 index 00000000000..0536dca4bbe --- /dev/null +++ b/invokeai/backend/bria/controlnet_aux/__init__.py @@ -0,0 +1,5 @@ +__version__ = "0.0.9" + +from .canny import CannyDetector +from .open_pose import OpenposeDetector + diff --git a/invokeai/backend/bria/controlnet_aux/canny/__init__.py b/invokeai/backend/bria/controlnet_aux/canny/__init__.py new file mode 100644 index 00000000000..aca9ae3a34b --- /dev/null +++ b/invokeai/backend/bria/controlnet_aux/canny/__init__.py @@ -0,0 +1,36 @@ +import warnings +import cv2 +import numpy as np +from PIL import Image +from ..util import HWC3, resize_image + +class CannyDetector: + def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, image_resolution=512, output_type=None, **kwargs): + if "img" in kwargs: + warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning) + input_image = kwargs.pop("img") + + if input_image is None: + raise ValueError("input_image must be defined.") + + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + output_type = output_type or "pil" + else: + output_type = output_type or "np" + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + detected_map = cv2.Canny(input_image, low_threshold, high_threshold) + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/invokeai/backend/bria/controlnet_aux/open_pose/LICENSE b/invokeai/backend/bria/controlnet_aux/open_pose/LICENSE new file mode 100644 index 00000000000..6f60b76d35f --- /dev/null +++ b/invokeai/backend/bria/controlnet_aux/open_pose/LICENSE @@ -0,0 +1,108 @@ +OPENPOSE: MULTIPERSON KEYPOINT DETECTION +SOFTWARE LICENSE AGREEMENT +ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY + +BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE. + +This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie Mellon University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor. + +RESERVATION OF OWNERSHIP AND GRANT OF LICENSE: +Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive, +non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i). + +CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication. + +COPYRIGHT: The Software is owned by Licensor and is protected by United +States copyright laws and applicable international treaties and/or conventions. + +PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto. + +DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement. + +BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies. + +USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “OpenPose", "Carnegie Mellon" or any renditions thereof without the prior written permission of Licensor. + +You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software. + +ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void. + +TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below. + +The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement. + +FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement. + +DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS. + +SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement. + +EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage. + +EXPORT REGULATION: Licensee agrees to comply with any and all applicable +U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control. + +SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby. + +NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor. + +GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Allegheny County, Pennsylvania. + +ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto. + + + +************************************************************************ + +THIRD-PARTY SOFTWARE NOTICES AND INFORMATION + +This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise. + +1. Caffe, version 1.0.0, (https://github.com/BVLC/caffe/) + +COPYRIGHT + +All contributions by the University of California: +Copyright (c) 2014-2017 The Regents of the University of California (Regents) +All rights reserved. + +All other contributions: +Copyright (c) 2014-2017, the respective contributors +All rights reserved. + +Caffe uses a shared copyright model: each contributor holds copyright over +their contributions to Caffe. The project versioning records all such +contribution and copyright details. If a contributor wants to further mark +their specific copyright on a particular contribution, they should indicate +their copyright solely in the commit message of the change when it is +committed. + +LICENSE + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +CONTRIBUTION AGREEMENT + +By contributing to the BVLC/caffe repository through pull-request, comment, +or otherwise, the contributor releases their content to the +license and copyright terms herein. + +************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION********** \ No newline at end of file diff --git a/invokeai/backend/bria/controlnet_aux/open_pose/__init__.py b/invokeai/backend/bria/controlnet_aux/open_pose/__init__.py new file mode 100644 index 00000000000..e463316aa60 --- /dev/null +++ b/invokeai/backend/bria/controlnet_aux/open_pose/__init__.py @@ -0,0 +1,234 @@ +# Openpose +# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose +# 2nd Edited by https://github.com/Hzzone/pytorch-openpose +# 3rd Edited by ControlNet +# 4th Edited by ControlNet (added face and correct hands) +# 5th Edited by ControlNet (Improved JSON serialization/deserialization, and lots of bug fixs) +# This preprocessor is licensed by CMU for non-commercial use only. + + +import os + +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + +import json +import warnings +from typing import Callable, List, NamedTuple, Tuple, Union + +import cv2 +import numpy as np +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from ..util import HWC3, resize_image +from . import util +from .body import Body, BodyResult, Keypoint +from .face import Face +from .hand import Hand + +HandResult = List[Keypoint] +FaceResult = List[Keypoint] + +class PoseResult(NamedTuple): + body: BodyResult + left_hand: Union[HandResult, None] + right_hand: Union[HandResult, None] + face: Union[FaceResult, None] + +def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True): + """ + Draw the detected poses on an empty canvas. + + Args: + poses (List[PoseResult]): A list of PoseResult objects containing the detected poses. + H (int): The height of the canvas. + W (int): The width of the canvas. + draw_body (bool, optional): Whether to draw body keypoints. Defaults to True. + draw_hand (bool, optional): Whether to draw hand keypoints. Defaults to True. + draw_face (bool, optional): Whether to draw face keypoints. Defaults to True. + + Returns: + numpy.ndarray: A 3D numpy array representing the canvas with the drawn poses. + """ + canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) + + for pose in poses: + if draw_body: + canvas = util.draw_bodypose(canvas, pose.body.keypoints) + + if draw_hand: + canvas = util.draw_handpose(canvas, pose.left_hand) + canvas = util.draw_handpose(canvas, pose.right_hand) + + if draw_face: + canvas = util.draw_facepose(canvas, pose.face) + + return canvas + + +class OpenposeDetector: + """ + A class for detecting human poses in images using the Openpose model. + + Attributes: + model_dir (str): Path to the directory where the pose models are stored. + """ + def __init__(self, body_estimation, hand_estimation=None, face_estimation=None): + self.body_estimation = body_estimation + self.hand_estimation = hand_estimation + self.face_estimation = face_estimation + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, filename=None, hand_filename=None, face_filename=None, cache_dir=None, local_files_only=False): + + if pretrained_model_or_path == "lllyasviel/ControlNet": + filename = filename or "annotator/ckpts/body_pose_model.pth" + hand_filename = hand_filename or "annotator/ckpts/hand_pose_model.pth" + face_filename = face_filename or "facenet.pth" + + face_pretrained_model_or_path = "lllyasviel/Annotators" + else: + filename = filename or "body_pose_model.pth" + hand_filename = hand_filename or "hand_pose_model.pth" + face_filename = face_filename or "facenet.pth" + + face_pretrained_model_or_path = pretrained_model_or_path + + if os.path.isdir(pretrained_model_or_path): + body_model_path = os.path.join(pretrained_model_or_path, filename) + hand_model_path = os.path.join(pretrained_model_or_path, hand_filename) + face_model_path = os.path.join(face_pretrained_model_or_path, face_filename) + else: + body_model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) + hand_model_path = hf_hub_download(pretrained_model_or_path, hand_filename, cache_dir=cache_dir, local_files_only=local_files_only) + face_model_path = hf_hub_download(face_pretrained_model_or_path, face_filename, cache_dir=cache_dir, local_files_only=local_files_only) + + body_estimation = Body(body_model_path) + hand_estimation = Hand(hand_model_path) + face_estimation = Face(face_model_path) + + return cls(body_estimation, hand_estimation, face_estimation) + + def to(self, device): + self.body_estimation.to(device) + self.hand_estimation.to(device) + self.face_estimation.to(device) + return self + + def detect_hands(self, body: BodyResult, oriImg) -> Tuple[Union[HandResult, None], Union[HandResult, None]]: + left_hand = None + right_hand = None + H, W, _ = oriImg.shape + for x, y, w, is_left in util.handDetect(body, oriImg): + peaks = self.hand_estimation(oriImg[y:y+w, x:x+w, :]).astype(np.float32) + if peaks.ndim == 2 and peaks.shape[1] == 2: + peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W) + peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H) + + hand_result = [ + Keypoint(x=peak[0], y=peak[1]) + for peak in peaks + ] + + if is_left: + left_hand = hand_result + else: + right_hand = hand_result + + return left_hand, right_hand + + def detect_face(self, body: BodyResult, oriImg) -> Union[FaceResult, None]: + face = util.faceDetect(body, oriImg) + if face is None: + return None + + x, y, w = face + H, W, _ = oriImg.shape + heatmaps = self.face_estimation(oriImg[y:y+w, x:x+w, :]) + peaks = self.face_estimation.compute_peaks_from_heatmaps(heatmaps).astype(np.float32) + if peaks.ndim == 2 and peaks.shape[1] == 2: + peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W) + peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H) + return [ + Keypoint(x=peak[0], y=peak[1]) + for peak in peaks + ] + + return None + + def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[PoseResult]: + """ + Detect poses in the given image. + Args: + oriImg (numpy.ndarray): The input image for pose detection. + include_hand (bool, optional): Whether to include hand detection. Defaults to False. + include_face (bool, optional): Whether to include face detection. Defaults to False. + + Returns: + List[PoseResult]: A list of PoseResult objects containing the detected poses. + """ + oriImg = oriImg[:, :, ::-1].copy() + H, W, C = oriImg.shape + with torch.no_grad(): + candidate, subset = self.body_estimation(oriImg) + bodies = self.body_estimation.format_body_result(candidate, subset) + + results = [] + for body in bodies: + left_hand, right_hand, face = (None,) * 3 + if include_hand: + left_hand, right_hand = self.detect_hands(body, oriImg) + if include_face: + face = self.detect_face(body, oriImg) + + results.append(PoseResult(BodyResult( + keypoints=[ + Keypoint( + x=keypoint.x / float(W), + y=keypoint.y / float(H) + ) if keypoint is not None else None + for keypoint in body.keypoints + ], + total_score=body.total_score, + total_parts=body.total_parts + ), left_hand, right_hand, face)) + + return results + + def __call__(self, input_image, detect_resolution=512, image_resolution=512, include_body=True, include_hand=False, include_face=False, hand_and_face=None, output_type="pil", **kwargs): + if hand_and_face is not None: + warnings.warn("hand_and_face is deprecated. Use include_hand and include_face instead.", DeprecationWarning) + include_hand = hand_and_face + include_face = hand_and_face + + if "return_pil" in kwargs: + warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) + output_type = "pil" if kwargs["return_pil"] else "np" + if type(output_type) is bool: + warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") + if output_type: + output_type = "pil" + + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + H, W, C = input_image.shape + + poses = self.detect_poses(input_image, include_hand, include_face) + canvas = draw_poses(poses, H, W, draw_body=include_body, draw_hand=include_hand, draw_face=include_face) + + detected_map = canvas + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/invokeai/backend/bria/controlnet_aux/open_pose/body.py b/invokeai/backend/bria/controlnet_aux/open_pose/body.py new file mode 100644 index 00000000000..fa4c74e4e1e --- /dev/null +++ b/invokeai/backend/bria/controlnet_aux/open_pose/body.py @@ -0,0 +1,260 @@ +import math +from typing import List, NamedTuple, Union + +import cv2 +import numpy as np +import torch +from scipy.ndimage.filters import gaussian_filter + +from . import util +from .model import bodypose_model + + +class Keypoint(NamedTuple): + x: float + y: float + score: float = 1.0 + id: int = -1 + + +class BodyResult(NamedTuple): + # Note: Using `Union` instead of `|` operator as the ladder is a Python + # 3.10 feature. + # Annotator code should be Python 3.8 Compatible, as controlnet repo uses + # Python 3.8 environment. + # https://github.com/lllyasviel/ControlNet/blob/d3284fcd0972c510635a4f5abe2eeb71dc0de524/environment.yaml#L6 + keypoints: List[Union[Keypoint, None]] + total_score: float + total_parts: int + + +class Body(object): + def __init__(self, model_path): + self.model = bodypose_model() + model_dict = util.transfer(self.model, torch.load(model_path)) + self.model.load_state_dict(model_dict) + self.model.eval() + + def to(self, device): + self.model.to(device) + return self + + def __call__(self, oriImg): + device = next(iter(self.model.parameters())).device + # scale_search = [0.5, 1.0, 1.5, 2.0] + scale_search = [0.5] + boxsize = 368 + stride = 8 + padValue = 128 + thre1 = 0.1 + thre2 = 0.05 + multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search] + heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19)) + paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38)) + + for m in range(len(multiplier)): + scale = multiplier[m] + imageToTest = util.smart_resize_k(oriImg, fx=scale, fy=scale) + imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) + im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 + im = np.ascontiguousarray(im) + + data = torch.from_numpy(im).float() + data = data.to(device) + # data = data.permute([2, 0, 1]).unsqueeze(0).float() + with torch.no_grad(): + Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data) + Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy() + Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy() + + # extract outputs, resize, and remove padding + # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps + heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps + heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride) + heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + heatmap = util.smart_resize(heatmap, (oriImg.shape[0], oriImg.shape[1])) + + # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs + paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs + paf = util.smart_resize_k(paf, fx=stride, fy=stride) + paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + paf = util.smart_resize(paf, (oriImg.shape[0], oriImg.shape[1])) + + heatmap_avg += heatmap_avg + heatmap / len(multiplier) + paf_avg += + paf / len(multiplier) + + all_peaks = [] + peak_counter = 0 + + for part in range(18): + map_ori = heatmap_avg[:, :, part] + one_heatmap = gaussian_filter(map_ori, sigma=3) + + map_left = np.zeros(one_heatmap.shape) + map_left[1:, :] = one_heatmap[:-1, :] + map_right = np.zeros(one_heatmap.shape) + map_right[:-1, :] = one_heatmap[1:, :] + map_up = np.zeros(one_heatmap.shape) + map_up[:, 1:] = one_heatmap[:, :-1] + map_down = np.zeros(one_heatmap.shape) + map_down[:, :-1] = one_heatmap[:, 1:] + + peaks_binary = np.logical_and.reduce( + (one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1)) + peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse + peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks] + peak_id = range(peak_counter, peak_counter + len(peaks)) + peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))] + + all_peaks.append(peaks_with_score_and_id) + peak_counter += len(peaks) + + # find connection in the specified sequence, center 29 is in the position 15 + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + # the middle joints heatmap correpondence + mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \ + [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \ + [55, 56], [37, 38], [45, 46]] + + connection_all = [] + special_k = [] + mid_num = 10 + + for k in range(len(mapIdx)): + score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]] + candA = all_peaks[limbSeq[k][0] - 1] + candB = all_peaks[limbSeq[k][1] - 1] + nA = len(candA) + nB = len(candB) + indexA, indexB = limbSeq[k] + if (nA != 0 and nB != 0): + connection_candidate = [] + for i in range(nA): + for j in range(nB): + vec = np.subtract(candB[j][:2], candA[i][:2]) + norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1]) + norm = max(0.001, norm) + vec = np.divide(vec, norm) + + startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \ + np.linspace(candA[i][1], candB[j][1], num=mid_num))) + + vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \ + for I in range(len(startend))]) + vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \ + for I in range(len(startend))]) + + score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1]) + score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min( + 0.5 * oriImg.shape[0] / norm - 1, 0) + criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts) + criterion2 = score_with_dist_prior > 0 + if criterion1 and criterion2: + connection_candidate.append( + [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]]) + + connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True) + connection = np.zeros((0, 5)) + for c in range(len(connection_candidate)): + i, j, s = connection_candidate[c][0:3] + if (i not in connection[:, 3] and j not in connection[:, 4]): + connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]]) + if (len(connection) >= min(nA, nB)): + break + + connection_all.append(connection) + else: + special_k.append(k) + connection_all.append([]) + + # last number in each row is the total parts number of that person + # the second last number in each row is the score of the overall configuration + subset = -1 * np.ones((0, 20)) + candidate = np.array([item for sublist in all_peaks for item in sublist]) + + for k in range(len(mapIdx)): + if k not in special_k: + partAs = connection_all[k][:, 0] + partBs = connection_all[k][:, 1] + indexA, indexB = np.array(limbSeq[k]) - 1 + + for i in range(len(connection_all[k])): # = 1:size(temp,1) + found = 0 + subset_idx = [-1, -1] + for j in range(len(subset)): # 1:size(subset,1): + if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]: + subset_idx[found] = j + found += 1 + + if found == 1: + j = subset_idx[0] + if subset[j][indexB] != partBs[i]: + subset[j][indexB] = partBs[i] + subset[j][-1] += 1 + subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] + elif found == 2: # if found 2 and disjoint, merge them + j1, j2 = subset_idx + membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2] + if len(np.nonzero(membership == 2)[0]) == 0: # merge + subset[j1][:-2] += (subset[j2][:-2] + 1) + subset[j1][-2:] += subset[j2][-2:] + subset[j1][-2] += connection_all[k][i][2] + subset = np.delete(subset, j2, 0) + else: # as like found == 1 + subset[j1][indexB] = partBs[i] + subset[j1][-1] += 1 + subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] + + # if find no partA in the subset, create a new subset + elif not found and k < 17: + row = -1 * np.ones(20) + row[indexA] = partAs[i] + row[indexB] = partBs[i] + row[-1] = 2 + row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2] + subset = np.vstack([subset, row]) + # delete some rows of subset which has few parts occur + deleteIdx = [] + for i in range(len(subset)): + if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4: + deleteIdx.append(i) + subset = np.delete(subset, deleteIdx, axis=0) + + # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts + # candidate: x, y, score, id + return candidate, subset + + @staticmethod + def format_body_result(candidate: np.ndarray, subset: np.ndarray) -> List[BodyResult]: + """ + Format the body results from the candidate and subset arrays into a list of BodyResult objects. + + Args: + candidate (np.ndarray): An array of candidates containing the x, y coordinates, score, and id + for each body part. + subset (np.ndarray): An array of subsets containing indices to the candidate array for each + person detected. The last two columns of each row hold the total score and total parts + of the person. + + Returns: + List[BodyResult]: A list of BodyResult objects, where each object represents a person with + detected keypoints, total score, and total parts. + """ + return [ + BodyResult( + keypoints=[ + Keypoint( + x=candidate[candidate_index][0], + y=candidate[candidate_index][1], + score=candidate[candidate_index][2], + id=candidate[candidate_index][3] + ) if candidate_index != -1 else None + for candidate_index in person[:18].astype(int) + ], + total_score=person[18], + total_parts=person[19] + ) + for person in subset + ] diff --git a/invokeai/backend/bria/controlnet_aux/open_pose/face.py b/invokeai/backend/bria/controlnet_aux/open_pose/face.py new file mode 100644 index 00000000000..41c7799af10 --- /dev/null +++ b/invokeai/backend/bria/controlnet_aux/open_pose/face.py @@ -0,0 +1,364 @@ +import logging + +import numpy as np +import torch +import torch.nn.functional as F +from torch.nn import Conv2d, MaxPool2d, Module, ReLU, init +from torchvision.transforms import ToPILImage, ToTensor + +from . import util + + +class FaceNet(Module): + """Model the cascading heatmaps. """ + def __init__(self): + super(FaceNet, self).__init__() + # cnn to make feature map + self.relu = ReLU() + self.max_pooling_2d = MaxPool2d(kernel_size=2, stride=2) + self.conv1_1 = Conv2d(in_channels=3, out_channels=64, + kernel_size=3, stride=1, padding=1) + self.conv1_2 = Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1, + padding=1) + self.conv2_1 = Conv2d( + in_channels=64, out_channels=128, kernel_size=3, stride=1, + padding=1) + self.conv2_2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=3, stride=1, + padding=1) + self.conv3_1 = Conv2d( + in_channels=128, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv3_2 = Conv2d( + in_channels=256, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv3_3 = Conv2d( + in_channels=256, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv3_4 = Conv2d( + in_channels=256, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv4_1 = Conv2d( + in_channels=256, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv4_2 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv4_3 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv4_4 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv5_1 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv5_2 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv5_3_CPM = Conv2d( + in_channels=512, out_channels=128, kernel_size=3, stride=1, + padding=1) + + # stage1 + self.conv6_1_CPM = Conv2d( + in_channels=128, out_channels=512, kernel_size=1, stride=1, + padding=0) + self.conv6_2_CPM = Conv2d( + in_channels=512, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage2 + self.Mconv1_stage2 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage2 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage3 + self.Mconv1_stage3 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage3 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage4 + self.Mconv1_stage4 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage4 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage5 + self.Mconv1_stage5 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage5 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage6 + self.Mconv1_stage6 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage6 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + for m in self.modules(): + if isinstance(m, Conv2d): + init.constant_(m.bias, 0) + + def forward(self, x): + """Return a list of heatmaps.""" + heatmaps = [] + + h = self.relu(self.conv1_1(x)) + h = self.relu(self.conv1_2(h)) + h = self.max_pooling_2d(h) + h = self.relu(self.conv2_1(h)) + h = self.relu(self.conv2_2(h)) + h = self.max_pooling_2d(h) + h = self.relu(self.conv3_1(h)) + h = self.relu(self.conv3_2(h)) + h = self.relu(self.conv3_3(h)) + h = self.relu(self.conv3_4(h)) + h = self.max_pooling_2d(h) + h = self.relu(self.conv4_1(h)) + h = self.relu(self.conv4_2(h)) + h = self.relu(self.conv4_3(h)) + h = self.relu(self.conv4_4(h)) + h = self.relu(self.conv5_1(h)) + h = self.relu(self.conv5_2(h)) + h = self.relu(self.conv5_3_CPM(h)) + feature_map = h + + # stage1 + h = self.relu(self.conv6_1_CPM(h)) + h = self.conv6_2_CPM(h) + heatmaps.append(h) + + # stage2 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage2(h)) + h = self.relu(self.Mconv2_stage2(h)) + h = self.relu(self.Mconv3_stage2(h)) + h = self.relu(self.Mconv4_stage2(h)) + h = self.relu(self.Mconv5_stage2(h)) + h = self.relu(self.Mconv6_stage2(h)) + h = self.Mconv7_stage2(h) + heatmaps.append(h) + + # stage3 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage3(h)) + h = self.relu(self.Mconv2_stage3(h)) + h = self.relu(self.Mconv3_stage3(h)) + h = self.relu(self.Mconv4_stage3(h)) + h = self.relu(self.Mconv5_stage3(h)) + h = self.relu(self.Mconv6_stage3(h)) + h = self.Mconv7_stage3(h) + heatmaps.append(h) + + # stage4 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage4(h)) + h = self.relu(self.Mconv2_stage4(h)) + h = self.relu(self.Mconv3_stage4(h)) + h = self.relu(self.Mconv4_stage4(h)) + h = self.relu(self.Mconv5_stage4(h)) + h = self.relu(self.Mconv6_stage4(h)) + h = self.Mconv7_stage4(h) + heatmaps.append(h) + + # stage5 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage5(h)) + h = self.relu(self.Mconv2_stage5(h)) + h = self.relu(self.Mconv3_stage5(h)) + h = self.relu(self.Mconv4_stage5(h)) + h = self.relu(self.Mconv5_stage5(h)) + h = self.relu(self.Mconv6_stage5(h)) + h = self.Mconv7_stage5(h) + heatmaps.append(h) + + # stage6 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage6(h)) + h = self.relu(self.Mconv2_stage6(h)) + h = self.relu(self.Mconv3_stage6(h)) + h = self.relu(self.Mconv4_stage6(h)) + h = self.relu(self.Mconv5_stage6(h)) + h = self.relu(self.Mconv6_stage6(h)) + h = self.Mconv7_stage6(h) + heatmaps.append(h) + + return heatmaps + + +LOG = logging.getLogger(__name__) +TOTEN = ToTensor() +TOPIL = ToPILImage() + + +params = { + 'gaussian_sigma': 2.5, + 'inference_img_size': 736, # 368, 736, 1312 + 'heatmap_peak_thresh': 0.1, + 'crop_scale': 1.5, + 'line_indices': [ + [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], + [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13], + [13, 14], [14, 15], [15, 16], + [17, 18], [18, 19], [19, 20], [20, 21], + [22, 23], [23, 24], [24, 25], [25, 26], + [27, 28], [28, 29], [29, 30], + [31, 32], [32, 33], [33, 34], [34, 35], + [36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 36], + [42, 43], [43, 44], [44, 45], [45, 46], [46, 47], [47, 42], + [48, 49], [49, 50], [50, 51], [51, 52], [52, 53], [53, 54], + [54, 55], [55, 56], [56, 57], [57, 58], [58, 59], [59, 48], + [60, 61], [61, 62], [62, 63], [63, 64], [64, 65], [65, 66], + [66, 67], [67, 60] + ], +} + + +class Face(object): + """ + The OpenPose face landmark detector model. + + Args: + inference_size: set the size of the inference image size, suggested: + 368, 736, 1312, default 736 + gaussian_sigma: blur the heatmaps, default 2.5 + heatmap_peak_thresh: return landmark if over threshold, default 0.1 + + """ + def __init__(self, face_model_path, + inference_size=None, + gaussian_sigma=None, + heatmap_peak_thresh=None): + self.inference_size = inference_size or params["inference_img_size"] + self.sigma = gaussian_sigma or params['gaussian_sigma'] + self.threshold = heatmap_peak_thresh or params["heatmap_peak_thresh"] + self.model = FaceNet() + self.model.load_state_dict(torch.load(face_model_path)) + self.model.eval() + + def to(self, device): + self.model.to(device) + return self + + def __call__(self, face_img): + device = next(iter(self.model.parameters())).device + H, W, C = face_img.shape + + w_size = 384 + x_data = torch.from_numpy(util.smart_resize(face_img, (w_size, w_size))).permute([2, 0, 1]) / 256.0 - 0.5 + + x_data = x_data.to(device) + + with torch.no_grad(): + hs = self.model(x_data[None, ...]) + heatmaps = F.interpolate( + hs[-1], + (H, W), + mode='bilinear', align_corners=True).cpu().numpy()[0] + return heatmaps + + def compute_peaks_from_heatmaps(self, heatmaps): + all_peaks = [] + for part in range(heatmaps.shape[0]): + map_ori = heatmaps[part].copy() + binary = np.ascontiguousarray(map_ori > 0.05, dtype=np.uint8) + + if np.sum(binary) == 0: + continue + + positions = np.where(binary > 0.5) + intensities = map_ori[positions] + mi = np.argmax(intensities) + y, x = positions[0][mi], positions[1][mi] + all_peaks.append([x, y]) + + return np.array(all_peaks) \ No newline at end of file diff --git a/invokeai/backend/bria/controlnet_aux/open_pose/hand.py b/invokeai/backend/bria/controlnet_aux/open_pose/hand.py new file mode 100644 index 00000000000..1387c4238c8 --- /dev/null +++ b/invokeai/backend/bria/controlnet_aux/open_pose/hand.py @@ -0,0 +1,90 @@ +import cv2 +import numpy as np +import torch +from scipy.ndimage.filters import gaussian_filter +from skimage.measure import label + +from . import util +from .model import handpose_model + + +class Hand(object): + def __init__(self, model_path): + self.model = handpose_model() + model_dict = util.transfer(self.model, torch.load(model_path)) + self.model.load_state_dict(model_dict) + self.model.eval() + + def to(self, device): + self.model.to(device) + return self + + def __call__(self, oriImgRaw): + device = next(iter(self.model.parameters())).device + scale_search = [0.5, 1.0, 1.5, 2.0] + # scale_search = [0.5] + boxsize = 368 + stride = 8 + padValue = 128 + thre = 0.05 + multiplier = [x * boxsize for x in scale_search] + + wsize = 128 + heatmap_avg = np.zeros((wsize, wsize, 22)) + + Hr, Wr, Cr = oriImgRaw.shape + + oriImg = cv2.GaussianBlur(oriImgRaw, (0, 0), 0.8) + + for m in range(len(multiplier)): + scale = multiplier[m] + imageToTest = util.smart_resize(oriImg, (scale, scale)) + + imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) + im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 + im = np.ascontiguousarray(im) + + data = torch.from_numpy(im).float() + data = data.to(device) + + with torch.no_grad(): + output = self.model(data).cpu().numpy() + + # extract outputs, resize, and remove padding + heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps + heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride) + heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + heatmap = util.smart_resize(heatmap, (wsize, wsize)) + + heatmap_avg += heatmap / len(multiplier) + + all_peaks = [] + for part in range(21): + map_ori = heatmap_avg[:, :, part] + one_heatmap = gaussian_filter(map_ori, sigma=3) + binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8) + + if np.sum(binary) == 0: + all_peaks.append([0, 0]) + continue + label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim) + max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1 + label_img[label_img != max_index] = 0 + map_ori[label_img == 0] = 0 + + y, x = util.npmax(map_ori) + y = int(float(y) * float(Hr) / float(wsize)) + x = int(float(x) * float(Wr) / float(wsize)) + all_peaks.append([x, y]) + return np.array(all_peaks) + +if __name__ == "__main__": + hand_estimation = Hand('../model/hand_pose_model.pth') + + # test_image = '../images/hand.jpg' + test_image = '../images/hand.jpg' + oriImg = cv2.imread(test_image) # B,G,R order + peaks = hand_estimation(oriImg) + canvas = util.draw_handpose(oriImg, peaks, True) + cv2.imshow('', canvas) + cv2.waitKey(0) \ No newline at end of file diff --git a/invokeai/backend/bria/controlnet_aux/open_pose/model.py b/invokeai/backend/bria/controlnet_aux/open_pose/model.py new file mode 100644 index 00000000000..6c3d4726898 --- /dev/null +++ b/invokeai/backend/bria/controlnet_aux/open_pose/model.py @@ -0,0 +1,217 @@ +import torch +from collections import OrderedDict + +import torch +import torch.nn as nn + +def make_layers(block, no_relu_layers): + layers = [] + for layer_name, v in block.items(): + if 'pool' in layer_name: + layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], + padding=v[2]) + layers.append((layer_name, layer)) + else: + conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], + kernel_size=v[2], stride=v[3], + padding=v[4]) + layers.append((layer_name, conv2d)) + if layer_name not in no_relu_layers: + layers.append(('relu_'+layer_name, nn.ReLU(inplace=True))) + + return nn.Sequential(OrderedDict(layers)) + +class bodypose_model(nn.Module): + def __init__(self): + super(bodypose_model, self).__init__() + + # these layers have no relu layer + no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\ + 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\ + 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\ + 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1'] + blocks = {} + block0 = OrderedDict([ + ('conv1_1', [3, 64, 3, 1, 1]), + ('conv1_2', [64, 64, 3, 1, 1]), + ('pool1_stage1', [2, 2, 0]), + ('conv2_1', [64, 128, 3, 1, 1]), + ('conv2_2', [128, 128, 3, 1, 1]), + ('pool2_stage1', [2, 2, 0]), + ('conv3_1', [128, 256, 3, 1, 1]), + ('conv3_2', [256, 256, 3, 1, 1]), + ('conv3_3', [256, 256, 3, 1, 1]), + ('conv3_4', [256, 256, 3, 1, 1]), + ('pool3_stage1', [2, 2, 0]), + ('conv4_1', [256, 512, 3, 1, 1]), + ('conv4_2', [512, 512, 3, 1, 1]), + ('conv4_3_CPM', [512, 256, 3, 1, 1]), + ('conv4_4_CPM', [256, 128, 3, 1, 1]) + ]) + + + # Stage 1 + block1_1 = OrderedDict([ + ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]), + ('conv5_5_CPM_L1', [512, 38, 1, 1, 0]) + ]) + + block1_2 = OrderedDict([ + ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]), + ('conv5_5_CPM_L2', [512, 19, 1, 1, 0]) + ]) + blocks['block1_1'] = block1_1 + blocks['block1_2'] = block1_2 + + self.model0 = make_layers(block0, no_relu_layers) + + # Stages 2 - 6 + for i in range(2, 7): + blocks['block%d_1' % i] = OrderedDict([ + ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]), + ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0]) + ]) + + blocks['block%d_2' % i] = OrderedDict([ + ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]), + ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0]) + ]) + + for k in blocks.keys(): + blocks[k] = make_layers(blocks[k], no_relu_layers) + + self.model1_1 = blocks['block1_1'] + self.model2_1 = blocks['block2_1'] + self.model3_1 = blocks['block3_1'] + self.model4_1 = blocks['block4_1'] + self.model5_1 = blocks['block5_1'] + self.model6_1 = blocks['block6_1'] + + self.model1_2 = blocks['block1_2'] + self.model2_2 = blocks['block2_2'] + self.model3_2 = blocks['block3_2'] + self.model4_2 = blocks['block4_2'] + self.model5_2 = blocks['block5_2'] + self.model6_2 = blocks['block6_2'] + + + def forward(self, x): + + out1 = self.model0(x) + + out1_1 = self.model1_1(out1) + out1_2 = self.model1_2(out1) + out2 = torch.cat([out1_1, out1_2, out1], 1) + + out2_1 = self.model2_1(out2) + out2_2 = self.model2_2(out2) + out3 = torch.cat([out2_1, out2_2, out1], 1) + + out3_1 = self.model3_1(out3) + out3_2 = self.model3_2(out3) + out4 = torch.cat([out3_1, out3_2, out1], 1) + + out4_1 = self.model4_1(out4) + out4_2 = self.model4_2(out4) + out5 = torch.cat([out4_1, out4_2, out1], 1) + + out5_1 = self.model5_1(out5) + out5_2 = self.model5_2(out5) + out6 = torch.cat([out5_1, out5_2, out1], 1) + + out6_1 = self.model6_1(out6) + out6_2 = self.model6_2(out6) + + return out6_1, out6_2 + +class handpose_model(nn.Module): + def __init__(self): + super(handpose_model, self).__init__() + + # these layers have no relu layer + no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\ + 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6'] + # stage 1 + block1_0 = OrderedDict([ + ('conv1_1', [3, 64, 3, 1, 1]), + ('conv1_2', [64, 64, 3, 1, 1]), + ('pool1_stage1', [2, 2, 0]), + ('conv2_1', [64, 128, 3, 1, 1]), + ('conv2_2', [128, 128, 3, 1, 1]), + ('pool2_stage1', [2, 2, 0]), + ('conv3_1', [128, 256, 3, 1, 1]), + ('conv3_2', [256, 256, 3, 1, 1]), + ('conv3_3', [256, 256, 3, 1, 1]), + ('conv3_4', [256, 256, 3, 1, 1]), + ('pool3_stage1', [2, 2, 0]), + ('conv4_1', [256, 512, 3, 1, 1]), + ('conv4_2', [512, 512, 3, 1, 1]), + ('conv4_3', [512, 512, 3, 1, 1]), + ('conv4_4', [512, 512, 3, 1, 1]), + ('conv5_1', [512, 512, 3, 1, 1]), + ('conv5_2', [512, 512, 3, 1, 1]), + ('conv5_3_CPM', [512, 128, 3, 1, 1]) + ]) + + block1_1 = OrderedDict([ + ('conv6_1_CPM', [128, 512, 1, 1, 0]), + ('conv6_2_CPM', [512, 22, 1, 1, 0]) + ]) + + blocks = {} + blocks['block1_0'] = block1_0 + blocks['block1_1'] = block1_1 + + # stage 2-6 + for i in range(2, 7): + blocks['block%d' % i] = OrderedDict([ + ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]), + ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0]) + ]) + + for k in blocks.keys(): + blocks[k] = make_layers(blocks[k], no_relu_layers) + + self.model1_0 = blocks['block1_0'] + self.model1_1 = blocks['block1_1'] + self.model2 = blocks['block2'] + self.model3 = blocks['block3'] + self.model4 = blocks['block4'] + self.model5 = blocks['block5'] + self.model6 = blocks['block6'] + + def forward(self, x): + out1_0 = self.model1_0(x) + out1_1 = self.model1_1(out1_0) + concat_stage2 = torch.cat([out1_1, out1_0], 1) + out_stage2 = self.model2(concat_stage2) + concat_stage3 = torch.cat([out_stage2, out1_0], 1) + out_stage3 = self.model3(concat_stage3) + concat_stage4 = torch.cat([out_stage3, out1_0], 1) + out_stage4 = self.model4(concat_stage4) + concat_stage5 = torch.cat([out_stage4, out1_0], 1) + out_stage5 = self.model5(concat_stage5) + concat_stage6 = torch.cat([out_stage5, out1_0], 1) + out_stage6 = self.model6(concat_stage6) + return out_stage6 diff --git a/invokeai/backend/bria/controlnet_aux/open_pose/util.py b/invokeai/backend/bria/controlnet_aux/open_pose/util.py new file mode 100644 index 00000000000..f10ca2dfcbf --- /dev/null +++ b/invokeai/backend/bria/controlnet_aux/open_pose/util.py @@ -0,0 +1,383 @@ +import math +import numpy as np +import cv2 +from typing import List, Tuple, Union + +from .body import BodyResult, Keypoint + +eps = 0.01 + + +def smart_resize(x, s): + Ht, Wt = s + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) + + +def smart_resize_k(x, fx, fy): + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + Ht, Wt = Ho * fy, Wo * fx + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) + + +def padRightDownCorner(img, stride, padValue): + h = img.shape[0] + w = img.shape[1] + + pad = 4 * [None] + pad[0] = 0 # up + pad[1] = 0 # left + pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down + pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right + + img_padded = img + pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) + img_padded = np.concatenate((pad_up, img_padded), axis=0) + pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) + img_padded = np.concatenate((pad_left, img_padded), axis=1) + pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) + img_padded = np.concatenate((img_padded, pad_down), axis=0) + pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) + img_padded = np.concatenate((img_padded, pad_right), axis=1) + + return img_padded, pad + + +def transfer(model, model_weights): + transfered_model_weights = {} + for weights_name in model.state_dict().keys(): + transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] + return transfered_model_weights + + +def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint]) -> np.ndarray: + """ + Draw keypoints and limbs representing body pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the body pose. + keypoints (List[Keypoint]): A list of Keypoint objects representing the body keypoints to be drawn. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn body pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + H, W, C = canvas.shape + stickwidth = 4 + + limbSeq = [ + [2, 3], [2, 6], [3, 4], [4, 5], + [6, 7], [7, 8], [2, 9], [9, 10], + [10, 11], [2, 12], [12, 13], [13, 14], + [2, 1], [1, 15], [15, 17], [1, 16], + [16, 18], + ] + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + + for (k1_index, k2_index), color in zip(limbSeq, colors): + keypoint1 = keypoints[k1_index - 1] + keypoint2 = keypoints[k2_index - 1] + + if keypoint1 is None or keypoint2 is None: + continue + + Y = np.array([keypoint1.x, keypoint2.x]) * float(W) + X = np.array([keypoint1.y, keypoint2.y]) * float(H) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(canvas, polygon, [int(float(c) * 0.6) for c in color]) + + for keypoint, color in zip(keypoints, colors): + if keypoint is None: + continue + + x, y = keypoint.x, keypoint.y + x = int(x * W) + y = int(y * H) + cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1) + + return canvas + + +def draw_handpose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray: + import matplotlib + """ + Draw keypoints and connections representing hand pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. + keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn + or None if no keypoints are present. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + if not keypoints: + return canvas + + H, W, C = canvas.shape + + edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ + [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] + + for ie, (e1, e2) in enumerate(edges): + k1 = keypoints[e1] + k2 = keypoints[e2] + if k1 is None or k2 is None: + continue + + x1 = int(k1.x * W) + y1 = int(k1.y * H) + x2 = int(k2.x * W) + y2 = int(k2.y * H) + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2) + + for keypoint in keypoints: + x, y = keypoint.x, keypoint.y + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + return canvas + + +def draw_facepose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray: + """ + Draw keypoints representing face pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the face pose. + keypoints (List[Keypoint]| None): A list of Keypoint objects representing the face keypoints to be drawn + or None if no keypoints are present. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn face pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + if not keypoints: + return canvas + + H, W, C = canvas.shape + for keypoint in keypoints: + x, y = keypoint.x, keypoint.y + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) + return canvas + + +# detect hand according to body pose keypoints +# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp +def handDetect(body: BodyResult, oriImg) -> List[Tuple[int, int, int, bool]]: + """ + Detect hands in the input body pose keypoints and calculate the bounding box for each hand. + + Args: + body (BodyResult): A BodyResult object containing the detected body pose keypoints. + oriImg (numpy.ndarray): A 3D numpy array representing the original input image. + + Returns: + List[Tuple[int, int, int, bool]]: A list of tuples, each containing the coordinates (x, y) of the top-left + corner of the bounding box, the width (height) of the bounding box, and + a boolean flag indicating whether the hand is a left hand (True) or a + right hand (False). + + Notes: + - The width and height of the bounding boxes are equal since the network requires squared input. + - The minimum bounding box size is 20 pixels. + """ + ratioWristElbow = 0.33 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + + keypoints = body.keypoints + # right hand: wrist 4, elbow 3, shoulder 2 + # left hand: wrist 7, elbow 6, shoulder 5 + left_shoulder = keypoints[5] + left_elbow = keypoints[6] + left_wrist = keypoints[7] + right_shoulder = keypoints[2] + right_elbow = keypoints[3] + right_wrist = keypoints[4] + + # if any of three not detected + has_left = all(keypoint is not None for keypoint in (left_shoulder, left_elbow, left_wrist)) + has_right = all(keypoint is not None for keypoint in (right_shoulder, right_elbow, right_wrist)) + if not (has_left or has_right): + return [] + + hands = [] + #left hand + if has_left: + hands.append([ + left_shoulder.x, left_shoulder.y, + left_elbow.x, left_elbow.y, + left_wrist.x, left_wrist.y, + True + ]) + # right hand + if has_right: + hands.append([ + right_shoulder.x, right_shoulder.y, + right_elbow.x, right_elbow.y, + right_wrist.x, right_wrist.y, + False + ]) + + for x1, y1, x2, y2, x3, y3, is_left in hands: + # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox + # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); + # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); + # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); + # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); + # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); + x = x3 + ratioWristElbow * (x3 - x2) + y = y3 + ratioWristElbow * (y3 - y2) + distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) + distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) + # x-y refers to the center --> offset to topLeft point + # handRectangle.x -= handRectangle.width / 2.f; + # handRectangle.y -= handRectangle.height / 2.f; + x -= width / 2 + y -= width / 2 # width = height + # overflow the image + if x < 0: x = 0 + if y < 0: y = 0 + width1 = width + width2 = width + if x + width > image_width: width1 = image_width - x + if y + width > image_height: width2 = image_height - y + width = min(width1, width2) + # the max hand box value is 20 pixels + if width >= 20: + detect_result.append((int(x), int(y), int(width), is_left)) + + ''' + return value: [[x, y, w, True if left hand else False]]. + width=height since the network require squared input. + x, y is the coordinate of top left + ''' + return detect_result + + +# Written by Lvmin +def faceDetect(body: BodyResult, oriImg) -> Union[Tuple[int, int, int], None]: + """ + Detect the face in the input body pose keypoints and calculate the bounding box for the face. + + Args: + body (BodyResult): A BodyResult object containing the detected body pose keypoints. + oriImg (numpy.ndarray): A 3D numpy array representing the original input image. + + Returns: + Tuple[int, int, int] | None: A tuple containing the coordinates (x, y) of the top-left corner of the + bounding box and the width (height) of the bounding box, or None if the + face is not detected or the bounding box width is less than 20 pixels. + + Notes: + - The width and height of the bounding box are equal. + - The minimum bounding box size is 20 pixels. + """ + # left right eye ear 14 15 16 17 + image_height, image_width = oriImg.shape[0:2] + + keypoints = body.keypoints + head = keypoints[0] + left_eye = keypoints[14] + right_eye = keypoints[15] + left_ear = keypoints[16] + right_ear = keypoints[17] + + if head is None or all(keypoint is None for keypoint in (left_eye, right_eye, left_ear, right_ear)): + return None + + width = 0.0 + x0, y0 = head.x, head.y + + if left_eye is not None: + x1, y1 = left_eye.x, left_eye.y + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if right_eye is not None: + x1, y1 = right_eye.x, right_eye.y + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if left_ear is not None: + x1, y1 = left_ear.x, left_ear.y + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + if right_ear is not None: + x1, y1 = right_ear.x, right_ear.y + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + x, y = x0, y0 + + x -= width + y -= width + + if x < 0: + x = 0 + + if y < 0: + y = 0 + + width1 = width * 2 + width2 = width * 2 + + if x + width > image_width: + width1 = image_width - x + + if y + width > image_height: + width2 = image_height - y + + width = min(width1, width2) + + if width >= 20: + return int(x), int(y), int(width) + else: + return None + + +# get max index of 2d array +def npmax(array): + arrayindex = array.argmax(1) + arrayvalue = array.max(1) + i = arrayvalue.argmax() + j = arrayindex[i] + return i, j diff --git a/invokeai/backend/bria/controlnet_aux/util.py b/invokeai/backend/bria/controlnet_aux/util.py new file mode 100644 index 00000000000..79ba7f120cc --- /dev/null +++ b/invokeai/backend/bria/controlnet_aux/util.py @@ -0,0 +1,146 @@ +import os +import random + +import cv2 +import numpy as np +import torch + +annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') + + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + +def make_noise_disk(H, W, C, F): + noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C)) + noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) + noise = noise[F: F + H, F: F + W] + noise -= np.min(noise) + noise /= np.max(noise) + if C == 1: + noise = noise[:, :, None] + return noise + + +def nms(x, t, s): + x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) + + f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) + f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) + f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) + f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) + + y = np.zeros_like(x) + + for f in [f1, f2, f3, f4]: + np.putmask(y, cv2.dilate(x, kernel=f) == x, x) + + z = np.zeros_like(y, dtype=np.uint8) + z[y > t] = 255 + return z + +def min_max_norm(x): + x -= np.min(x) + x /= np.maximum(np.max(x), 1e-5) + return x + + +def safe_step(x, step=2): + y = x.astype(np.float32) * float(step + 1) + y = y.astype(np.int32).astype(np.float32) / float(step) + return y + + +def img2mask(img, H, W, low=10, high=90): + assert img.ndim == 3 or img.ndim == 2 + assert img.dtype == np.uint8 + + if img.ndim == 3: + y = img[:, :, random.randrange(0, img.shape[2])] + else: + y = img + + y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC) + + if random.uniform(0, 1) < 0.5: + y = 255 - y + + return y < np.percentile(y, random.randrange(low, high)) + + +def resize_image(input_image, resolution): + H, W, C = input_image.shape + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) + return img + + +def torch_gc(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +def ade_palette(): + """ADE20K palette that maps each class to RGB values.""" + return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], + [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], + [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], + [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], + [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], + [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], + [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], + [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], + [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], + [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], + [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], + [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], + [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], + [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], + [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], + [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], + [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], + [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], + [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], + [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], + [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], + [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], + [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], + [102, 255, 0], [92, 0, 255]] + diff --git a/invokeai/backend/bria/controlnet_bria.py b/invokeai/backend/bria/controlnet_bria.py new file mode 100644 index 00000000000..a845afbcf2e --- /dev/null +++ b/invokeai/backend/bria/controlnet_bria.py @@ -0,0 +1,543 @@ +# type: ignore +# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Literal +from enum import Enum + +import torch +import torch.nn as nn + +from invokeai.backend.bria.transformer_bria import TimestepProjEmbeddings, FluxSingleTransformerBlock, FluxTransformerBlock, EmbedND +from diffusers.models.controlnet import zero_module +from diffusers.utils.outputs import BaseOutput +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from diffusers.models.modeling_outputs import Transformer2DModelOutput + +from diffusers.models.attention_processor import AttentionProcessor + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +BRIA_CONTROL_MODES = Literal["depth", "canny", "colorgrid", "recolor", "tile", "pose"] +class BriaControlModes(Enum): + depth = 0 + canny = 1 + colorgrid = 2 + recolor = 3 + tile = 4 + pose = 5 + + +@dataclass +class BriaControlNetOutput(BaseOutput): + controlnet_block_samples: Tuple[torch.Tensor] + controlnet_single_block_samples: Tuple[torch.Tensor] + + +class BriaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 64, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 4096, + pooled_projection_dim: int = 768, + guidance_embeds: bool = False, + axes_dims_rope: List[int] = [16, 56, 56], + num_mode: int = None, + rope_theta: int = 10000, + time_theta: int = 10000, + ): + super().__init__() + self.out_channels = in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + # self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) + self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope) + + # text_time_guidance_cls = ( + # CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings + # ) + # self.time_text_embed = text_time_guidance_cls( + # embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim + # ) + self.time_embed = TimestepProjEmbeddings( + embedding_dim=self.inner_dim, time_theta=time_theta + ) + + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) + self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for i in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for i in range(num_single_layers) + ] + ) + + # controlnet_blocks + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(len(self.transformer_blocks)): + self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) + + self.controlnet_single_blocks = nn.ModuleList([]) + for _ in range(len(self.single_transformer_blocks)): + self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) + + self.union = num_mode is not None and num_mode > 0 + if self.union: + self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim) + + self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self): + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor): + r""" + Sets the attention processor to use to compute attention. + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @classmethod + def from_transformer( + cls, + transformer, + num_layers: int = 4, + num_single_layers: int = 10, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + load_weights_from_transformer=True, + ): + config = transformer.config + config["num_layers"] = num_layers + config["num_single_layers"] = num_single_layers + config["attention_head_dim"] = attention_head_dim + config["num_attention_heads"] = num_attention_heads + + controlnet = cls(**config) + + if load_weights_from_transformer: + controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) + controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict()) + controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict()) + controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict()) + controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) + controlnet.single_transformer_blocks.load_state_dict( + transformer.single_transformer_blocks.state_dict(), strict=False + ) + + controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder) + + return controlnet + + def forward( + self, + hidden_states: torch.Tensor, + controlnet_cond: torch.Tensor, + controlnet_mode: torch.Tensor = None, + conditioning_scale: float = 1.0, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + controlnet_cond (`torch.Tensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + controlnet_mode (`torch.Tensor`): + The mode tensor of shape `(batch_size, 1)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if guidance is not None: + print("guidance is not supported in BriaControlNetModel") + if pooled_projections is not None: + print("pooled_projections is not supported in BriaControlNetModel") + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + hidden_states = self.x_embedder(hidden_states) + + # Convert controlnet_cond to the same dtype as the model weights + controlnet_cond = controlnet_cond.to(dtype=self.controlnet_x_embedder.weight.dtype) + + # add + hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond) + + timestep = timestep.to(hidden_states.dtype) # Original code was * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) # Original code was * 1000 + else: + guidance = None + + temb = self.time_embed(timestep, dtype=hidden_states.dtype) + + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + if self.union: + # union mode + if controlnet_mode is None: + raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union") + + # Validate controlnet_mode values are within the valid range + if torch.any(controlnet_mode < 0) or torch.any(controlnet_mode >= self.num_mode): + raise ValueError(f"`controlnet_mode` values must be in range [0, {self.num_mode-1}], but got values outside this range") + + # union mode emb + controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode) + if controlnet_mode_emb.shape[0] < encoder_hidden_states.shape[0]: # duplicate mode emb for each batch + controlnet_mode_emb = controlnet_mode_emb.expand(encoder_hidden_states.shape[0], 1, encoder_hidden_states.shape[2]) + encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1) + + txt_ids = torch.cat((txt_ids[0:1, :], txt_ids), dim=0) + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + block_samples = () + for index_block, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + block_samples = block_samples + (hidden_states,) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + single_block_samples = () + for index_block, block in enumerate(self.single_transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) + + # controlnet block + controlnet_block_samples = () + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): + block_sample = controlnet_block(block_sample) + controlnet_block_samples = controlnet_block_samples + (block_sample,) + + controlnet_single_block_samples = () + for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks): + single_block_sample = controlnet_block(single_block_sample) + controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,) + + # scaling + controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples] + controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples] + + controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples + controlnet_single_block_samples = ( + None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples + ) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (controlnet_block_samples, controlnet_single_block_samples) + + return BriaControlNetOutput( + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + ) + + +class BriaMultiControlNetModel(ModelMixin): + r""" + `BriaMultiControlNetModel` wrapper class for Multi-BriaControlNetModel + This module is a wrapper for multiple instances of the `BriaControlNetModel`. The `forward()` API is designed to be + compatible with `BriaControlNetModel`. + Args: + controlnets (`List[BriaControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. You must set multiple + `BriaControlNetModel` as a list. + """ + + def __init__(self, controlnets): + super().__init__() + self.nets = nn.ModuleList(controlnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + controlnet_cond: List[torch.tensor], + controlnet_mode: List[torch.tensor], + conditioning_scale: List[float], + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[BriaControlNetOutput, Tuple]: + # ControlNet-Union with multiple conditions + # only load one ControlNet for saving memories + if len(self.nets) == 1 and self.nets[0].union: + controlnet = self.nets[0] + + for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)): + block_samples, single_block_samples = controlnet( + hidden_states=hidden_states, + controlnet_cond=image, + controlnet_mode=mode[:, None], + conditioning_scale=scale, + timestep=timestep, + guidance=guidance, + pooled_projections=pooled_projections, + encoder_hidden_states=encoder_hidden_states, + txt_ids=txt_ids, + img_ids=img_ids, + joint_attention_kwargs=joint_attention_kwargs, + return_dict=return_dict, + ) + + # merge samples + if i == 0: + control_block_samples = block_samples + control_single_block_samples = single_block_samples + else: + control_block_samples = [ + control_block_sample + block_sample + for control_block_sample, block_sample in zip(control_block_samples, block_samples) + ] + + control_single_block_samples = [ + control_single_block_sample + block_sample + for control_single_block_sample, block_sample in zip( + control_single_block_samples, single_block_samples + ) + ] + + # Regular Multi-ControlNets + # load all ControlNets into memories + else: + for i, (image, mode, scale, controlnet) in enumerate( + zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets) + ): + block_samples, single_block_samples = controlnet( + hidden_states=hidden_states, + controlnet_cond=image, + controlnet_mode=mode[:, None], + conditioning_scale=scale, + timestep=timestep, + guidance=guidance, + pooled_projections=pooled_projections, + encoder_hidden_states=encoder_hidden_states, + txt_ids=txt_ids, + img_ids=img_ids, + joint_attention_kwargs=joint_attention_kwargs, + return_dict=return_dict, + ) + + # merge samples + if i == 0: + control_block_samples = block_samples + control_single_block_samples = single_block_samples + else: + if block_samples is not None and control_block_samples is not None: + control_block_samples = [ + control_block_sample + block_sample + for control_block_sample, block_sample in zip(control_block_samples, block_samples) + ] + if single_block_samples is not None and control_single_block_samples is not None: + control_single_block_samples = [ + control_single_block_sample + block_sample + for control_single_block_sample, block_sample in zip( + control_single_block_samples, single_block_samples + ) + ] + + return control_block_samples, control_single_block_samples \ No newline at end of file diff --git a/invokeai/backend/bria/controlnet_utils.py b/invokeai/backend/bria/controlnet_utils.py new file mode 100644 index 00000000000..91dc270c846 --- /dev/null +++ b/invokeai/backend/bria/controlnet_utils.py @@ -0,0 +1,69 @@ +from typing import List, Tuple +from PIL import Image +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL + +from diffusers.image_processor import VaeImageProcessor + +import torch + + + +@torch.no_grad() +def prepare_control_images( + vae: AutoencoderKL, + control_images: list[Image.Image], + control_modes: list[int], + width: int, + height: int, + device: torch.device, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + + tensored_control_images = [] + tensored_control_modes = [] + for idx, control_image_ in enumerate(control_images): + tensored_control_image = _prepare_image( + image=control_image_, + width=width, + height=height, + device=device, + dtype=vae.dtype, + ) + height, width = tensored_control_image.shape[-2:] + + # vae encode + tensored_control_image = vae.encode(tensored_control_image).latent_dist.sample() + tensored_control_image = (tensored_control_image) * vae.config.scaling_factor + + # pack + height_control_image, width_control_image = tensored_control_image.shape[2:] + tensored_control_image = _pack_latents( + tensored_control_image, + height_control_image, + width_control_image, + ) + tensored_control_images.append(tensored_control_image) + tensored_control_modes.append(torch.tensor(control_modes[idx]).expand( + tensored_control_image.shape[0]).to(device, dtype=torch.long)) + + return tensored_control_images, tensored_control_modes + +def _prepare_image( + image: Image.Image, + width: int, + height: int, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + image = image.convert("RGB") + image = VaeImageProcessor(vae_scale_factor=16).preprocess(image, height=height, width=width) + image = image.repeat_interleave(1, dim=0) + image = image.to(device=device, dtype=dtype) + return image + +def _pack_latents(latents, height, width): + latents = latents.view(1, 4, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(1, (height // 2) * (width // 2), 16) + + return latents + diff --git a/invokeai/backend/bria/pipeline_bria.py b/invokeai/backend/bria/pipeline_bria.py new file mode 100644 index 00000000000..7a195a6ae75 --- /dev/null +++ b/invokeai/backend/bria/pipeline_bria.py @@ -0,0 +1,647 @@ +from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, retrieve_timesteps, calculate_shift +from typing import Any, Callable, Dict, List, Optional, Union + +import torch + +from transformers import ( + T5EncoderModel, + T5TokenizerFast, +) + +from diffusers.image_processor import VaeImageProcessor +from diffusers import AutoencoderKL , DDIMScheduler, EulerAncestralDiscreteScheduler +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.loaders import FluxLoraLoaderMixin +from diffusers.utils import ( + USE_PEFT_BACKEND, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput +from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel +from invokeai.backend.bria.bria_utils import get_t5_prompt_embeds, get_original_sigmas, is_ng_none +from diffusers.utils.torch_utils import randn_tensor +import diffusers +import numpy as np + +XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusion3Pipeline + + >>> pipe = StableDiffusion3Pipeline.from_pretrained( + ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> image = pipe(prompt).images[0] + >>> image.save("sd3.png") + ``` +""" + +T5_PRECISION = torch.float16 + +""" +Based on FluxPipeline with several changes: +- no pooled embeddings +- We use zero padding for prompts +- No guidance embedding since this is not a distilled version +""" +class BriaPipeline(FluxPipeline): + r""" + Args: + transformer ([`SD3Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. Stable Diffusion 3 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + """ + + def __init__( + self, + transformer: BriaTransformer2DModel, + scheduler: Union[FlowMatchEulerDiscreteScheduler,KarrasDiffusionSchedulers], + vae: AutoencoderKL, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast + ): + self.register_modules( + vae=vae, + transformer=transformer, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + + # TODO - why different than offical flux (-1) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = 64 # due to patchify=> 128,128 => res of 1k,1k + + # T5 is senstive to precision so we use the precision used for precompute and cast as needed + + if self.vae.config.shift_factor is None: + self.vae.config.shift_factor=0 + self.vae.to(dtype=torch.float32) + + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 128, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = get_t5_prompt_embeds( + self.tokenizer, + self.text_encoder, + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ).to(dtype=self.transformer.dtype) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + if not is_ng_none(negative_prompt): + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = get_t5_prompt_embeds( + self.tokenizer, + self.text_encoder, + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ).to(dtype=self.transformer.dtype) + else: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) + + return prompt_embeds, negative_prompt_embeds, text_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 30, + timesteps: List[int] = None, + guidance_scale: float = 5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 128, + clip_value:Union[None,float] = None, + normalize:bool = False + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + text_ids + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + if isinstance(self.scheduler,FlowMatchEulerDiscreteScheduler) and self.scheduler.config['use_dynamic_shifting']: + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] # Shift by height - Why just height? + print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}") + + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + else: + # 4. Prepare timesteps + # Sample from training sigmas + if isinstance(self.scheduler,DDIMScheduler) or isinstance(self.scheduler,EulerAncestralDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, None) + else: + sigmas = get_original_sigmas(num_train_timesteps=self.scheduler.config.num_train_timesteps,num_inference_steps=num_inference_steps) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps,sigmas=sigmas) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # Supprot different diffusers versions + if diffusers.__version__>='0.32.0': + latent_image_ids=latent_image_ids[0] + text_ids=text_ids[0] + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + if type(self.scheduler)!=FlowMatchEulerDiscreteScheduler: + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # This is predicts "v" from flow-matching or eps from diffusion + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + txt_ids=text_ids, + img_ids=latent_image_ids, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + cfg_noise_pred_text = noise_pred_text.std() + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if normalize: + noise_pred = noise_pred * (0.7 *(cfg_noise_pred_text/noise_pred.std())) + 0.3 * noise_pred + + if clip_value: + assert clip_value>0 + noise_pred = noise_pred.clip(-clip_value,clip_value) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents.to(dtype=torch.float32) / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + def to(self, *args, **kwargs): + DiffusionPipeline.to(self, *args, **kwargs) + # T5 is senstive to precision so we use the precision used for precompute and cast as needed + self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION) + for block in self.text_encoder.encoder.block: + block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32) + + if self.vae.config.shift_factor == 0 and self.vae.dtype!=torch.float32: + self.vae.to(dtype=torch.float32) + + + return self + + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor ) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + + return latents + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + + + + + + diff --git a/invokeai/backend/bria/pipeline_bria_controlnet.py b/invokeai/backend/bria/pipeline_bria_controlnet.py new file mode 100644 index 00000000000..b6106cd02ca --- /dev/null +++ b/invokeai/backend/bria/pipeline_bria_controlnet.py @@ -0,0 +1,672 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, List, Optional, Union +from diffusers.utils.peft_utils import scale_lora_layers, unscale_lora_layers +import torch +from transformers import ( + T5EncoderModel, + T5TokenizerFast, +) +from diffusers.image_processor import PipelineImageInput + +from diffusers import AutoencoderKL # Waiting for diffusers udpdate +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import logging, USE_PEFT_BACKEND +from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput +from invokeai.backend.bria.controlnet_bria import BriaControlNetModel, BriaMultiControlNetModel +from diffusers.pipelines.flux.pipeline_flux import retrieve_timesteps, calculate_shift +from invokeai.backend.bria.pipeline_bria import BriaPipeline +from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel +from invokeai.backend.bria.bria_utils import get_original_sigmas +import numpy as np +import diffusers +from invokeai.backend.bria.bria_utils import get_t5_prompt_embeds, get_original_sigmas, is_ng_none +from diffusers.utils.torch_utils import randn_tensor + +XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class BriaControlNetPipeline(BriaPipeline): + r""" + Args: + transformer ([`SD3Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. Stable Diffusion 3 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] + + def __init__( # EYAL - removed clip text encoder + tokenizer + self, + transformer: BriaTransformer2DModel, + scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers], + vae: AutoencoderKL, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + controlnet: BriaControlNetModel, + ): + super().__init__( + transformer=transformer, scheduler=scheduler, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer + ) + self.register_modules(controlnet=controlnet) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + def prepare_control(self, control_image, width, height, batch_size, num_images_per_prompt, device, control_mode): + num_channels_latents = self.transformer.config.in_channels // 4 + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + + # vae encode + control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + # Here we ensure that `control_mode` has the same length as the control_image. + if control_mode is not None: + if not isinstance(control_mode, int): + raise ValueError(" For `BriaControlNet`, `control_mode` should be an `int` or `None`") + control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) + control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1) + + return control_image, control_mode + + def prepare_multi_control(self, control_image, width, height, batch_size, num_images_per_prompt, device, control_mode): + num_channels_latents = self.transformer.config.in_channels // 4 + control_images = [] + for i, control_image_ in enumerate(control_image): + control_image_ = self.prepare_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image_.shape[-2:] + + # vae encode + control_image_ = self.vae.encode(control_image_).latent_dist.sample() + control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image_.shape[2:] + control_image_ = self._pack_latents( + control_image_, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + control_images.append(control_image_) + + control_image = control_images + + # Here we ensure that `control_mode` has the same length as the control_image. + if isinstance(control_mode, list) and len(control_mode) != len(control_image): + raise ValueError( + "For Multi-ControlNet, `control_mode` must be a list of the same " + + " length as the number of controlnets (control images) specified" + ) + if not isinstance(control_mode, list): + control_mode = [control_mode] * len(control_image) + # set control mode + control_modes = [] + for cmode in control_mode: + if cmode is None: + cmode = -1 + control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long) + control_modes.append(control_mode) + control_mode = control_modes + + return control_image, control_mode + + def get_controlnet_keep(self, timesteps, control_guidance_start, control_guidance_end): + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(self.controlnet, BriaControlNetModel) else keeps) + return controlnet_keep + + def get_control_start_end(self, control_guidance_start, control_guidance_end): + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = 1 # TODO - why is this 1? + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + return control_guidance_start, control_guidance_end + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 30, + timesteps: List[int] = None, + guidance_scale: float = 3.5, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + control_image: Optional[PipelineImageInput] = None, + control_mode: Optional[Union[int, List[int]]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + latent_image_ids: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + text_ids: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 128, + ): + r""" + Function invoked when calling the pipeline for generation. + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + Examples: + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + control_guidance_start, control_guidance_end = self.get_control_start_end( + control_guidance_start=control_guidance_start, control_guidance_end=control_guidance_end + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + device = self._execution_device + + + # 4. Prepare timesteps + if isinstance(self.scheduler,FlowMatchEulerDiscreteScheduler) and self.scheduler.config['use_dynamic_shifting']: + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + + + # Determine image sequence length + if control_image is not None: + if type(control_image) == list: + image_seq_len = control_image[0].shape[1] + else: + image_seq_len = control_image.shape[1] + else: + # Use latents sequence length when no control image is provided + image_seq_len = latents.shape[1] + + print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}") + + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps=None, + sigmas=sigmas, + mu=mu, + ) + else: + # 5. Prepare timesteps + sigmas = get_original_sigmas( + num_train_timesteps=self.scheduler.config.num_train_timesteps, num_inference_steps=num_inference_steps + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Create tensor stating which controlnets to keep + if control_image is not None: + controlnet_keep = self.get_controlnet_keep( + timesteps=timesteps, + control_guidance_start=control_guidance_start, + control_guidance_end=control_guidance_end, + ) + + if diffusers.__version__>='0.32.0': + latent_image_ids=latent_image_ids[0] + text_ids=text_ids[0] + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # EYAL - added the CFG loop + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # if type(self.scheduler) != FlowMatchEulerDiscreteScheduler: + if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # Handling ControlNet + if control_image is not None: + if isinstance(controlnet_keep[i], list): + if isinstance(controlnet_conditioning_scale, list): + cond_scale = controlnet_conditioning_scale + else: + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + controlnet_block_samples, controlnet_single_block_samples = self.controlnet( + hidden_states=latents, + controlnet_cond=control_image, + controlnet_mode=control_mode, + conditioning_scale=cond_scale, + timestep=timestep, + # guidance=guidance, + # pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + ) + else: + controlnet_block_samples, controlnet_single_block_samples = None, None + + # This is predicts "v" from flow-matching + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + txt_ids=text_ids, + img_ids=latent_image_ids, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) + + +def encode_prompt( + prompt: Union[str, List[str]], + tokenizer: T5TokenizerFast, + text_encoder: T5EncoderModel, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 128, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + device = device or torch.device("cuda") + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + # dynamically adjust the LoRA scale + if text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(text_encoder, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + dtype = text_encoder.dtype if text_encoder is not None else torch.float32 + if prompt_embeds is None: + prompt_embeds = get_t5_prompt_embeds( + tokenizer, + text_encoder, + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ).to(dtype=dtype) + + if negative_prompt_embeds is None: + if not is_ng_none(negative_prompt): + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = get_t5_prompt_embeds( + tokenizer, + text_encoder, + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ).to(dtype=dtype) + else: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + + if text_encoder is not None: + if USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(text_encoder, lora_scale) + + text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) + + return prompt_embeds, negative_prompt_embeds, text_ids + + +def prepare_latents( + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, +): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + vae_scale_factor = 16 + height = 2 * (int(height) // vae_scale_factor) + width = 2 * (int(width) // vae_scale_factor ) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = _pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + +def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + + + +def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents \ No newline at end of file diff --git a/invokeai/backend/model_manager/legacy_probe.py b/invokeai/backend/model_manager/legacy_probe.py index aeee6fd42f4..19252ae7f96 100644 --- a/invokeai/backend/model_manager/legacy_probe.py +++ b/invokeai/backend/model_manager/legacy_probe.py @@ -126,7 +126,7 @@ class ModelProbe(object): CLASS2TYPE = { "BriaPipeline": ModelType.Main, - "BriaControlNetModel": ModelType.ControlNet, + "BriaTransformer2DModel": ModelType.ControlNet, "FluxPipeline": ModelType.Main, "StableDiffusionPipeline": ModelType.Main, "StableDiffusionInpaintPipeline": ModelType.Main, @@ -1014,7 +1014,7 @@ def get_base_type(self) -> BaseModelType: if config.get("_class_name", None) == "FluxControlNetModel": return BaseModelType.Flux - if config.get("_class_name", None) == "BriaControlNetModel": + if config.get("_class_name", None) == "BriaTransformer2DModel": return BaseModelType.Bria # no obvious way to distinguish between sd2-base and sd2-768 diff --git a/invokeai/backend/model_manager/load/model_loaders/bria.py b/invokeai/backend/model_manager/load/model_loaders/bria.py index 02a2c0835fc..c5d6ec6f433 100644 --- a/invokeai/backend/model_manager/load/model_loaders/bria.py +++ b/invokeai/backend/model_manager/load/model_loaders/bria.py @@ -31,14 +31,11 @@ def _load_model( if isinstance(config, ControlNetCheckpointConfig): raise NotImplementedError("CheckpointConfigBase is not implemented for Bria models.") - if submodel_type is None: - raise Exception("A submodel type must be provided when loading control net pipelines.") - model_path = Path(config.path) - load_class = self.get_hf_load_class(model_path, submodel_type) + load_class = self.get_hf_load_class(model_path) repo_variant = config.repo_variant if isinstance(config, ControlNetDiffusersConfig) else None variant = repo_variant.value if repo_variant else None - model_path = model_path / submodel_type.value + model_path = model_path dtype = self._torch_dtype @@ -47,6 +44,7 @@ def _load_model( model_path, torch_dtype=dtype, variant=variant, + use_safetensors=False, ) except OSError as e: if variant and "no file named" in str( diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py index 245d812a014..837f75a8bb0 100644 --- a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -84,6 +84,9 @@ def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # ]: if module == "transformer_bria": module = "invokeai.backend.bria.transformer_bria" + elif class_name == "BriaTransformer2DModel": + class_name = "BriaControlNetModel" + module = "invokeai.backend.bria.controlnet_bria" res_type = sys.modules[module] else: res_type = sys.modules["diffusers"].pipelines diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index 47c6eb70bcb..5d396ffad2c 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -12,6 +12,9 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast +from invokeai.backend.bria.controlnet_aux.open_pose.body import Body +from invokeai.backend.bria.controlnet_aux.open_pose.face import Face +from invokeai.backend.bria.controlnet_aux.open_pose.hand import Hand from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline @@ -62,6 +65,8 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int: else: # If neither is available, return 0 return 0 + elif isinstance(model, (Body, Hand, Face)): + return calc_module_size(model.model) elif isinstance( model, ( diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BriaControlNetModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BriaControlNetModelFieldInputComponent.tsx index 31f732abb24..bafba67ceaa 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BriaControlNetModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BriaControlNetModelFieldInputComponent.tsx @@ -6,8 +6,8 @@ import type { BriaControlNetModelFieldInputTemplate, } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import { useBriaModels } from 'services/api/hooks/modelsByType'; -import type { MainModelConfig } from 'services/api/types'; +import { useBriaControlNetModels } from 'services/api/hooks/modelsByType'; +import type { ControlNetModelConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -16,9 +16,9 @@ type Props = FieldComponentProps { const { nodeId, field } = props; const dispatch = useAppDispatch(); - const [modelConfigs, { isLoading }] = useBriaModels(); + const [modelConfigs, { isLoading }] = useBriaControlNetModels(); const onChange = useCallback( - (value: MainModelConfig | null) => { + (value: ControlNetModelConfig | null) => { if (!value) { return; } diff --git a/invokeai/nodes/bria_nodes/__init__.py b/invokeai/nodes/bria_nodes/__init__.py new file mode 100644 index 00000000000..a6049b4362e --- /dev/null +++ b/invokeai/nodes/bria_nodes/__init__.py @@ -0,0 +1,6 @@ +from .bria_decoder import BriaDecoderInvocation +from .bria_model_loader import BriaModelLoaderInvocation +from .bria_denoiser import BriaDenoiseInvocation +from .bria_latent_sampler import BriaLatentSamplerInvocation +from .bria_text_encoder import BriaTextEncoderInvocation +from .bria_full import BriaFullInvocation diff --git a/invokeai/nodes/bria_nodes/bria_controlnet.py b/invokeai/nodes/bria_nodes/bria_controlnet.py new file mode 100644 index 00000000000..20d49988da6 --- /dev/null +++ b/invokeai/nodes/bria_nodes/bria_controlnet.py @@ -0,0 +1,145 @@ +from invokeai.backend.bria.controlnet_bria import BRIA_CONTROL_MODES +from pydantic import BaseModel, Field +from invokeai.invocation_api import ImageOutput +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + invocation, + invocation_output, +) +from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType +from invokeai.app.invocations.model import ModelIdentifierField +from invokeai.app.services.shared.invocation_context import InvocationContext +import numpy as np +import cv2 +from PIL import Image + +from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline +from invokeai.backend.bria.controlnet_aux.open_pose import OpenposeDetector, Body, Hand, Face + +DEPTH_SMALL_V2_URL = "depth-anything/Depth-Anything-V2-Small-hf" +HF_LLLYASVIEL = "https://huggingface.co/lllyasviel/Annotators/resolve/main/" + +class BriaControlNetField(BaseModel): + image: ImageField = Field(description="The control image") + model: ModelIdentifierField = Field(description="The ControlNet model to use") + mode: BRIA_CONTROL_MODES = Field(description="The mode of the ControlNet") + conditioning_scale: float = Field(description="The weight given to the ControlNet") + +@invocation_output("flux_controlnet_output") +class BriaControlNetOutput(BaseInvocationOutput): + """FLUX ControlNet info""" + + control: BriaControlNetField = OutputField(description=FieldDescriptions.control) + preprocessed_images: ImageField = OutputField(description="The preprocessed control image") + + +@invocation( + "bria_controlnet", + title="Bria ControlNet", + tags=["controlnet", "bria"], + category="controlnet", + version="1.0.0", +) +class BriaControlNetInvocation(BaseInvocation): + """Collect Bria ControlNet info to pass to denoiser node.""" + + control_image: ImageField = InputField(description="The control image") + control_model: ModelIdentifierField = InputField( + description=FieldDescriptions.controlnet_model, ui_type=UIType.BriaControlNetModel + ) + control_mode: BRIA_CONTROL_MODES = InputField( + default="depth", description="The mode of the ControlNet" + ) + control_weight: float = InputField( + default=1.0, ge=-1, le=2, description="The weight given to the ControlNet" + ) + + def invoke(self, context: InvocationContext) -> BriaControlNetOutput: + image_in = resize_img(context.images.get_pil(self.control_image.image_name)) + if self.control_mode == "canny": + control_image = extract_canny(image_in) + elif self.control_mode == "depth": + control_image = extract_depth(image_in, context) + elif self.control_mode == "pose": + control_image = extract_openpose(image_in, context) + elif self.control_mode == "colorgrid": + control_image = tile(64, image_in) + elif self.control_mode == "recolor": + control_image = convert_to_grayscale(image_in) + elif self.control_mode == "tile": + control_image = tile(16, image_in) + + control_image = resize_img(control_image) + image_dto = context.images.save(image=control_image) + image_output = ImageOutput.build(image_dto) + return BriaControlNetOutput( + preprocessed_images=image_output.image, + control=BriaControlNetField( + image=ImageField(image_name=image_dto.image_name), + model=self.control_model, + mode=self.control_mode, + conditioning_scale=self.control_weight, + ), + ) + + +RATIO_CONFIGS_1024 = { + 0.6666666666666666: {"width": 832, "height": 1248}, + 0.7432432432432432: {"width": 880, "height": 1184}, + 0.8028169014084507: {"width": 912, "height": 1136}, + 1.0: {"width": 1024, "height": 1024}, + 1.2456140350877194: {"width": 1136, "height": 912}, + 1.3454545454545455: {"width": 1184, "height": 880}, + 1.4339622641509433: {"width": 1216, "height": 848}, + 1.5: {"width": 1248, "height": 832}, + 1.5490196078431373: {"width": 1264, "height": 816}, + 1.62: {"width": 1296, "height": 800}, + 1.7708333333333333: {"width": 1360, "height": 768}, +} + +def extract_depth(image: Image.Image, context: InvocationContext): + loaded_model = context.models.load_remote_model(DEPTH_SMALL_V2_URL, DepthAnythingPipeline.load_model) + + with loaded_model as depth_anything_detector: + assert isinstance(depth_anything_detector, DepthAnythingPipeline) + depth_map = depth_anything_detector.generate_depth(image) + return depth_map + +def extract_openpose(image: Image.Image, context: InvocationContext): + body_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}body_pose_model.pth", Body) + hand_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}hand_pose_model.pth", Hand) + face_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}facenet.pth", Face) + + with body_model as body_model, hand_model as hand_model, face_model as face_model: + open_pose_model = OpenposeDetector(body_model, hand_model, face_model) + processed_image_open_pose = open_pose_model(image, hand_and_face=True) + + processed_image_open_pose = processed_image_open_pose.resize(image.size) + return processed_image_open_pose + + +def extract_canny(input_image): + image = np.array(input_image) + image = cv2.Canny(image, 100, 200) + image = image[:, :, None] + image = np.concatenate([image, image, image], axis=2) + canny_image = Image.fromarray(image) + return canny_image + + +def convert_to_grayscale(image): + gray_image = image.convert('L').convert('RGB') + return gray_image + +def tile(downscale_factor, input_image): + control_image = input_image.resize((input_image.size[0] // downscale_factor, input_image.size[1] // downscale_factor)).resize(input_image.size, Image.Resampling.NEAREST) + return control_image + +def resize_img(control_image): + image_ratio = control_image.width / control_image.height + ratio = min(RATIO_CONFIGS_1024.keys(), key=lambda k: abs(k - image_ratio)) + to_height = RATIO_CONFIGS_1024[ratio]["height"] + to_width = RATIO_CONFIGS_1024[ratio]["width"] + resized_image = control_image.resize((to_width, to_height), resample=Image.Resampling.LANCZOS) + return resized_image diff --git a/invokeai/nodes/bria_nodes/bria_denoiser.py b/invokeai/nodes/bria_nodes/bria_denoiser.py index 081c3392f47..834e290de4f 100644 --- a/invokeai/nodes/bria_nodes/bria_denoiser.py +++ b/invokeai/nodes/bria_nodes/bria_denoiser.py @@ -1,20 +1,18 @@ +from typing import List, Tuple +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from invokeai.backend.bria.controlnet_bria import BriaControlModes, BriaMultiControlNetModel +from invokeai.backend.bria.controlnet_utils import prepare_control_images +from invokeai.backend.bria.pipeline_bria_controlnet import BriaControlNetPipeline +from invokeai.nodes.bria_nodes.bria_controlnet import BriaControlNetField + import torch from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler -from invokeai.app.invocations.fields import Input, InputField -from invokeai.app.invocations.model import SubModelType, TransformerField -from invokeai.app.invocations.primitives import ( - BaseInvocationOutput, - FieldDescriptions, - Input, - InputField, - LatentsField, - OutputField, -) +from invokeai.app.invocations.fields import Input, InputField, LatentsField, OutputField +from invokeai.app.invocations.model import SubModelType, T5EncoderField, TransformerField, VAEField +from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.invocation_api import BaseInvocation, Classification, InputField, invocation, invocation_output - -from invokeai.backend.bria.pipeline import get_original_sigmas, retrieve_timesteps from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel @invocation_output("bria_denoise_output") @@ -43,6 +41,16 @@ class BriaDenoiseInvocation(BaseInvocation): input=Input.Connection, title="Transformer", ) + t5_encoder: T5EncoderField = InputField( + title="T5Encoder", + description=FieldDescriptions.t5_encoder, + input=Input.Connection, + ) + vae: VAEField = InputField( + description=FieldDescriptions.vae, + input=Input.Connection, + title="VAE", + ) latents: LatentsField = InputField( description="Latents to denoise", input=Input.Connection, @@ -68,6 +76,12 @@ class BriaDenoiseInvocation(BaseInvocation): input=Input.Connection, title="Text IDs", ) + control: BriaControlNetField | list[BriaControlNetField] | None = InputField( + description="ControlNet", + input=Input.Connection, + title="ControlNet", + default = None, + ) @torch.no_grad() def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput: @@ -83,51 +97,89 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput: with ( context.models.load(self.transformer.transformer) as transformer, context.models.load(scheduler_identifier) as scheduler, + context.models.load(self.vae.vae) as vae, + context.models.load(self.t5_encoder.text_encoder) as t5_encoder, + context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer, ): assert isinstance(transformer, BriaTransformer2DModel) assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler) + assert isinstance(vae, AutoencoderKL) dtype = transformer.dtype device = transformer.device latents, pos_embeds, neg_embeds = map(lambda x: x.to(device, dtype), (latents, pos_embeds, neg_embeds)) - prompt_embeds = torch.cat([neg_embeds, pos_embeds]) if self.guidance_scale > 1 else pos_embeds - - sigmas = get_original_sigmas(1000, self.num_steps) - timesteps, _ = retrieve_timesteps(scheduler, self.num_steps, device, None, sigmas, mu=0.0) - - for t in timesteps: - # Prepare model input efficiently - if self.guidance_scale > 1: - latent_model_input = torch.cat([latents] * 2) - else: - latent_model_input = latents - - # Prepare timestep tensor efficiently - if isinstance(t, torch.Tensor): - timestep_tensor = t.expand(latent_model_input.shape[0]) - else: - timestep_tensor = torch.tensor([t] * latent_model_input.shape[0], device=device, dtype=torch.float32) - - noise_pred = transformer( - latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep_tensor, - img_ids=latent_image_ids, - txt_ids=text_ids, - guidance=None, - return_dict=False, - )[0] - - if self.guidance_scale > 1: - noise_uncond, noise_text = noise_pred.chunk(2) - noise_pred = noise_uncond + self.guidance_scale * (noise_text - noise_uncond) - - # Convert timestep for scheduler - t_step = float(t.item()) if isinstance(t, torch.Tensor) else float(t) - - # Use scheduler step with proper dtypes - latents = scheduler.step(noise_pred, t_step, latents, return_dict=False)[0] + + control_model, control_images, control_modes, control_scales = None, None, None, None + if self.control is not None: + control_model, control_images, control_modes, control_scales = self._prepare_multi_control( + context=context, + vae=vae, + width=1024, + height=1024, + device=vae.device, + ) + + pipeline = BriaControlNetPipeline( + transformer=transformer, + scheduler=scheduler, + vae=vae, + text_encoder=t5_encoder, + tokenizer=t5_tokenizer, + controlnet=control_model, + ) + pipeline.to(device=transformer.device, dtype=transformer.dtype) + + latents = pipeline( + control_image=control_images, + control_mode=control_modes, + width=1024, + height=1024, + controlnet_conditioning_scale=control_scales, + num_inference_steps=self.num_steps, + max_sequence_length=128, + guidance_scale=self.guidance_scale, + latents=latents, + latent_image_ids=latent_image_ids, + text_ids=text_ids, + prompt_embeds=pos_embeds, + negative_prompt_embeds=neg_embeds, + output_type="latent", + )[0] assert isinstance(latents, torch.Tensor) saved_input_latents_tensor = context.tensors.save(latents) latents_output = LatentsField(latents_name=saved_input_latents_tensor) return BriaDenoiseInvocationOutput(latents=latents_output) + + + def _prepare_multi_control( + self, + context: InvocationContext, + vae: AutoencoderKL, + width: int, + height: int, + device: torch.device + ) -> Tuple[BriaMultiControlNetModel, List[torch.Tensor], List[torch.Tensor], List[float]]: + + control = self.control if isinstance(self.control, list) else [self.control] + control_images, control_models, control_modes, control_scales = [], [], [], [] + for controlnet in control: + if controlnet is not None: + control_models.append(context.models.load(controlnet.model).model) + control_modes.append(BriaControlModes[controlnet.mode].value) + control_scales.append(controlnet.conditioning_scale) + try: + control_images.append(context.images.get_pil(controlnet.image.image_name)) + except: + raise FileNotFoundError(f"Control image {controlnet.image.image_name} not found. Make sure not to delete the preprocessed image before finishing the pipeline.") + + control_model = BriaMultiControlNetModel(control_models).to(device) + tensored_control_images, tensored_control_modes = prepare_control_images( + vae=vae, + control_images=control_images, + control_modes=control_modes, + width=width, + height=height, + device=device, + ) + return control_model, tensored_control_images, tensored_control_modes, control_scales + \ No newline at end of file diff --git a/invokeai/nodes/bria_nodes/bria_latent_sampler.py b/invokeai/nodes/bria_nodes/bria_latent_sampler.py index 36170ff5d90..f3ce74729b2 100644 --- a/invokeai/nodes/bria_nodes/bria_latent_sampler.py +++ b/invokeai/nodes/bria_nodes/bria_latent_sampler.py @@ -9,7 +9,7 @@ LatentsField, OutputField, ) -from invokeai.backend.model_manager.config import MainDiffusersConfig +from invokeai.backend.bria.pipeline_bria_controlnet import prepare_latents from invokeai.invocation_api import ( BaseInvocation, Classification, @@ -50,23 +50,19 @@ class BriaLatentSamplerInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> BriaLatentSamplerInvocationOutput: device = torch.device("cuda") - transformer_config = context.models.get_config(self.transformer.transformer) - if not isinstance(transformer_config, MainDiffusersConfig): - raise ValueError("Transformer config is not a MainDiffusersConfig") - # TODO: get latent channels from transformer config - latent_channels = 16 - latent_height, latent_width = 128, 128 - shrunk = latent_channels // 4 - gen = torch.Generator(device=device).manual_seed(self.seed) - - noise4d = torch.randn((1, shrunk, latent_height, latent_width), device=device, generator=gen) - latents = noise4d.view(1, shrunk, latent_height // 2, 2, latent_width // 2, 2).permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(1, (latent_height // 2) * (latent_width // 2), shrunk * 4) - - latent_image_ids = torch.zeros((latent_height // 2, latent_width // 2, 3), device=device, dtype=torch.long) - latent_image_ids[..., 1] = torch.arange(latent_height // 2, device=device)[:, None] - latent_image_ids[..., 2] = torch.arange(latent_width // 2, device=device)[None, :] - latent_image_ids = latent_image_ids.view(-1, 3) + height, width = 1024, 1024 + generator = torch.Generator(device=device).manual_seed(self.seed) + + num_channels_latents = 4 # due to patch=2, we devide by 4 + latents, latent_image_ids = prepare_latents( + batch_size=1, + num_channels_latents=num_channels_latents, + height=height, + width=width, + dtype=torch.float32, + device=device, + generator=generator, + ) saved_latents_tensor = context.tensors.save(latents) saved_latent_image_ids_tensor = context.tensors.save(latent_image_ids) diff --git a/invokeai/nodes/bria_nodes/bria_text_encoder.py b/invokeai/nodes/bria_nodes/bria_text_encoder.py index 143a873bb10..03a9a1d50b4 100644 --- a/invokeai/nodes/bria_nodes/bria_text_encoder.py +++ b/invokeai/nodes/bria_nodes/bria_text_encoder.py @@ -9,6 +9,7 @@ from invokeai.app.invocations.model import T5EncoderField from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions, Input, OutputField from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.bria.pipeline_bria_controlnet import encode_prompt from invokeai.invocation_api import ( BaseInvocation, Classification, @@ -46,6 +47,7 @@ class BriaTextEncoderInvocation(BaseInvocation): negative_prompt: Optional[str] = InputField( title="Negative Prompt", description="The negative prompt to encode", + default="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate", ) max_length: int = InputField( default=128, @@ -68,17 +70,20 @@ def invoke(self, context: InvocationContext) -> BriaTextEncoderInvocationOutput: ): assert isinstance(tokenizer, T5TokenizerFast) assert isinstance(text_encoder, T5EncoderModel) - pos = get_t5_prompt_embeds(tokenizer, text_encoder, self.prompt, 1, self.max_length, text_encoder.device) - neg = ( - torch.zeros_like(pos) - if is_ng_none(self.negative_prompt) - else get_t5_prompt_embeds( - tokenizer, text_encoder, self.negative_prompt, 1, self.max_length, text_encoder.device - ) - ) - text_ids = torch.zeros((pos.shape[1], 3), device=text_encoder.device, dtype=torch.long) - saved_pos_tensor = context.tensors.save(pos) - saved_neg_tensor = context.tensors.save(neg) + + (prompt_embeds, negative_prompt_embeds, text_ids) = encode_prompt( + prompt=self.prompt, + tokenizer=tokenizer, + text_encoder=text_encoder, + negative_prompt=self.negative_prompt, + device=text_encoder.device, + num_images_per_prompt=1, + max_sequence_length=self.max_length, + lora_scale=1.0, + ) + + saved_pos_tensor = context.tensors.save(prompt_embeds) + saved_neg_tensor = context.tensors.save(negative_prompt_embeds) saved_text_ids_tensor = context.tensors.save(text_ids) pos_embeds_output = LatentsField(latents_name=saved_pos_tensor) neg_embeds_output = LatentsField(latents_name=saved_neg_tensor) From 9131c456455ea758bc4bbab34672d37af46a1458 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 14 Jul 2025 13:23:04 +0000 Subject: [PATCH 07/14] Added scikit-image required for Bria's OpenposeDetector model --- pyproject.toml | 3 +- uv.lock | 100 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8e2d827763d..33f66aca5bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,7 +73,8 @@ dependencies = [ "pypatchmatch", "python-multipart", "requests", - "semver~=3.0.1" + "semver~=3.0.1", + "scikit-image" ] [project.optional-dependencies] diff --git a/uv.lock b/uv.lock index 0d82784a0f9..b1634c4d67b 100644 --- a/uv.lock +++ b/uv.lock @@ -948,6 +948,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, ] +[[package]] +name = "imageio" +version = "2.37.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "pillow" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0c/47/57e897fb7094afb2d26e8b2e4af9a45c7cf1a405acdeeca001fdf2c98501/imageio-2.37.0.tar.gz", hash = "sha256:71b57b3669666272c818497aebba2b4c5f20d5b37c81720e5e1a56d59c492996", size = 389963, upload-time = "2025-01-20T02:42:37.089Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/bd/b394387b598ed84d8d0fa90611a90bee0adc2021820ad5729f7ced74a8e2/imageio-2.37.0-py3-none-any.whl", hash = "sha256:11efa15b87bc7871b61590326b2d635439acc321cf7f8ce996f812543ce10eed", size = 315796, upload-time = "2025-01-20T02:42:34.931Z" }, +] + [[package]] name = "importlib-metadata" version = "8.7.0" @@ -1017,6 +1030,7 @@ dependencies = [ { name = "python-socketio" }, { name = "requests" }, { name = "safetensors" }, + { name = "scikit-image" }, { name = "semver" }, { name = "sentencepiece" }, { name = "spandrel" }, @@ -1127,6 +1141,7 @@ requires-dist = [ { name = "ruff", marker = "extra == 'test'", specifier = "~=0.11.2" }, { name = "ruff-lsp", marker = "extra == 'test'", specifier = "~=0.0.62" }, { name = "safetensors" }, + { name = "scikit-image" }, { name = "semver", specifier = "~=3.0.1" }, { name = "sentencepiece" }, { name = "snakeviz", marker = "extra == 'dev'" }, @@ -1356,6 +1371,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3a/1d/50ad811d1c5dae091e4cf046beba925bcae0a610e79ae4c538f996f63ed5/kiwisolver-1.4.8-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:65ea09a5a3faadd59c2ce96dc7bf0f364986a315949dc6374f04396b0d60e09b", size = 71762, upload-time = "2024-12-24T18:30:48.903Z" }, ] +[[package]] +name = "lazy-loader" +version = "0.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6f/6b/c875b30a1ba490860c93da4cabf479e03f584eba06fe5963f6f6644653d8/lazy_loader-0.4.tar.gz", hash = "sha256:47c75182589b91a4e1a85a136c074285a5ad4d9f39c63e0d7fb76391c4574cd1", size = 15431, upload-time = "2024-04-05T13:03:12.261Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/60/d497a310bde3f01cb805196ac61b7ad6dc5dcf8dce66634dc34364b20b4f/lazy_loader-0.4-py3-none-any.whl", hash = "sha256:342aa8e14d543a154047afb4ba8ef17f5563baad3fc610d7b15b213b0f119efc", size = 12097, upload-time = "2024-04-05T13:03:10.514Z" }, +] + [[package]] name = "lsprotocol" version = "2023.0.1" @@ -2947,6 +2974,42 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/69/e2/b011c38e5394c4c18fb5500778a55ec43ad6106126e74723ffaee246f56e/safetensors-0.5.3-cp38-abi3-win_amd64.whl", hash = "sha256:836cbbc320b47e80acd40e44c8682db0e8ad7123209f69b093def21ec7cafd11", size = 308878, upload-time = "2025-02-26T09:15:14.99Z" }, ] +[[package]] +name = "scikit-image" +version = "0.25.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "imageio" }, + { name = "lazy-loader" }, + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pillow" }, + { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "scipy", version = "1.16.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "tifffile", version = "2025.5.10", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "tifffile", version = "2025.6.11", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c7/a8/3c0f256012b93dd2cb6fda9245e9f4bff7dc0486880b248005f15ea2255e/scikit_image-0.25.2.tar.gz", hash = "sha256:e5a37e6cd4d0c018a7a55b9d601357e3382826d3888c10d0213fc63bff977dde", size = 22693594, upload-time = "2025-02-18T18:05:24.538Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/cb/016c63f16065c2d333c8ed0337e18a5cdf9bc32d402e4f26b0db362eb0e2/scikit_image-0.25.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d3278f586793176599df6a4cf48cb6beadae35c31e58dc01a98023af3dc31c78", size = 13988922, upload-time = "2025-02-18T18:04:11.069Z" }, + { url = "https://files.pythonhosted.org/packages/30/ca/ff4731289cbed63c94a0c9a5b672976603118de78ed21910d9060c82e859/scikit_image-0.25.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:5c311069899ce757d7dbf1d03e32acb38bb06153236ae77fcd820fd62044c063", size = 13192698, upload-time = "2025-02-18T18:04:15.362Z" }, + { url = "https://files.pythonhosted.org/packages/39/6d/a2aadb1be6d8e149199bb9b540ccde9e9622826e1ab42fe01de4c35ab918/scikit_image-0.25.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:be455aa7039a6afa54e84f9e38293733a2622b8c2fb3362b822d459cc5605e99", size = 14153634, upload-time = "2025-02-18T18:04:18.496Z" }, + { url = "https://files.pythonhosted.org/packages/96/08/916e7d9ee4721031b2f625db54b11d8379bd51707afaa3e5a29aecf10bc4/scikit_image-0.25.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4c464b90e978d137330be433df4e76d92ad3c5f46a22f159520ce0fdbea8a09", size = 14767545, upload-time = "2025-02-18T18:04:22.556Z" }, + { url = "https://files.pythonhosted.org/packages/5f/ee/c53a009e3997dda9d285402f19226fbd17b5b3cb215da391c4ed084a1424/scikit_image-0.25.2-cp310-cp310-win_amd64.whl", hash = "sha256:60516257c5a2d2f74387c502aa2f15a0ef3498fbeaa749f730ab18f0a40fd054", size = 12812908, upload-time = "2025-02-18T18:04:26.364Z" }, + { url = "https://files.pythonhosted.org/packages/c4/97/3051c68b782ee3f1fb7f8f5bb7d535cf8cb92e8aae18fa9c1cdf7e15150d/scikit_image-0.25.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f4bac9196fb80d37567316581c6060763b0f4893d3aca34a9ede3825bc035b17", size = 14003057, upload-time = "2025-02-18T18:04:30.395Z" }, + { url = "https://files.pythonhosted.org/packages/19/23/257fc696c562639826065514d551b7b9b969520bd902c3a8e2fcff5b9e17/scikit_image-0.25.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d989d64ff92e0c6c0f2018c7495a5b20e2451839299a018e0e5108b2680f71e0", size = 13180335, upload-time = "2025-02-18T18:04:33.449Z" }, + { url = "https://files.pythonhosted.org/packages/ef/14/0c4a02cb27ca8b1e836886b9ec7c9149de03053650e9e2ed0625f248dd92/scikit_image-0.25.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b2cfc96b27afe9a05bc92f8c6235321d3a66499995675b27415e0d0c76625173", size = 14144783, upload-time = "2025-02-18T18:04:36.594Z" }, + { url = "https://files.pythonhosted.org/packages/dd/9b/9fb556463a34d9842491d72a421942c8baff4281025859c84fcdb5e7e602/scikit_image-0.25.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:24cc986e1f4187a12aa319f777b36008764e856e5013666a4a83f8df083c2641", size = 14785376, upload-time = "2025-02-18T18:04:39.856Z" }, + { url = "https://files.pythonhosted.org/packages/de/ec/b57c500ee85885df5f2188f8bb70398481393a69de44a00d6f1d055f103c/scikit_image-0.25.2-cp311-cp311-win_amd64.whl", hash = "sha256:b4f6b61fc2db6340696afe3db6b26e0356911529f5f6aee8c322aa5157490c9b", size = 12791698, upload-time = "2025-02-18T18:04:42.868Z" }, + { url = "https://files.pythonhosted.org/packages/35/8c/5df82881284459f6eec796a5ac2a0a304bb3384eec2e73f35cfdfcfbf20c/scikit_image-0.25.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8db8dd03663112783221bf01ccfc9512d1cc50ac9b5b0fe8f4023967564719fb", size = 13986000, upload-time = "2025-02-18T18:04:47.156Z" }, + { url = "https://files.pythonhosted.org/packages/ce/e6/93bebe1abcdce9513ffec01d8af02528b4c41fb3c1e46336d70b9ed4ef0d/scikit_image-0.25.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:483bd8cc10c3d8a7a37fae36dfa5b21e239bd4ee121d91cad1f81bba10cfb0ed", size = 13235893, upload-time = "2025-02-18T18:04:51.049Z" }, + { url = "https://files.pythonhosted.org/packages/53/4b/eda616e33f67129e5979a9eb33c710013caa3aa8a921991e6cc0b22cea33/scikit_image-0.25.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d1e80107bcf2bf1291acfc0bf0425dceb8890abe9f38d8e94e23497cbf7ee0d", size = 14178389, upload-time = "2025-02-18T18:04:54.245Z" }, + { url = "https://files.pythonhosted.org/packages/6b/b5/b75527c0f9532dd8a93e8e7cd8e62e547b9f207d4c11e24f0006e8646b36/scikit_image-0.25.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a17e17eb8562660cc0d31bb55643a4da996a81944b82c54805c91b3fe66f4824", size = 15003435, upload-time = "2025-02-18T18:04:57.586Z" }, + { url = "https://files.pythonhosted.org/packages/34/e3/49beb08ebccda3c21e871b607c1cb2f258c3fa0d2f609fed0a5ba741b92d/scikit_image-0.25.2-cp312-cp312-win_amd64.whl", hash = "sha256:bdd2b8c1de0849964dbc54037f36b4e9420157e67e45a8709a80d727f52c7da2", size = 12899474, upload-time = "2025-02-18T18:05:01.166Z" }, +] + [[package]] name = "scipy" version = "1.15.3" @@ -3197,6 +3260,43 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" }, ] +[[package]] +name = "tifffile" +version = "2025.5.10" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11' and sys_platform == 'darwin'", + "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')", +] +dependencies = [ + { name = "numpy", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/44/d0/18fed0fc0916578a4463f775b0fbd9c5fed2392152d039df2fb533bfdd5d/tifffile-2025.5.10.tar.gz", hash = "sha256:018335d34283aa3fd8c263bae5c3c2b661ebc45548fde31504016fcae7bf1103", size = 365290, upload-time = "2025-05-10T19:22:34.386Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/06/bd0a6097da704a7a7c34a94cfd771c3ea3c2f405dd214e790d22c93f6be1/tifffile-2025.5.10-py3-none-any.whl", hash = "sha256:e37147123c0542d67bc37ba5cdd67e12ea6fbe6e86c52bee037a9eb6a064e5ad", size = 226533, upload-time = "2025-05-10T19:22:27.279Z" }, +] + +[[package]] +name = "tifffile" +version = "2025.6.11" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12' and sys_platform == 'darwin'", + "python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version >= '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux')", + "python_full_version == '3.11.*' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')", +] +dependencies = [ + { name = "numpy", marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/11/9e/636e3e433c24da41dd639e0520db60750dbf5e938d023b83af8097382ea3/tifffile-2025.6.11.tar.gz", hash = "sha256:0ece4c2e7a10656957d568a093b07513c0728d30c1bd8cc12725901fffdb7143", size = 370125, upload-time = "2025-06-12T04:49:38.839Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3a/d8/1ba8f32bfc9cb69e37edeca93738e883f478fbe84ae401f72c0d8d507841/tifffile-2025.6.11-py3-none-any.whl", hash = "sha256:32effb78b10b3a283eb92d4ebf844ae7e93e151458b0412f38518b4e6d2d7542", size = 230800, upload-time = "2025-06-12T04:49:37.458Z" }, +] + [[package]] name = "tokenizers" version = "0.21.2" From efc5a762fce148219e0afd0b2cecb7b9cfb62516 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 14 Jul 2025 13:40:24 +0000 Subject: [PATCH 08/14] removed unused file --- invokeai/backend/bria/pipeline.py | 459 ------------------------------ 1 file changed, 459 deletions(-) delete mode 100644 invokeai/backend/bria/pipeline.py diff --git a/invokeai/backend/bria/pipeline.py b/invokeai/backend/bria/pipeline.py deleted file mode 100644 index d62e695db73..00000000000 --- a/invokeai/backend/bria/pipeline.py +++ /dev/null @@ -1,459 +0,0 @@ -#!/usr/bin/env python -""" -Bria TextΓÇætoΓÇæImage Pipeline (GPUΓÇæready) -Using your local Bria checkpoints. -""" - -from typing import Optional - -import numpy as np -import torch -import torch.nn as nn - -# Your bria_utils imports -from .bria_utils import get_original_sigmas, get_t5_prompt_embeds, is_ng_none -from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler -from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps -from PIL import Image -from tqdm import tqdm # add this at the top of your file - -# Your custom transformer import -from .transformer_bria import BriaTransformer2DModel -from transformers import T5EncoderModel, T5TokenizerFast - - -# ----------------------------------------------------------------------------- -# 1. Model Loader -# ----------------------------------------------------------------------------- -class BriaModelLoader: - def __init__( - self, - transformer_ckpt: str, - vae_ckpt: str, - text_encoder_ckpt: str, - tokenizer_ckpt: str, - device: torch.device, - ): - self.device = device - - # print("Loading Bria Transformer from", transformer_ckpt) - # self.transformer = BriaTransformer2DModel.from_pretrained(transformer_ckpt, torch_dtype=torch.bfloat16).to(device) - - # print("Loading VAE from", vae_ckpt) - # self.vae = AutoencoderKL.from_pretrained(vae_ckpt, torch_dtype=torch.float32).to(device) - - # print("Loading T5 Encoder from", text_encoder_ckpt) - # self.text_encoder = T5EncoderModel.from_pretrained(text_encoder_ckpt, torch_dtype=torch.float16).to(device) - - # print("Loading Tokenizer from", tokenizer_ckpt) - # self.tokenizer = T5TokenizerFast.from_pretrained(tokenizer_ckpt, legacy=False) - self.transformer = BriaTransformer2DModel.from_pretrained(transformer_ckpt, torch_dtype=torch.float16).to( - device - ) - self.vae = AutoencoderKL.from_pretrained(vae_ckpt, torch_dtype=torch.float16).to(device) - self.text_encoder = T5EncoderModel.from_pretrained(text_encoder_ckpt, torch_dtype=torch.float16).to(device) - self.tokenizer = T5TokenizerFast.from_pretrained(tokenizer_ckpt) - - def get(self): - return { - "transformer": self.transformer, - "vae": self.vae, - "text_encoder": self.text_encoder, - "tokenizer": self.tokenizer, - } - - -# ----------------------------------------------------------------------------- -# 2. Text Encoder (uses bria_utils) -# ----------------------------------------------------------------------------- -class BriaTextEncoder: - def __init__( - self, - text_encoder: T5EncoderModel, - tokenizer: T5TokenizerFast, - device: torch.device, - max_length: int = 128, - ): - self.model = text_encoder.to(device) - self.tokenizer = tokenizer - self.device = device - self.max_length = max_length - - def encode( - self, - prompt: str, - negative_prompt: Optional[str] = None, - num_images_per_prompt: int = 1, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # 1) get positive embeddings - pos = get_t5_prompt_embeds( - tokenizer=self.tokenizer, - text_encoder=self.model, - prompt=prompt, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=self.max_length, - device=self.device, - ) - # 2) get negative or zeros - if negative_prompt is None or is_ng_none(negative_prompt): - neg = torch.zeros_like(pos) - else: - neg = get_t5_prompt_embeds( - tokenizer=self.tokenizer, - text_encoder=self.model, - prompt=negative_prompt, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=self.max_length, - device=self.device, - ) - - # 3) build text_ids: shape [S_text, 3] - # S_text = number of tokens = pos.shape[1] - S_text = pos.shape[1] - text_ids = torch.zeros((1, S_text, 3), device=self.device, dtype=torch.long) - text_ids = torch.zeros((S_text, 3), device=self.device, dtype=torch.long) - - print(f"Text embeds shapes ΓåÆ pos: {pos.shape}, neg: {neg.shape}, text_ids: {text_ids.shape}") - return pos, neg, text_ids - - -# ----------------------------------------------------------------------------- -# 3. Latent Sampler -# ----------------------------------------------------------------------------- -class BriaLatentSampler: - def __init__(self, transformer: BriaTransformer2DModel, vae: AutoencoderKL, device: torch.device): - self.device = device - self.latent_channels = transformer.config.in_channels - # self.latent_height = vae.config.sample_size - # self.latent_width = vae.config.sample_size - self.latent_height = 128 - self.latent_width = 128 - - @staticmethod - def _prepare_latent_image_ids(batch_size: int, height: int, width: int, device: torch.device, dtype: torch.dtype): - # Build the same img_ids FluxPipeline.prepare_latents would use - latent_image_ids = torch.zeros((height, width, 3), device=device, dtype=dtype) - latent_image_ids[..., 1] = torch.arange(height, device=device)[:, None] - latent_image_ids[..., 2] = torch.arange(width, device=device)[None, :] - # reshape to [1, height*width, 3] then repeat for batch - latent_image_ids = latent_image_ids.view(1, height * width, 3) - return latent_image_ids.repeat(batch_size, 1, 1) - - def sample(self, batch_size: int = 1, seed: int = 0) -> tuple[torch.Tensor, torch.Tensor]: - gen = torch.Generator(device=self.device).manual_seed(seed) - - # 1) sample & pack the noise exactly as before - shrunk = self.latent_channels // 4 - noise4d = torch.randn( - (batch_size, shrunk, self.latent_height, self.latent_width), - device=self.device, - generator=gen, - ) - latents = ( - noise4d.view(batch_size, shrunk, self.latent_height // 2, 2, self.latent_width // 2, 2) - .permute(0, 2, 4, 1, 3, 5) - .reshape(batch_size, (self.latent_height // 2) * (self.latent_width // 2), shrunk * 4) - ) - - # 2) build the matching latent_image_ids - latent_image_ids = self._prepare_latent_image_ids( - batch_size, - self.latent_height // 2, - self.latent_width // 2, - device=self.device, - dtype=torch.long, - ) - if latent_image_ids.ndim == 3 and latent_image_ids.shape[0] == 1: - latent_image_ids = latent_image_ids[0] # [S_img , 3] - - latent_image_ids = latent_image_ids.squeeze(0) - - print(f"Sampled & packed latents: {latents.shape}") - return latents, latent_image_ids - - -# ----------------------------------------------------------------------------- -# 4. Denoising Loop (uses bria_utils for ╧â schedule) -# ----------------------------------------------------------------------------- -class BriaDenoise: - def __init__( - self, - transformer: nn.Module, - scheduler_name: str, - device: torch.device, - num_train_timesteps: int, - num_inference_steps: int, - **sched_kwargs, - ): - self.transformer = transformer.to(device) - self.device = device - - # Build scheduler - if scheduler_name == "flow_match": - from diffusers import FlowMatchEulerDiscreteScheduler - - self.scheduler = FlowMatchEulerDiscreteScheduler.from_config(transformer.config, **sched_kwargs) - else: - from diffusers import DDIMScheduler - - self.scheduler = DDIMScheduler(**sched_kwargs) - - # Use your exact ╧â schedule from bria_utils - from bria_utils import get_original_sigmas - - sigmas = get_original_sigmas( - num_train_timesteps=num_train_timesteps, - num_inference_steps=num_inference_steps, - ) - self.scheduler.set_timesteps( - num_inference_steps=None, - timesteps=None, - sigmas=sigmas, - device=device, - ) - - # allow early exit - self.interrupt = False - # will be set in denoise() - self._guidance_scale = 1.0 - self._joint_attention_kwargs = {} - self.transformer = transformer.to(device) - self.device = device - - @property - def guidance_scale(self) -> float: - return self._guidance_scale - - @property - def do_classifier_free_guidance(self) -> bool: - return self.guidance_scale > 1.0 - - @property - def joint_attention_kwargs(self) -> dict: - return self._joint_attention_kwargs - - @torch.no_grad() - def denoise( - self, - latents: torch.Tensor, # [B, seq_len, C_hidden] - latent_image_ids: torch.Tensor, # [B, seq_len, 3] - prompt_embeds: torch.Tensor, # [B, S_text, D] - negative_prompt_embeds: torch.Tensor, # [B, S_text, D] - text_ids: torch.Tensor, # [B, S_text, 3] - num_inference_steps: int = 30, - guidance_scale: float = 5.0, - normalize: bool = False, - clip_value: float | None = None, - seed: int = 0, - ) -> torch.Tensor: - # 0) Quick cast & setup - device = self.device - # ensure dtype matches transformer - target_dtype = next(self.transformer.parameters()).dtype - latents = latents.to(device, dtype=target_dtype) - prompt_embeds = prompt_embeds.to(device, dtype=target_dtype) - negative_prompt_embeds = negative_prompt_embeds.to(device, dtype=target_dtype) - - # replicate reference encode_prompt behaviour - if negative_prompt_embeds is None: - negative_prompt_embeds = torch.zeros_like(prompt_embeds) - if guidance_scale > 1.0: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - self._guidance_scale = guidance_scale - - # 1) Prepare FlowΓÇæMatch timesteps identical to reference pipeline - if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler) and getattr( - self.scheduler.config, "use_dynamic_shifting", False - ): - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = latents.shape[1] - mu = calculate_shift(image_seq_len, 256, 16_384, 0.25, 0.75) - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, None, sigmas, mu=mu - ) - else: - sigmas = get_original_sigmas( - num_train_timesteps=self.scheduler.config.num_train_timesteps, - num_inference_steps=num_inference_steps, - ) - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, None, sigmas - ) - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - - # 2) Loop with progress bar - - with tqdm(total=num_inference_steps, desc="Denoising", unit="step") as progress_bar: - for i, t in enumerate(timesteps): - # a) expand for CFG? - latent_model_input = torch.cat([latents] * 2, dim=0) if self.do_classifier_free_guidance else latents - - # b) scale model input if needed - if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # c) broadcast timestep - timestep = t.expand(latent_model_input.shape[0]) - - # d) predict noise - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - txt_ids=text_ids, - img_ids=latent_image_ids, - )[0] - - # e) classifierΓÇæfree guidance - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - cfg_noise_pred_text = noise_pred_text.std() - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - # f) optional normalize/clip - if normalize: - noise_pred = noise_pred * (0.7 * (cfg_noise_pred_text / noise_pred.std())) + 0.3 * noise_pred - - if clip_value: - noise_pred = noise_pred.clamp(-clip_value, clip_value) - - # g) scheduler step, inΓÇæplace - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if (i + 1) % 5 == 0 or i == len(timesteps) - 1: - progress_bar.update(5 if i + 1 < len(timesteps) else (len(timesteps) % 5)) - - # # j) XLA sync - # if XLA_AVAILABLE: - # xm.mark_step() - - # 3) Return the final packed latents (still [B, seq_len, C_hidden]) - return latents - - -# ----------------------------------------------------------------------------- -# 5. Latents ΓåÆ Image -# ----------------------------------------------------------------------------- -class BriaLatentsToImage: - def __init__(self, vae: AutoencoderKL, device: torch.device): - self.vae = vae.to(device) - self.device = device - - @torch.no_grad() - def decode(self, latents: torch.Tensor) -> list[Image.Image]: - """ - Accepts either of the two packed shapes that come out of the denoiser - - ΓÇó [B , S , 16] ΓÇô 3ΓÇæD, where S = H┬▓ (e.g. 16┬á384 for 1024├ù1024) - ΓÇó [B , 1 , S , 16] ΓÇô 4ΓÇæD misΓÇæordered (what caused your crash) - - Converts them to the VAEΓÇÖs expected shape [B , 4 , H , W] before decoding. - """ - # ---- 1. UnΓÇæpack to (B , 4 , H , W) ---------------------------------- - if latents.ndim == 3: # (B,S,16) - B, S, C = latents.shape - H2 = int(S**0.5) # 128 for 1024├ù1024 - latents = ( - latents.view(B, H2, H2, 4, 2, 2) # split channels into 4├ù(2├ù2) - .permute(0, 3, 1, 4, 2, 5) # (B,4,H2,2,W2,2) - .reshape(B, 4, H2 * 2, H2 * 2) # (B,4,H,W) - ) - - elif latents.ndim == 4 and latents.shape[1] == 1: # (B,1,S,16) - B, _, S, C = latents.shape - H2 = int(S**0.5) - latents = ( - latents.squeeze(1) # -> (B,S,16) - .view(B, H2, H2, 4, 2, 2) - .permute(0, 3, 1, 4, 2, 5) - .reshape(B, 4, H2 * 2, H2 * 2) - ) - # else: already (B,4,H,W) - - # ---- 2. Standard VAE decode ----------------------------------------- - shift = 0 if self.vae.config.shift_factor is None else self.vae.config.shift_factor - latents = (latents / self.vae.config.scaling_factor) + shift - - # 1. temporarily move VAE to fp32 for the forward pass - self.vae.to(dtype=torch.float32) - images = self.vae.decode(latents.to(torch.float32)).sample # fullΓÇæprecision decode - self.vae.to(dtype=torch.bfloat16) # cast to fp32 **after** decode - images = (images.clamp(-1, 1) + 1) / 2 # [0,1] fp32 - images = (images.cpu().permute(0, 2, 3, 1).numpy() * 255).astype("uint8") - - return [Image.fromarray(img) for img in images] - - -# ----------------------------------------------------------------------------- -# Main: Assemble & Run -# ----------------------------------------------------------------------------- -def main(): - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.benchmark = True - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print("Using device:", device) - - # ΓöÇΓöÇΓöÇ Use your actual checkpoint locations ΓöÇΓöÇΓöÇ - transformer_ckpt = "/home/ubuntu/invoke_local_nodes/bria_3_1/transformer" - vae_ckpt = "/home/ubuntu/invoke_local_nodes/bria_3_1/vae" - text_encoder_ckpt = "/home/ubuntu/invoke_local_nodes/bria_3_1/text_encoder" - tokenizer_ckpt = "/home/ubuntu/invoke_local_nodes/bria_3_1/tokenizer" - - # 1. Load models - loader = BriaModelLoader( - transformer_ckpt, - vae_ckpt, - text_encoder_ckpt, - tokenizer_ckpt, - device, - ) - mdl = loader.get() - # if diffusers.__version__ >= "0.27.0": - # mdl["transformer"].enable_xformers_memory_efficient_attention() # now safe - # else: - # mdl["transformer"].disable_xformers_memory_efficient_attention() # keep quality - - # 2. Encode prompt ΓÇö now capture text_ids as well - text_enc = BriaTextEncoder(mdl["text_encoder"], mdl["tokenizer"], device) - pos_embeds, neg_embeds, text_ids = text_enc.encode( - prompt="3d rendered image, landscape made out of ice cream, rich ice cream textures, ice cream-valley , with a milky ice cream river, the ice cream has rich texture with visible chocolate chunks and intricate details, in the background an air balloon floats over the vally, in the sky visible dramatic like clouds, brown-chocolate color white and pink pallet, drama, beautiful surreal landscape, polarizing lens, very high contrast, 3d rendered realistic", - negative_prompt=None, - num_images_per_prompt=1, - ) - - # 3. Sample initial noise ΓåÆ get both latents & latent_image_ids - sampler = BriaLatentSampler(mdl["transformer"], mdl["vae"], device) - init_latents, latent_image_ids = sampler.sample(batch_size=1, seed=1249141701) - - # 4. Denoise ΓÇö now passing latent_image_ids and text_ids - denoiser = BriaDenoise( - transformer=mdl["transformer"], - scheduler_name="flow_match", - device=device, - num_train_timesteps=1000, - num_inference_steps=30, - base_shift=0.5, - max_shift=1.15, - ) - final_latents = denoiser.denoise( - init_latents, - latent_image_ids, - pos_embeds, - neg_embeds, - text_ids, - num_inference_steps=30, - guidance_scale=5.0, - seed=1249141701, - ) - - # 5. Decode - decoder = BriaLatentsToImage(mdl["vae"], device) - images = decoder.decode(final_latents) - for i, img in enumerate(images): - img.save(f"bria_output_{i}.png") - - -if __name__ == "__main__": - main() From cad97d3da3e5f3c944d7b8dedc61db8bbf853d6c Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 14 Jul 2025 14:56:39 +0000 Subject: [PATCH 09/14] Small cosmetic fixes --- invokeai/nodes/bria_nodes/bria_controlnet.py | 2 +- invokeai/nodes/bria_nodes/bria_decoder.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/invokeai/nodes/bria_nodes/bria_controlnet.py b/invokeai/nodes/bria_nodes/bria_controlnet.py index 20d49988da6..514b99e9879 100644 --- a/invokeai/nodes/bria_nodes/bria_controlnet.py +++ b/invokeai/nodes/bria_nodes/bria_controlnet.py @@ -36,7 +36,7 @@ class BriaControlNetOutput(BaseInvocationOutput): @invocation( "bria_controlnet", - title="Bria ControlNet", + title="ControlNet - Bria", tags=["controlnet", "bria"], category="controlnet", version="1.0.0", diff --git a/invokeai/nodes/bria_nodes/bria_decoder.py b/invokeai/nodes/bria_nodes/bria_decoder.py index 38dbac0a0bb..b333a8c8c7b 100644 --- a/invokeai/nodes/bria_nodes/bria_decoder.py +++ b/invokeai/nodes/bria_nodes/bria_decoder.py @@ -10,21 +10,21 @@ @invocation( "bria_decoder", - title="Bria Decoder", + title="Decoder - Bria", tags=["image", "bria"], category="image", version="1.0.0", classification=Classification.Prototype, ) class BriaDecoderInvocation(BaseInvocation): - latents: LatentsField = InputField( - description=FieldDescriptions.latents, - input=Input.Connection, - ) vae: VAEField = InputField( description=FieldDescriptions.vae, input=Input.Connection, ) + latents: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, + ) @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: From 8523ea88f2fe594cf3e9dfd26245b10eccb6ee6d Mon Sep 17 00:00:00 2001 From: Ilan Tchenak Date: Tue, 15 Jul 2025 17:24:41 +0300 Subject: [PATCH 10/14] moved bria's nodes to invocations folder --- .../bria_nodes => app/invocations}/bria_controlnet.py | 0 .../{nodes/bria_nodes => app/invocations}/bria_decoder.py | 0 .../{nodes/bria_nodes => app/invocations}/bria_denoiser.py | 0 .../bria_nodes => app/invocations}/bria_latent_sampler.py | 0 .../bria_nodes => app/invocations}/bria_model_loader.py | 0 .../bria_nodes => app/invocations}/bria_text_encoder.py | 0 invokeai/nodes/__init__.py | 1 - invokeai/nodes/bria_nodes/__init__.py | 6 ------ 8 files changed, 7 deletions(-) rename invokeai/{nodes/bria_nodes => app/invocations}/bria_controlnet.py (100%) rename invokeai/{nodes/bria_nodes => app/invocations}/bria_decoder.py (100%) rename invokeai/{nodes/bria_nodes => app/invocations}/bria_denoiser.py (100%) rename invokeai/{nodes/bria_nodes => app/invocations}/bria_latent_sampler.py (100%) rename invokeai/{nodes/bria_nodes => app/invocations}/bria_model_loader.py (100%) rename invokeai/{nodes/bria_nodes => app/invocations}/bria_text_encoder.py (100%) delete mode 100644 invokeai/nodes/bria_nodes/__init__.py diff --git a/invokeai/nodes/bria_nodes/bria_controlnet.py b/invokeai/app/invocations/bria_controlnet.py similarity index 100% rename from invokeai/nodes/bria_nodes/bria_controlnet.py rename to invokeai/app/invocations/bria_controlnet.py diff --git a/invokeai/nodes/bria_nodes/bria_decoder.py b/invokeai/app/invocations/bria_decoder.py similarity index 100% rename from invokeai/nodes/bria_nodes/bria_decoder.py rename to invokeai/app/invocations/bria_decoder.py diff --git a/invokeai/nodes/bria_nodes/bria_denoiser.py b/invokeai/app/invocations/bria_denoiser.py similarity index 100% rename from invokeai/nodes/bria_nodes/bria_denoiser.py rename to invokeai/app/invocations/bria_denoiser.py diff --git a/invokeai/nodes/bria_nodes/bria_latent_sampler.py b/invokeai/app/invocations/bria_latent_sampler.py similarity index 100% rename from invokeai/nodes/bria_nodes/bria_latent_sampler.py rename to invokeai/app/invocations/bria_latent_sampler.py diff --git a/invokeai/nodes/bria_nodes/bria_model_loader.py b/invokeai/app/invocations/bria_model_loader.py similarity index 100% rename from invokeai/nodes/bria_nodes/bria_model_loader.py rename to invokeai/app/invocations/bria_model_loader.py diff --git a/invokeai/nodes/bria_nodes/bria_text_encoder.py b/invokeai/app/invocations/bria_text_encoder.py similarity index 100% rename from invokeai/nodes/bria_nodes/bria_text_encoder.py rename to invokeai/app/invocations/bria_text_encoder.py diff --git a/invokeai/nodes/__init__.py b/invokeai/nodes/__init__.py index f6b74417535..e69de29bb2d 100644 --- a/invokeai/nodes/__init__.py +++ b/invokeai/nodes/__init__.py @@ -1 +0,0 @@ -from .bria_nodes import * \ No newline at end of file diff --git a/invokeai/nodes/bria_nodes/__init__.py b/invokeai/nodes/bria_nodes/__init__.py deleted file mode 100644 index a6049b4362e..00000000000 --- a/invokeai/nodes/bria_nodes/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .bria_decoder import BriaDecoderInvocation -from .bria_model_loader import BriaModelLoaderInvocation -from .bria_denoiser import BriaDenoiseInvocation -from .bria_latent_sampler import BriaLatentSamplerInvocation -from .bria_text_encoder import BriaTextEncoderInvocation -from .bria_full import BriaFullInvocation From 282df322d52be29db198d3d874295b6572a0d1f9 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 17 Jul 2025 10:31:27 +0000 Subject: [PATCH 11/14] fixed node issue --- invokeai/app/invocations/bria_controlnet.py | 11 ++++++----- invokeai/app/invocations/bria_denoiser.py | 4 ++-- invokeai/app/invocations/bria_latent_sampler.py | 10 +++++++--- .../backend/bria/pipeline_bria_controlnet.py | 16 ++++++++-------- 4 files changed, 23 insertions(+), 18 deletions(-) diff --git a/invokeai/app/invocations/bria_controlnet.py b/invokeai/app/invocations/bria_controlnet.py index 514b99e9879..61338d6cf47 100644 --- a/invokeai/app/invocations/bria_controlnet.py +++ b/invokeai/app/invocations/bria_controlnet.py @@ -1,13 +1,13 @@ from invokeai.backend.bria.controlnet_bria import BRIA_CONTROL_MODES from pydantic import BaseModel, Field -from invokeai.invocation_api import ImageOutput +from invokeai.invocation_api import ImageOutput, Classification from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, invocation, invocation_output, ) -from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType +from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType, WithBoard, WithMetadata from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.services.shared.invocation_context import InvocationContext import numpy as np @@ -26,9 +26,9 @@ class BriaControlNetField(BaseModel): mode: BRIA_CONTROL_MODES = Field(description="The mode of the ControlNet") conditioning_scale: float = Field(description="The weight given to the ControlNet") -@invocation_output("flux_controlnet_output") +@invocation_output("bria_controlnet_output") class BriaControlNetOutput(BaseInvocationOutput): - """FLUX ControlNet info""" + """Bria ControlNet info""" control: BriaControlNetField = OutputField(description=FieldDescriptions.control) preprocessed_images: ImageField = OutputField(description="The preprocessed control image") @@ -40,8 +40,9 @@ class BriaControlNetOutput(BaseInvocationOutput): tags=["controlnet", "bria"], category="controlnet", version="1.0.0", + classification=Classification.Prototype, ) -class BriaControlNetInvocation(BaseInvocation): +class BriaControlNetInvocation(BaseInvocation, WithMetadata, WithBoard): """Collect Bria ControlNet info to pass to denoiser node.""" control_image: ImageField = InputField(description="The control image") diff --git a/invokeai/app/invocations/bria_denoiser.py b/invokeai/app/invocations/bria_denoiser.py index 834e290de4f..30b61ccd351 100644 --- a/invokeai/app/invocations/bria_denoiser.py +++ b/invokeai/app/invocations/bria_denoiser.py @@ -3,7 +3,7 @@ from invokeai.backend.bria.controlnet_bria import BriaControlModes, BriaMultiControlNetModel from invokeai.backend.bria.controlnet_utils import prepare_control_images from invokeai.backend.bria.pipeline_bria_controlnet import BriaControlNetPipeline -from invokeai.nodes.bria_nodes.bria_controlnet import BriaControlNetField +from invokeai.app.invocations.bria_controlnet import BriaControlNetField import torch from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler @@ -12,7 +12,7 @@ from invokeai.app.invocations.model import SubModelType, T5EncoderField, TransformerField, VAEField from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.invocation_api import BaseInvocation, Classification, InputField, invocation, invocation_output +from invokeai.invocation_api import BaseInvocation, Classification, invocation, invocation_output from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel @invocation_output("bria_denoise_output") diff --git a/invokeai/app/invocations/bria_latent_sampler.py b/invokeai/app/invocations/bria_latent_sampler.py index f3ce74729b2..d3531e3125f 100644 --- a/invokeai/app/invocations/bria_latent_sampler.py +++ b/invokeai/app/invocations/bria_latent_sampler.py @@ -48,18 +48,22 @@ class BriaLatentSamplerInvocation(BaseInvocation): title="Transformer", ) + @torch.no_grad() def invoke(self, context: InvocationContext) -> BriaLatentSamplerInvocationOutput: - device = torch.device("cuda") + with context.models.load(self.transformer.transformer) as transformer: + device = transformer.device + dtype = transformer.dtype + height, width = 1024, 1024 generator = torch.Generator(device=device).manual_seed(self.seed) - num_channels_latents = 4 # due to patch=2, we devide by 4 + num_channels_latents = 4 latents, latent_image_ids = prepare_latents( batch_size=1, num_channels_latents=num_channels_latents, height=height, width=width, - dtype=torch.float32, + dtype=dtype, device=device, generator=generator, ) diff --git a/invokeai/backend/bria/pipeline_bria_controlnet.py b/invokeai/backend/bria/pipeline_bria_controlnet.py index b6106cd02ca..fb80fce3bff 100644 --- a/invokeai/backend/bria/pipeline_bria_controlnet.py +++ b/invokeai/backend/bria/pipeline_bria_controlnet.py @@ -612,14 +612,14 @@ def encode_prompt( def prepare_latents( - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator, + latents: Optional[torch.FloatTensor] = None, ): # VAE applies 8x compression on images but we must also account for packing which requires # latent height and width to be divisible by 2. From 1ac5a24a8ad6fe90934db956f1714f73e1f42cd9 Mon Sep 17 00:00:00 2001 From: Ilan Tchenak Date: Sun, 20 Jul 2025 15:03:57 +0300 Subject: [PATCH 12/14] ruff fix --- invokeai/app/invocations/bria_controlnet.py | 52 +-- invokeai/app/invocations/bria_decoder.py | 8 +- invokeai/app/invocations/bria_denoiser.py | 45 ++- .../app/invocations/bria_latent_sampler.py | 9 +- invokeai/app/invocations/bria_model_loader.py | 2 - invokeai/app/invocations/bria_text_encoder.py | 6 +- invokeai/backend/bria/bria_utils.py | 2 +- .../backend/bria/controlnet_aux/__init__.py | 5 +- .../bria/controlnet_aux/canny/__init__.py | 30 +- .../bria/controlnet_aux/open_pose/__init__.py | 129 +++++--- .../bria/controlnet_aux/open_pose/body.py | 123 +++++-- .../bria/controlnet_aux/open_pose/face.py | 313 +++++++----------- .../bria/controlnet_aux/open_pose/hand.py | 15 +- .../bria/controlnet_aux/open_pose/model.py | 273 ++++++++------- .../bria/controlnet_aux/open_pose/util.py | 157 ++++++--- invokeai/backend/bria/controlnet_aux/util.py | 196 ++++++++--- invokeai/backend/bria/controlnet_bria.py | 84 +++-- invokeai/backend/bria/controlnet_utils.py | 17 +- invokeai/backend/bria/pipeline_bria.py | 151 ++++----- .../backend/bria/pipeline_bria_controlnet.py | 123 ++++--- invokeai/backend/bria/transformer_bria.py | 6 +- .../model_manager/load/model_loaders/bria.py | 5 +- 22 files changed, 998 insertions(+), 753 deletions(-) diff --git a/invokeai/app/invocations/bria_controlnet.py b/invokeai/app/invocations/bria_controlnet.py index 61338d6cf47..878f3cc8b03 100644 --- a/invokeai/app/invocations/bria_controlnet.py +++ b/invokeai/app/invocations/bria_controlnet.py @@ -1,31 +1,41 @@ -from invokeai.backend.bria.controlnet_bria import BRIA_CONTROL_MODES +import cv2 +import numpy as np +from PIL import Image from pydantic import BaseModel, Field -from invokeai.invocation_api import ImageOutput, Classification + from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, invocation, invocation_output, ) -from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType, WithBoard, WithMetadata +from invokeai.app.invocations.fields import ( + FieldDescriptions, + ImageField, + InputField, + OutputField, + UIType, + WithBoard, + WithMetadata, +) from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.services.shared.invocation_context import InvocationContext -import numpy as np -import cv2 -from PIL import Image - +from invokeai.backend.bria.controlnet_aux.open_pose import Body, Face, Hand, OpenposeDetector +from invokeai.backend.bria.controlnet_bria import BRIA_CONTROL_MODES from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline -from invokeai.backend.bria.controlnet_aux.open_pose import OpenposeDetector, Body, Hand, Face +from invokeai.invocation_api import Classification, ImageOutput DEPTH_SMALL_V2_URL = "depth-anything/Depth-Anything-V2-Small-hf" HF_LLLYASVIEL = "https://huggingface.co/lllyasviel/Annotators/resolve/main/" + class BriaControlNetField(BaseModel): image: ImageField = Field(description="The control image") model: ModelIdentifierField = Field(description="The ControlNet model to use") mode: BRIA_CONTROL_MODES = Field(description="The mode of the ControlNet") conditioning_scale: float = Field(description="The weight given to the ControlNet") + @invocation_output("bria_controlnet_output") class BriaControlNetOutput(BaseInvocationOutput): """Bria ControlNet info""" @@ -49,12 +59,8 @@ class BriaControlNetInvocation(BaseInvocation, WithMetadata, WithBoard): control_model: ModelIdentifierField = InputField( description=FieldDescriptions.controlnet_model, ui_type=UIType.BriaControlNetModel ) - control_mode: BRIA_CONTROL_MODES = InputField( - default="depth", description="The mode of the ControlNet" - ) - control_weight: float = InputField( - default=1.0, ge=-1, le=2, description="The weight given to the ControlNet" - ) + control_mode: BRIA_CONTROL_MODES = InputField(default="depth", description="The mode of the ControlNet") + control_weight: float = InputField(default=1.0, ge=-1, le=2, description="The weight given to the ControlNet") def invoke(self, context: InvocationContext) -> BriaControlNetOutput: image_in = resize_img(context.images.get_pil(self.control_image.image_name)) @@ -70,7 +76,7 @@ def invoke(self, context: InvocationContext) -> BriaControlNetOutput: control_image = convert_to_grayscale(image_in) elif self.control_mode == "tile": control_image = tile(16, image_in) - + control_image = resize_img(control_image) image_dto = context.images.save(image=control_image) image_output = ImageOutput.build(image_dto) @@ -99,6 +105,7 @@ def invoke(self, context: InvocationContext) -> BriaControlNetOutput: 1.7708333333333333: {"width": 1360, "height": 768}, } + def extract_depth(image: Image.Image, context: InvocationContext): loaded_model = context.models.load_remote_model(DEPTH_SMALL_V2_URL, DepthAnythingPipeline.load_model) @@ -107,6 +114,7 @@ def extract_depth(image: Image.Image, context: InvocationContext): depth_map = depth_anything_detector.generate_depth(image) return depth_map + def extract_openpose(image: Image.Image, context: InvocationContext): body_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}body_pose_model.pth", Body) hand_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}hand_pose_model.pth", Hand) @@ -115,10 +123,10 @@ def extract_openpose(image: Image.Image, context: InvocationContext): with body_model as body_model, hand_model as hand_model, face_model as face_model: open_pose_model = OpenposeDetector(body_model, hand_model, face_model) processed_image_open_pose = open_pose_model(image, hand_and_face=True) - + processed_image_open_pose = processed_image_open_pose.resize(image.size) return processed_image_open_pose - + def extract_canny(input_image): image = np.array(input_image) @@ -130,13 +138,17 @@ def extract_canny(input_image): def convert_to_grayscale(image): - gray_image = image.convert('L').convert('RGB') + gray_image = image.convert("L").convert("RGB") return gray_image + def tile(downscale_factor, input_image): - control_image = input_image.resize((input_image.size[0] // downscale_factor, input_image.size[1] // downscale_factor)).resize(input_image.size, Image.Resampling.NEAREST) + control_image = input_image.resize( + (input_image.size[0] // downscale_factor, input_image.size[1] // downscale_factor) + ).resize(input_image.size, Image.Resampling.NEAREST) return control_image - + + def resize_img(control_image): image_ratio = control_image.width / control_image.height ratio = min(RATIO_CONFIGS_1024.keys(), key=lambda k: abs(k - image_ratio)) diff --git a/invokeai/app/invocations/bria_decoder.py b/invokeai/app/invocations/bria_decoder.py index b333a8c8c7b..a3f862a884b 100644 --- a/invokeai/app/invocations/bria_decoder.py +++ b/invokeai/app/invocations/bria_decoder.py @@ -30,15 +30,15 @@ class BriaDecoderInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.tensors.load(self.latents.latents_name) latents = latents.view(1, 64, 64, 4, 2, 2).permute(0, 3, 1, 4, 2, 5).reshape(1, 4, 128, 128) - + with context.models.load(self.vae.vae) as vae: assert isinstance(vae, AutoencoderKL) - latents = (latents / vae.config.scaling_factor) + latents = latents / vae.config.scaling_factor latents = latents.to(device=vae.device, dtype=vae.dtype) - + decoded_output = vae.decode(latents) image = decoded_output.sample - + # Convert to numpy with proper gradient handling image = ((image.clamp(-1, 1) + 1) / 2 * 255).cpu().detach().permute(0, 2, 3, 1).numpy().astype("uint8")[0] img = Image.fromarray(image) diff --git a/invokeai/app/invocations/bria_denoiser.py b/invokeai/app/invocations/bria_denoiser.py index 30b61ccd351..f1a60568811 100644 --- a/invokeai/app/invocations/bria_denoiser.py +++ b/invokeai/app/invocations/bria_denoiser.py @@ -1,19 +1,20 @@ from typing import List, Tuple -from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL -from invokeai.backend.bria.controlnet_bria import BriaControlModes, BriaMultiControlNetModel -from invokeai.backend.bria.controlnet_utils import prepare_control_images -from invokeai.backend.bria.pipeline_bria_controlnet import BriaControlNetPipeline -from invokeai.app.invocations.bria_controlnet import BriaControlNetField import torch +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler +from invokeai.app.invocations.bria_controlnet import BriaControlNetField from invokeai.app.invocations.fields import Input, InputField, LatentsField, OutputField from invokeai.app.invocations.model import SubModelType, T5EncoderField, TransformerField, VAEField from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.invocation_api import BaseInvocation, Classification, invocation, invocation_output +from invokeai.backend.bria.controlnet_bria import BriaControlModes, BriaMultiControlNetModel +from invokeai.backend.bria.controlnet_utils import prepare_control_images +from invokeai.backend.bria.pipeline_bria_controlnet import BriaControlNetPipeline from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel +from invokeai.invocation_api import BaseInvocation, Classification, invocation, invocation_output + @invocation_output("bria_denoise_output") class BriaDenoiseInvocationOutput(BaseInvocationOutput): @@ -80,7 +81,7 @@ class BriaDenoiseInvocation(BaseInvocation): description="ControlNet", input=Input.Connection, title="ControlNet", - default = None, + default=None, ) @torch.no_grad() @@ -106,7 +107,7 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput: assert isinstance(vae, AutoencoderKL) dtype = transformer.dtype device = transformer.device - latents, pos_embeds, neg_embeds = map(lambda x: x.to(device, dtype), (latents, pos_embeds, neg_embeds)) + latents, pos_embeds, neg_embeds = (x.to(device, dtype) for x in (latents, pos_embeds, neg_embeds)) control_model, control_images, control_modes, control_scales = None, None, None, None if self.control is not None: @@ -134,7 +135,7 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput: width=1024, height=1024, controlnet_conditioning_scale=control_scales, - num_inference_steps=self.num_steps, + num_inference_steps=self.num_steps, max_sequence_length=128, guidance_scale=self.guidance_scale, latents=latents, @@ -150,36 +151,30 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput: latents_output = LatentsField(latents_name=saved_input_latents_tensor) return BriaDenoiseInvocationOutput(latents=latents_output) - def _prepare_multi_control( - self, - context: InvocationContext, - vae: AutoencoderKL, - width: int, - height: int, - device: torch.device + self, context: InvocationContext, vae: AutoencoderKL, width: int, height: int, device: torch.device ) -> Tuple[BriaMultiControlNetModel, List[torch.Tensor], List[torch.Tensor], List[float]]: - control = self.control if isinstance(self.control, list) else [self.control] control_images, control_models, control_modes, control_scales = [], [], [], [] for controlnet in control: if controlnet is not None: control_models.append(context.models.load(controlnet.model).model) - control_modes.append(BriaControlModes[controlnet.mode].value) + control_modes.append(BriaControlModes[controlnet.mode].value) control_scales.append(controlnet.conditioning_scale) try: control_images.append(context.images.get_pil(controlnet.image.image_name)) - except: - raise FileNotFoundError(f"Control image {controlnet.image.image_name} not found. Make sure not to delete the preprocessed image before finishing the pipeline.") + except Exception: + raise FileNotFoundError( + f"Control image {controlnet.image.image_name} not found. Make sure not to delete the preprocessed image before finishing the pipeline." + ) control_model = BriaMultiControlNetModel(control_models).to(device) tensored_control_images, tensored_control_modes = prepare_control_images( vae=vae, - control_images=control_images, - control_modes=control_modes, + control_images=control_images, + control_modes=control_modes, width=width, height=height, - device=device, - ) + device=device, + ) return control_model, tensored_control_images, tensored_control_modes, control_scales - \ No newline at end of file diff --git a/invokeai/app/invocations/bria_latent_sampler.py b/invokeai/app/invocations/bria_latent_sampler.py index d3531e3125f..7bf8acf8c97 100644 --- a/invokeai/app/invocations/bria_latent_sampler.py +++ b/invokeai/app/invocations/bria_latent_sampler.py @@ -1,19 +1,16 @@ import torch -from invokeai.app.invocations.fields import Input, InputField +from invokeai.app.invocations.fields import Input, InputField, OutputField from invokeai.app.invocations.model import TransformerField from invokeai.app.invocations.primitives import ( BaseInvocationOutput, FieldDescriptions, - Input, LatentsField, - OutputField, ) from invokeai.backend.bria.pipeline_bria_controlnet import prepare_latents from invokeai.invocation_api import ( BaseInvocation, Classification, - InputField, InvocationContext, invocation, invocation_output, @@ -56,7 +53,7 @@ def invoke(self, context: InvocationContext) -> BriaLatentSamplerInvocationOutpu height, width = 1024, 1024 generator = torch.Generator(device=device).manual_seed(self.seed) - + num_channels_latents = 4 latents, latent_image_ids = prepare_latents( batch_size=1, @@ -66,7 +63,7 @@ def invoke(self, context: InvocationContext) -> BriaLatentSamplerInvocationOutpu dtype=dtype, device=device, generator=generator, - ) + ) saved_latents_tensor = context.tensors.save(latents) saved_latent_image_ids_tensor = context.tensors.save(latent_image_ids) diff --git a/invokeai/app/invocations/bria_model_loader.py b/invokeai/app/invocations/bria_model_loader.py index b8b20f4f511..aa06ada1ba7 100644 --- a/invokeai/app/invocations/bria_model_loader.py +++ b/invokeai/app/invocations/bria_model_loader.py @@ -10,9 +10,7 @@ BaseInvocation, BaseInvocationOutput, Classification, - InputField, InvocationContext, - OutputField, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/bria_text_encoder.py b/invokeai/app/invocations/bria_text_encoder.py index 03a9a1d50b4..f574ad86aea 100644 --- a/invokeai/app/invocations/bria_text_encoder.py +++ b/invokeai/app/invocations/bria_text_encoder.py @@ -19,8 +19,6 @@ invocation_output, ) -from invokeai.backend.bria.bria_utils import get_t5_prompt_embeds, is_ng_none - @invocation_output("bria_text_encoder_output") class BriaTextEncoderInvocationOutput(BaseInvocationOutput): @@ -70,7 +68,7 @@ def invoke(self, context: InvocationContext) -> BriaTextEncoderInvocationOutput: ): assert isinstance(tokenizer, T5TokenizerFast) assert isinstance(text_encoder, T5EncoderModel) - + (prompt_embeds, negative_prompt_embeds, text_ids) = encode_prompt( prompt=self.prompt, tokenizer=tokenizer, @@ -81,7 +79,7 @@ def invoke(self, context: InvocationContext) -> BriaTextEncoderInvocationOutput: max_sequence_length=self.max_length, lora_scale=1.0, ) - + saved_pos_tensor = context.tensors.save(prompt_embeds) saved_neg_tensor = context.tensors.save(negative_prompt_embeds) saved_text_ids_tensor = context.tensors.save(text_ids) diff --git a/invokeai/backend/bria/bria_utils.py b/invokeai/backend/bria/bria_utils.py index a821ebe7ba1..6b47f04b096 100644 --- a/invokeai/backend/bria/bria_utils.py +++ b/invokeai/backend/bria/bria_utils.py @@ -87,7 +87,7 @@ def is_ng_none(negative_prompt): negative_prompt is None or negative_prompt == "" or (isinstance(negative_prompt, list) and negative_prompt[0] is None) - or (type(negative_prompt) == list and negative_prompt[0] == "") + or (isinstance(negative_prompt, list) and negative_prompt[0] == "") ) diff --git a/invokeai/backend/bria/controlnet_aux/__init__.py b/invokeai/backend/bria/controlnet_aux/__init__.py index 0536dca4bbe..cfc9c718a79 100644 --- a/invokeai/backend/bria/controlnet_aux/__init__.py +++ b/invokeai/backend/bria/controlnet_aux/__init__.py @@ -1,5 +1,6 @@ __version__ = "0.0.9" -from .canny import CannyDetector -from .open_pose import OpenposeDetector +from invokeai.backend.bria.controlnet_aux.canny import CannyDetector as CannyDetector +from invokeai.backend.bria.controlnet_aux.open_pose import OpenposeDetector as OpenposeDetector +__all__ = ["CannyDetector", "OpenposeDetector"] diff --git a/invokeai/backend/bria/controlnet_aux/canny/__init__.py b/invokeai/backend/bria/controlnet_aux/canny/__init__.py index aca9ae3a34b..968d4f0db73 100644 --- a/invokeai/backend/bria/controlnet_aux/canny/__init__.py +++ b/invokeai/backend/bria/controlnet_aux/canny/__init__.py @@ -1,15 +1,27 @@ import warnings + import cv2 import numpy as np from PIL import Image -from ..util import HWC3, resize_image + +from invokeai.backend.bria.controlnet_aux.util import HWC3, resize_image + class CannyDetector: - def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, image_resolution=512, output_type=None, **kwargs): + def __call__( + self, + input_image=None, + low_threshold=100, + high_threshold=200, + detect_resolution=512, + image_resolution=512, + output_type=None, + **kwargs, + ): if "img" in kwargs: - warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning) + warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning, stacklevel=2) input_image = kwargs.pop("img") - + if input_image is None: raise ValueError("input_image must be defined.") @@ -18,19 +30,19 @@ def __call__(self, input_image=None, low_threshold=100, high_threshold=200, dete output_type = output_type or "pil" else: output_type = output_type or "np" - + input_image = HWC3(input_image) input_image = resize_image(input_image, detect_resolution) detected_map = cv2.Canny(input_image, low_threshold, high_threshold) - detected_map = HWC3(detected_map) - + detected_map = HWC3(detected_map) + img = resize_image(input_image, image_resolution) H, W, C = img.shape detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) - + if output_type == "pil": detected_map = Image.fromarray(detected_map) - + return detected_map diff --git a/invokeai/backend/bria/controlnet_aux/open_pose/__init__.py b/invokeai/backend/bria/controlnet_aux/open_pose/__init__.py index e463316aa60..cdd5ff95696 100644 --- a/invokeai/backend/bria/controlnet_aux/open_pose/__init__.py +++ b/invokeai/backend/bria/controlnet_aux/open_pose/__init__.py @@ -11,9 +11,8 @@ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" -import json import warnings -from typing import Callable, List, NamedTuple, Tuple, Union +from typing import List, NamedTuple, Tuple, Union import cv2 import numpy as np @@ -21,21 +20,23 @@ from huggingface_hub import hf_hub_download from PIL import Image -from ..util import HWC3, resize_image -from . import util -from .body import Body, BodyResult, Keypoint -from .face import Face -from .hand import Hand +from invokeai.backend.bria.controlnet_aux.open_pose import util +from invokeai.backend.bria.controlnet_aux.open_pose.body import Body, BodyResult, Keypoint +from invokeai.backend.bria.controlnet_aux.open_pose.face import Face +from invokeai.backend.bria.controlnet_aux.open_pose.hand import Hand +from invokeai.backend.bria.controlnet_aux.util import HWC3, resize_image HandResult = List[Keypoint] FaceResult = List[Keypoint] + class PoseResult(NamedTuple): body: BodyResult left_hand: Union[HandResult, None] right_hand: Union[HandResult, None] face: Union[FaceResult, None] + def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True): """ Draw the detected poses on an empty canvas. @@ -65,8 +66,8 @@ def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, dr canvas = util.draw_facepose(canvas, pose.face) return canvas - - + + class OpenposeDetector: """ A class for detecting human poses in images using the Openpose model. @@ -74,14 +75,22 @@ class OpenposeDetector: Attributes: model_dir (str): Path to the directory where the pose models are stored. """ + def __init__(self, body_estimation, hand_estimation=None, face_estimation=None): self.body_estimation = body_estimation self.hand_estimation = hand_estimation self.face_estimation = face_estimation @classmethod - def from_pretrained(cls, pretrained_model_or_path, filename=None, hand_filename=None, face_filename=None, cache_dir=None, local_files_only=False): - + def from_pretrained( + cls, + pretrained_model_or_path, + filename=None, + hand_filename=None, + face_filename=None, + cache_dir=None, + local_files_only=False, + ): if pretrained_model_or_path == "lllyasviel/ControlNet": filename = filename or "annotator/ckpts/body_pose_model.pth" hand_filename = hand_filename or "annotator/ckpts/hand_pose_model.pth" @@ -100,9 +109,15 @@ def from_pretrained(cls, pretrained_model_or_path, filename=None, hand_filename= hand_model_path = os.path.join(pretrained_model_or_path, hand_filename) face_model_path = os.path.join(face_pretrained_model_or_path, face_filename) else: - body_model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) - hand_model_path = hf_hub_download(pretrained_model_or_path, hand_filename, cache_dir=cache_dir, local_files_only=local_files_only) - face_model_path = hf_hub_download(face_pretrained_model_or_path, face_filename, cache_dir=cache_dir, local_files_only=local_files_only) + body_model_path = hf_hub_download( + pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only + ) + hand_model_path = hf_hub_download( + pretrained_model_or_path, hand_filename, cache_dir=cache_dir, local_files_only=local_files_only + ) + face_model_path = hf_hub_download( + face_pretrained_model_or_path, face_filename, cache_dir=cache_dir, local_files_only=local_files_only + ) body_estimation = Body(body_model_path) hand_estimation = Hand(hand_model_path) @@ -121,15 +136,12 @@ def detect_hands(self, body: BodyResult, oriImg) -> Tuple[Union[HandResult, None right_hand = None H, W, _ = oriImg.shape for x, y, w, is_left in util.handDetect(body, oriImg): - peaks = self.hand_estimation(oriImg[y:y+w, x:x+w, :]).astype(np.float32) + peaks = self.hand_estimation(oriImg[y : y + w, x : x + w, :]).astype(np.float32) if peaks.ndim == 2 and peaks.shape[1] == 2: peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W) peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H) - - hand_result = [ - Keypoint(x=peak[0], y=peak[1]) - for peak in peaks - ] + + hand_result = [Keypoint(x=peak[0], y=peak[1]) for peak in peaks] if is_left: left_hand = hand_result @@ -142,19 +154,16 @@ def detect_face(self, body: BodyResult, oriImg) -> Union[FaceResult, None]: face = util.faceDetect(body, oriImg) if face is None: return None - + x, y, w = face H, W, _ = oriImg.shape - heatmaps = self.face_estimation(oriImg[y:y+w, x:x+w, :]) + heatmaps = self.face_estimation(oriImg[y : y + w, x : x + w, :]) peaks = self.face_estimation.compute_peaks_from_heatmaps(heatmaps).astype(np.float32) if peaks.ndim == 2 and peaks.shape[1] == 2: peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W) peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H) - return [ - Keypoint(x=peak[0], y=peak[1]) - for peak in peaks - ] - + return [Keypoint(x=peak[0], y=peak[1]) for peak in peaks] + return None def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[PoseResult]: @@ -181,32 +190,56 @@ def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[P left_hand, right_hand = self.detect_hands(body, oriImg) if include_face: face = self.detect_face(body, oriImg) - - results.append(PoseResult(BodyResult( - keypoints=[ - Keypoint( - x=keypoint.x / float(W), - y=keypoint.y / float(H) - ) if keypoint is not None else None - for keypoint in body.keypoints - ], - total_score=body.total_score, - total_parts=body.total_parts - ), left_hand, right_hand, face)) - + + results.append( + PoseResult( + BodyResult( + keypoints=[ + Keypoint(x=keypoint.x / float(W), y=keypoint.y / float(H)) + if keypoint is not None + else None + for keypoint in body.keypoints + ], + total_score=body.total_score, + total_parts=body.total_parts, + ), + left_hand, + right_hand, + face, + ) + ) + return results - - def __call__(self, input_image, detect_resolution=512, image_resolution=512, include_body=True, include_hand=False, include_face=False, hand_and_face=None, output_type="pil", **kwargs): + + def __call__( + self, + input_image, + detect_resolution=512, + image_resolution=512, + include_body=True, + include_hand=False, + include_face=False, + hand_and_face=None, + output_type="pil", + **kwargs, + ): if hand_and_face is not None: - warnings.warn("hand_and_face is deprecated. Use include_hand and include_face instead.", DeprecationWarning) + warnings.warn( + "hand_and_face is deprecated. Use include_hand and include_face instead.", + DeprecationWarning, + stacklevel=2, + ) include_hand = hand_and_face include_face = hand_and_face if "return_pil" in kwargs: - warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) + warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning, stacklevel=2) output_type = "pil" if kwargs["return_pil"] else "np" if type(output_type) is bool: - warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") + warnings.warn( + "Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions", + stacklevel=2, + ) if output_type: output_type = "pil" @@ -216,13 +249,13 @@ def __call__(self, input_image, detect_resolution=512, image_resolution=512, inc input_image = HWC3(input_image) input_image = resize_image(input_image, detect_resolution) H, W, C = input_image.shape - + poses = self.detect_poses(input_image, include_hand, include_face) - canvas = draw_poses(poses, H, W, draw_body=include_body, draw_hand=include_hand, draw_face=include_face) + canvas = draw_poses(poses, H, W, draw_body=include_body, draw_hand=include_hand, draw_face=include_face) detected_map = canvas detected_map = HWC3(detected_map) - + img = resize_image(input_image, image_resolution) H, W, C = img.shape diff --git a/invokeai/backend/bria/controlnet_aux/open_pose/body.py b/invokeai/backend/bria/controlnet_aux/open_pose/body.py index fa4c74e4e1e..339a1c5a233 100644 --- a/invokeai/backend/bria/controlnet_aux/open_pose/body.py +++ b/invokeai/backend/bria/controlnet_aux/open_pose/body.py @@ -1,13 +1,12 @@ import math from typing import List, NamedTuple, Union -import cv2 import numpy as np import torch from scipy.ndimage.filters import gaussian_filter -from . import util -from .model import bodypose_model +from invokeai.backend.bria.controlnet_aux.open_pose import util +from invokeai.backend.bria.controlnet_aux.open_pose.model import bodypose_model class Keypoint(NamedTuple): @@ -71,17 +70,17 @@ def __call__(self, oriImg): # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride) - heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + heatmap = heatmap[: imageToTest_padded.shape[0] - pad[2], : imageToTest_padded.shape[1] - pad[3], :] heatmap = util.smart_resize(heatmap, (oriImg.shape[0], oriImg.shape[1])) # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs paf = util.smart_resize_k(paf, fx=stride, fy=stride) - paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + paf = paf[: imageToTest_padded.shape[0] - pad[2], : imageToTest_padded.shape[1] - pad[3], :] paf = util.smart_resize(paf, (oriImg.shape[0], oriImg.shape[1])) heatmap_avg += heatmap_avg + heatmap / len(multiplier) - paf_avg += + paf / len(multiplier) + paf_avg += +paf / len(multiplier) all_peaks = [] peak_counter = 0 @@ -100,8 +99,15 @@ def __call__(self, oriImg): map_down[:, :-1] = one_heatmap[:, 1:] peaks_binary = np.logical_and.reduce( - (one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1)) - peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse + ( + one_heatmap >= map_left, + one_heatmap >= map_right, + one_heatmap >= map_up, + one_heatmap >= map_down, + one_heatmap > thre1, + ) + ) + peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0], strict=False)) # note reverse peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks] peak_id = range(peak_counter, peak_counter + len(peaks)) peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))] @@ -110,13 +116,49 @@ def __call__(self, oriImg): peak_counter += len(peaks) # find connection in the specified sequence, center 29 is in the position 15 - limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ - [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ - [1, 16], [16, 18], [3, 17], [6, 18]] + limbSeq = [ + [2, 3], + [2, 6], + [3, 4], + [4, 5], + [6, 7], + [7, 8], + [2, 9], + [9, 10], + [10, 11], + [2, 12], + [12, 13], + [13, 14], + [2, 1], + [1, 15], + [15, 17], + [1, 16], + [16, 18], + [3, 17], + [6, 18], + ] # the middle joints heatmap correpondence - mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \ - [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \ - [55, 56], [37, 38], [45, 46]] + mapIdx = [ + [31, 32], + [39, 40], + [33, 34], + [35, 36], + [41, 42], + [43, 44], + [19, 20], + [21, 22], + [23, 24], + [25, 26], + [27, 28], + [29, 30], + [47, 48], + [49, 50], + [53, 54], + [51, 52], + [55, 56], + [37, 38], + [45, 46], + ] connection_all = [] special_k = [] @@ -129,7 +171,7 @@ def __call__(self, oriImg): nA = len(candA) nB = len(candB) indexA, indexB = limbSeq[k] - if (nA != 0 and nB != 0): + if nA != 0 and nB != 0: connection_candidate = [] for i in range(nA): for j in range(nB): @@ -138,30 +180,45 @@ def __call__(self, oriImg): norm = max(0.001, norm) vec = np.divide(vec, norm) - startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \ - np.linspace(candA[i][1], candB[j][1], num=mid_num))) - - vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \ - for I in range(len(startend))]) - vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \ - for I in range(len(startend))]) + startend = list( + zip( + np.linspace(candA[i][0], candB[j][0], num=mid_num), + np.linspace(candA[i][1], candB[j][1], num=mid_num), + strict=False, + ) + ) + + vec_x = np.array( + [ + score_mid[int(round(startend[i][1])), int(round(startend[i][0])), 0] + for i in range(len(startend)) + ] + ) + vec_y = np.array( + [ + score_mid[int(round(startend[i][1])), int(round(startend[i][0])), 1] + for i in range(len(startend)) + ] + ) score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1]) score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min( - 0.5 * oriImg.shape[0] / norm - 1, 0) + 0.5 * oriImg.shape[0] / norm - 1, 0 + ) criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts) criterion2 = score_with_dist_prior > 0 if criterion1 and criterion2: connection_candidate.append( - [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]]) + [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]] + ) connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True) connection = np.zeros((0, 5)) for c in range(len(connection_candidate)): i, j, s = connection_candidate[c][0:3] - if (i not in connection[:, 3] and j not in connection[:, 4]): + if i not in connection[:, 3] and j not in connection[:, 4]: connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]]) - if (len(connection) >= min(nA, nB)): + if len(connection) >= min(nA, nB): break connection_all.append(connection) @@ -198,7 +255,7 @@ def __call__(self, oriImg): j1, j2 = subset_idx membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2] if len(np.nonzero(membership == 2)[0]) == 0: # merge - subset[j1][:-2] += (subset[j2][:-2] + 1) + subset[j1][:-2] += subset[j2][:-2] + 1 subset[j1][-2:] += subset[j2][-2:] subset[j1][-2] += connection_all[k][i][2] subset = np.delete(subset, j2, 0) @@ -225,12 +282,12 @@ def __call__(self, oriImg): # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts # candidate: x, y, score, id return candidate, subset - + @staticmethod def format_body_result(candidate: np.ndarray, subset: np.ndarray) -> List[BodyResult]: """ Format the body results from the candidate and subset arrays into a list of BodyResult objects. - + Args: candidate (np.ndarray): An array of candidates containing the x, y coordinates, score, and id for each body part. @@ -249,12 +306,14 @@ def format_body_result(candidate: np.ndarray, subset: np.ndarray) -> List[BodyRe x=candidate[candidate_index][0], y=candidate[candidate_index][1], score=candidate[candidate_index][2], - id=candidate[candidate_index][3] - ) if candidate_index != -1 else None + id=candidate[candidate_index][3], + ) + if candidate_index != -1 + else None for candidate_index in person[:18].astype(int) ], total_score=person[18], - total_parts=person[19] + total_parts=person[19], ) for person in subset ] diff --git a/invokeai/backend/bria/controlnet_aux/open_pose/face.py b/invokeai/backend/bria/controlnet_aux/open_pose/face.py index 41c7799af10..fb1ee12b275 100644 --- a/invokeai/backend/bria/controlnet_aux/open_pose/face.py +++ b/invokeai/backend/bria/controlnet_aux/open_pose/face.py @@ -6,183 +6,81 @@ from torch.nn import Conv2d, MaxPool2d, Module, ReLU, init from torchvision.transforms import ToPILImage, ToTensor -from . import util +from invokeai.backend.bria.controlnet_aux.open_pose import util class FaceNet(Module): - """Model the cascading heatmaps. """ + """Model the cascading heatmaps.""" + def __init__(self): super(FaceNet, self).__init__() # cnn to make feature map self.relu = ReLU() self.max_pooling_2d = MaxPool2d(kernel_size=2, stride=2) - self.conv1_1 = Conv2d(in_channels=3, out_channels=64, - kernel_size=3, stride=1, padding=1) - self.conv1_2 = Conv2d( - in_channels=64, out_channels=64, kernel_size=3, stride=1, - padding=1) - self.conv2_1 = Conv2d( - in_channels=64, out_channels=128, kernel_size=3, stride=1, - padding=1) - self.conv2_2 = Conv2d( - in_channels=128, out_channels=128, kernel_size=3, stride=1, - padding=1) - self.conv3_1 = Conv2d( - in_channels=128, out_channels=256, kernel_size=3, stride=1, - padding=1) - self.conv3_2 = Conv2d( - in_channels=256, out_channels=256, kernel_size=3, stride=1, - padding=1) - self.conv3_3 = Conv2d( - in_channels=256, out_channels=256, kernel_size=3, stride=1, - padding=1) - self.conv3_4 = Conv2d( - in_channels=256, out_channels=256, kernel_size=3, stride=1, - padding=1) - self.conv4_1 = Conv2d( - in_channels=256, out_channels=512, kernel_size=3, stride=1, - padding=1) - self.conv4_2 = Conv2d( - in_channels=512, out_channels=512, kernel_size=3, stride=1, - padding=1) - self.conv4_3 = Conv2d( - in_channels=512, out_channels=512, kernel_size=3, stride=1, - padding=1) - self.conv4_4 = Conv2d( - in_channels=512, out_channels=512, kernel_size=3, stride=1, - padding=1) - self.conv5_1 = Conv2d( - in_channels=512, out_channels=512, kernel_size=3, stride=1, - padding=1) - self.conv5_2 = Conv2d( - in_channels=512, out_channels=512, kernel_size=3, stride=1, - padding=1) - self.conv5_3_CPM = Conv2d( - in_channels=512, out_channels=128, kernel_size=3, stride=1, - padding=1) + self.conv1_1 = Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1) + self.conv1_2 = Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1) + self.conv2_1 = Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1) + self.conv2_2 = Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1) + self.conv3_1 = Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1) + self.conv3_2 = Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1) + self.conv3_3 = Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1) + self.conv3_4 = Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1) + self.conv4_1 = Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1) + self.conv4_2 = Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1) + self.conv4_3 = Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1) + self.conv4_4 = Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1) + self.conv5_1 = Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1) + self.conv5_2 = Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1) + self.conv5_3_CPM = Conv2d(in_channels=512, out_channels=128, kernel_size=3, stride=1, padding=1) # stage1 - self.conv6_1_CPM = Conv2d( - in_channels=128, out_channels=512, kernel_size=1, stride=1, - padding=0) - self.conv6_2_CPM = Conv2d( - in_channels=512, out_channels=71, kernel_size=1, stride=1, - padding=0) + self.conv6_1_CPM = Conv2d(in_channels=128, out_channels=512, kernel_size=1, stride=1, padding=0) + self.conv6_2_CPM = Conv2d(in_channels=512, out_channels=71, kernel_size=1, stride=1, padding=0) # stage2 - self.Mconv1_stage2 = Conv2d( - in_channels=199, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv2_stage2 = Conv2d( - in_channels=128, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv3_stage2 = Conv2d( - in_channels=128, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv4_stage2 = Conv2d( - in_channels=128, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv5_stage2 = Conv2d( - in_channels=128, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv6_stage2 = Conv2d( - in_channels=128, out_channels=128, kernel_size=1, stride=1, - padding=0) - self.Mconv7_stage2 = Conv2d( - in_channels=128, out_channels=71, kernel_size=1, stride=1, - padding=0) + self.Mconv1_stage2 = Conv2d(in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv2_stage2 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv3_stage2 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv4_stage2 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv5_stage2 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv6_stage2 = Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0) + self.Mconv7_stage2 = Conv2d(in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0) # stage3 - self.Mconv1_stage3 = Conv2d( - in_channels=199, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv2_stage3 = Conv2d( - in_channels=128, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv3_stage3 = Conv2d( - in_channels=128, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv4_stage3 = Conv2d( - in_channels=128, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv5_stage3 = Conv2d( - in_channels=128, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv6_stage3 = Conv2d( - in_channels=128, out_channels=128, kernel_size=1, stride=1, - padding=0) - self.Mconv7_stage3 = Conv2d( - in_channels=128, out_channels=71, kernel_size=1, stride=1, - padding=0) + self.Mconv1_stage3 = Conv2d(in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv2_stage3 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv3_stage3 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv4_stage3 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv5_stage3 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv6_stage3 = Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0) + self.Mconv7_stage3 = Conv2d(in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0) # stage4 - self.Mconv1_stage4 = Conv2d( - in_channels=199, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv2_stage4 = Conv2d( - in_channels=128, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv3_stage4 = Conv2d( - in_channels=128, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv4_stage4 = Conv2d( - in_channels=128, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv5_stage4 = Conv2d( - in_channels=128, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv6_stage4 = Conv2d( - in_channels=128, out_channels=128, kernel_size=1, stride=1, - padding=0) - self.Mconv7_stage4 = Conv2d( - in_channels=128, out_channels=71, kernel_size=1, stride=1, - padding=0) + self.Mconv1_stage4 = Conv2d(in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv2_stage4 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv3_stage4 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv4_stage4 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv5_stage4 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv6_stage4 = Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0) + self.Mconv7_stage4 = Conv2d(in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0) # stage5 - self.Mconv1_stage5 = Conv2d( - in_channels=199, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv2_stage5 = Conv2d( - in_channels=128, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv3_stage5 = Conv2d( - in_channels=128, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv4_stage5 = Conv2d( - in_channels=128, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv5_stage5 = Conv2d( - in_channels=128, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv6_stage5 = Conv2d( - in_channels=128, out_channels=128, kernel_size=1, stride=1, - padding=0) - self.Mconv7_stage5 = Conv2d( - in_channels=128, out_channels=71, kernel_size=1, stride=1, - padding=0) + self.Mconv1_stage5 = Conv2d(in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv2_stage5 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv3_stage5 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv4_stage5 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv5_stage5 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv6_stage5 = Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0) + self.Mconv7_stage5 = Conv2d(in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0) # stage6 - self.Mconv1_stage6 = Conv2d( - in_channels=199, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv2_stage6 = Conv2d( - in_channels=128, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv3_stage6 = Conv2d( - in_channels=128, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv4_stage6 = Conv2d( - in_channels=128, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv5_stage6 = Conv2d( - in_channels=128, out_channels=128, kernel_size=7, stride=1, - padding=3) - self.Mconv6_stage6 = Conv2d( - in_channels=128, out_channels=128, kernel_size=1, stride=1, - padding=0) - self.Mconv7_stage6 = Conv2d( - in_channels=128, out_channels=71, kernel_size=1, stride=1, - padding=0) + self.Mconv1_stage6 = Conv2d(in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv2_stage6 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv3_stage6 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv4_stage6 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv5_stage6 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) + self.Mconv6_stage6 = Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0) + self.Mconv7_stage6 = Conv2d(in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0) for m in self.modules(): if isinstance(m, Conv2d): @@ -281,24 +179,74 @@ def forward(self, x): params = { - 'gaussian_sigma': 2.5, - 'inference_img_size': 736, # 368, 736, 1312 - 'heatmap_peak_thresh': 0.1, - 'crop_scale': 1.5, - 'line_indices': [ - [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], - [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13], - [13, 14], [14, 15], [15, 16], - [17, 18], [18, 19], [19, 20], [20, 21], - [22, 23], [23, 24], [24, 25], [25, 26], - [27, 28], [28, 29], [29, 30], - [31, 32], [32, 33], [33, 34], [34, 35], - [36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 36], - [42, 43], [43, 44], [44, 45], [45, 46], [46, 47], [47, 42], - [48, 49], [49, 50], [50, 51], [51, 52], [52, 53], [53, 54], - [54, 55], [55, 56], [56, 57], [57, 58], [58, 59], [59, 48], - [60, 61], [61, 62], [62, 63], [63, 64], [64, 65], [65, 66], - [66, 67], [67, 60] + "gaussian_sigma": 2.5, + "inference_img_size": 736, # 368, 736, 1312 + "heatmap_peak_thresh": 0.1, + "crop_scale": 1.5, + "line_indices": [ + [0, 1], + [1, 2], + [2, 3], + [3, 4], + [4, 5], + [5, 6], + [6, 7], + [7, 8], + [8, 9], + [9, 10], + [10, 11], + [11, 12], + [12, 13], + [13, 14], + [14, 15], + [15, 16], + [17, 18], + [18, 19], + [19, 20], + [20, 21], + [22, 23], + [23, 24], + [24, 25], + [25, 26], + [27, 28], + [28, 29], + [29, 30], + [31, 32], + [32, 33], + [33, 34], + [34, 35], + [36, 37], + [37, 38], + [38, 39], + [39, 40], + [40, 41], + [41, 36], + [42, 43], + [43, 44], + [44, 45], + [45, 46], + [46, 47], + [47, 42], + [48, 49], + [49, 50], + [50, 51], + [51, 52], + [52, 53], + [53, 54], + [54, 55], + [55, 56], + [56, 57], + [57, 58], + [58, 59], + [59, 48], + [60, 61], + [61, 62], + [62, 63], + [63, 64], + [64, 65], + [65, 66], + [66, 67], + [67, 60], ], } @@ -314,12 +262,10 @@ class Face(object): heatmap_peak_thresh: return landmark if over threshold, default 0.1 """ - def __init__(self, face_model_path, - inference_size=None, - gaussian_sigma=None, - heatmap_peak_thresh=None): + + def __init__(self, face_model_path, inference_size=None, gaussian_sigma=None, heatmap_peak_thresh=None): self.inference_size = inference_size or params["inference_img_size"] - self.sigma = gaussian_sigma or params['gaussian_sigma'] + self.sigma = gaussian_sigma or params["gaussian_sigma"] self.threshold = heatmap_peak_thresh or params["heatmap_peak_thresh"] self.model = FaceNet() self.model.load_state_dict(torch.load(face_model_path)) @@ -340,10 +286,7 @@ def __call__(self, face_img): with torch.no_grad(): hs = self.model(x_data[None, ...]) - heatmaps = F.interpolate( - hs[-1], - (H, W), - mode='bilinear', align_corners=True).cpu().numpy()[0] + heatmaps = F.interpolate(hs[-1], (H, W), mode="bilinear", align_corners=True).cpu().numpy()[0] return heatmaps def compute_peaks_from_heatmaps(self, heatmaps): @@ -361,4 +304,4 @@ def compute_peaks_from_heatmaps(self, heatmaps): y, x = positions[0][mi], positions[1][mi] all_peaks.append([x, y]) - return np.array(all_peaks) \ No newline at end of file + return np.array(all_peaks) diff --git a/invokeai/backend/bria/controlnet_aux/open_pose/hand.py b/invokeai/backend/bria/controlnet_aux/open_pose/hand.py index 1387c4238c8..d595896f98a 100644 --- a/invokeai/backend/bria/controlnet_aux/open_pose/hand.py +++ b/invokeai/backend/bria/controlnet_aux/open_pose/hand.py @@ -4,8 +4,8 @@ from scipy.ndimage.filters import gaussian_filter from skimage.measure import label -from . import util -from .model import handpose_model +from invokeai.backend.bria.controlnet_aux.open_pose import util +from invokeai.backend.bria.controlnet_aux.open_pose.model import handpose_model class Hand(object): @@ -53,7 +53,7 @@ def __call__(self, oriImgRaw): # extract outputs, resize, and remove padding heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride) - heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + heatmap = heatmap[: imageToTest_padded.shape[0] - pad[2], : imageToTest_padded.shape[1] - pad[3], :] heatmap = util.smart_resize(heatmap, (wsize, wsize)) heatmap_avg += heatmap / len(multiplier) @@ -78,13 +78,14 @@ def __call__(self, oriImgRaw): all_peaks.append([x, y]) return np.array(all_peaks) + if __name__ == "__main__": - hand_estimation = Hand('../model/hand_pose_model.pth') + hand_estimation = Hand("../model/hand_pose_model.pth") # test_image = '../images/hand.jpg' - test_image = '../images/hand.jpg' + test_image = "../images/hand.jpg" oriImg = cv2.imread(test_image) # B,G,R order peaks = hand_estimation(oriImg) canvas = util.draw_handpose(oriImg, peaks, True) - cv2.imshow('', canvas) - cv2.waitKey(0) \ No newline at end of file + cv2.imshow("", canvas) + cv2.waitKey(0) diff --git a/invokeai/backend/bria/controlnet_aux/open_pose/model.py b/invokeai/backend/bria/controlnet_aux/open_pose/model.py index 6c3d4726898..023cfd596db 100644 --- a/invokeai/backend/bria/controlnet_aux/open_pose/model.py +++ b/invokeai/backend/bria/controlnet_aux/open_pose/model.py @@ -1,118 +1,133 @@ -import torch from collections import OrderedDict import torch import torch.nn as nn + def make_layers(block, no_relu_layers): layers = [] for layer_name, v in block.items(): - if 'pool' in layer_name: - layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], - padding=v[2]) + if "pool" in layer_name: + layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2]) layers.append((layer_name, layer)) else: - conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], - kernel_size=v[2], stride=v[3], - padding=v[4]) + conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], kernel_size=v[2], stride=v[3], padding=v[4]) layers.append((layer_name, conv2d)) if layer_name not in no_relu_layers: - layers.append(('relu_'+layer_name, nn.ReLU(inplace=True))) + layers.append(("relu_" + layer_name, nn.ReLU(inplace=True))) return nn.Sequential(OrderedDict(layers)) + class bodypose_model(nn.Module): def __init__(self): super(bodypose_model, self).__init__() # these layers have no relu layer - no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\ - 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\ - 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\ - 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1'] + no_relu_layers = [ + "conv5_5_CPM_L1", + "conv5_5_CPM_L2", + "Mconv7_stage2_L1", + "Mconv7_stage2_L2", + "Mconv7_stage3_L1", + "Mconv7_stage3_L2", + "Mconv7_stage4_L1", + "Mconv7_stage4_L2", + "Mconv7_stage5_L1", + "Mconv7_stage5_L2", + "Mconv7_stage6_L1", + "Mconv7_stage6_L1", + ] blocks = {} - block0 = OrderedDict([ - ('conv1_1', [3, 64, 3, 1, 1]), - ('conv1_2', [64, 64, 3, 1, 1]), - ('pool1_stage1', [2, 2, 0]), - ('conv2_1', [64, 128, 3, 1, 1]), - ('conv2_2', [128, 128, 3, 1, 1]), - ('pool2_stage1', [2, 2, 0]), - ('conv3_1', [128, 256, 3, 1, 1]), - ('conv3_2', [256, 256, 3, 1, 1]), - ('conv3_3', [256, 256, 3, 1, 1]), - ('conv3_4', [256, 256, 3, 1, 1]), - ('pool3_stage1', [2, 2, 0]), - ('conv4_1', [256, 512, 3, 1, 1]), - ('conv4_2', [512, 512, 3, 1, 1]), - ('conv4_3_CPM', [512, 256, 3, 1, 1]), - ('conv4_4_CPM', [256, 128, 3, 1, 1]) - ]) - + block0 = OrderedDict( + [ + ("conv1_1", [3, 64, 3, 1, 1]), + ("conv1_2", [64, 64, 3, 1, 1]), + ("pool1_stage1", [2, 2, 0]), + ("conv2_1", [64, 128, 3, 1, 1]), + ("conv2_2", [128, 128, 3, 1, 1]), + ("pool2_stage1", [2, 2, 0]), + ("conv3_1", [128, 256, 3, 1, 1]), + ("conv3_2", [256, 256, 3, 1, 1]), + ("conv3_3", [256, 256, 3, 1, 1]), + ("conv3_4", [256, 256, 3, 1, 1]), + ("pool3_stage1", [2, 2, 0]), + ("conv4_1", [256, 512, 3, 1, 1]), + ("conv4_2", [512, 512, 3, 1, 1]), + ("conv4_3_CPM", [512, 256, 3, 1, 1]), + ("conv4_4_CPM", [256, 128, 3, 1, 1]), + ] + ) # Stage 1 - block1_1 = OrderedDict([ - ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]), - ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]), - ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]), - ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]), - ('conv5_5_CPM_L1', [512, 38, 1, 1, 0]) - ]) - - block1_2 = OrderedDict([ - ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]), - ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]), - ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]), - ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]), - ('conv5_5_CPM_L2', [512, 19, 1, 1, 0]) - ]) - blocks['block1_1'] = block1_1 - blocks['block1_2'] = block1_2 + block1_1 = OrderedDict( + [ + ("conv5_1_CPM_L1", [128, 128, 3, 1, 1]), + ("conv5_2_CPM_L1", [128, 128, 3, 1, 1]), + ("conv5_3_CPM_L1", [128, 128, 3, 1, 1]), + ("conv5_4_CPM_L1", [128, 512, 1, 1, 0]), + ("conv5_5_CPM_L1", [512, 38, 1, 1, 0]), + ] + ) + + block1_2 = OrderedDict( + [ + ("conv5_1_CPM_L2", [128, 128, 3, 1, 1]), + ("conv5_2_CPM_L2", [128, 128, 3, 1, 1]), + ("conv5_3_CPM_L2", [128, 128, 3, 1, 1]), + ("conv5_4_CPM_L2", [128, 512, 1, 1, 0]), + ("conv5_5_CPM_L2", [512, 19, 1, 1, 0]), + ] + ) + blocks["block1_1"] = block1_1 + blocks["block1_2"] = block1_2 self.model0 = make_layers(block0, no_relu_layers) # Stages 2 - 6 for i in range(2, 7): - blocks['block%d_1' % i] = OrderedDict([ - ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]), - ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]), - ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]), - ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]), - ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]), - ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]), - ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0]) - ]) - - blocks['block%d_2' % i] = OrderedDict([ - ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]), - ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]), - ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]), - ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]), - ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]), - ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]), - ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0]) - ]) + blocks["block%d_1" % i] = OrderedDict( + [ + ("Mconv1_stage%d_L1" % i, [185, 128, 7, 1, 3]), + ("Mconv2_stage%d_L1" % i, [128, 128, 7, 1, 3]), + ("Mconv3_stage%d_L1" % i, [128, 128, 7, 1, 3]), + ("Mconv4_stage%d_L1" % i, [128, 128, 7, 1, 3]), + ("Mconv5_stage%d_L1" % i, [128, 128, 7, 1, 3]), + ("Mconv6_stage%d_L1" % i, [128, 128, 1, 1, 0]), + ("Mconv7_stage%d_L1" % i, [128, 38, 1, 1, 0]), + ] + ) + + blocks["block%d_2" % i] = OrderedDict( + [ + ("Mconv1_stage%d_L2" % i, [185, 128, 7, 1, 3]), + ("Mconv2_stage%d_L2" % i, [128, 128, 7, 1, 3]), + ("Mconv3_stage%d_L2" % i, [128, 128, 7, 1, 3]), + ("Mconv4_stage%d_L2" % i, [128, 128, 7, 1, 3]), + ("Mconv5_stage%d_L2" % i, [128, 128, 7, 1, 3]), + ("Mconv6_stage%d_L2" % i, [128, 128, 1, 1, 0]), + ("Mconv7_stage%d_L2" % i, [128, 19, 1, 1, 0]), + ] + ) for k in blocks.keys(): blocks[k] = make_layers(blocks[k], no_relu_layers) - self.model1_1 = blocks['block1_1'] - self.model2_1 = blocks['block2_1'] - self.model3_1 = blocks['block3_1'] - self.model4_1 = blocks['block4_1'] - self.model5_1 = blocks['block5_1'] - self.model6_1 = blocks['block6_1'] - - self.model1_2 = blocks['block1_2'] - self.model2_2 = blocks['block2_2'] - self.model3_2 = blocks['block3_2'] - self.model4_2 = blocks['block4_2'] - self.model5_2 = blocks['block5_2'] - self.model6_2 = blocks['block6_2'] + self.model1_1 = blocks["block1_1"] + self.model2_1 = blocks["block2_1"] + self.model3_1 = blocks["block3_1"] + self.model4_1 = blocks["block4_1"] + self.model5_1 = blocks["block5_1"] + self.model6_1 = blocks["block6_1"] + self.model1_2 = blocks["block1_2"] + self.model2_2 = blocks["block2_2"] + self.model3_2 = blocks["block3_2"] + self.model4_2 = blocks["block4_2"] + self.model5_2 = blocks["block5_2"] + self.model6_2 = blocks["block6_2"] def forward(self, x): - out1 = self.model0(x) out1_1 = self.model1_1(out1) @@ -140,66 +155,74 @@ def forward(self, x): return out6_1, out6_2 + class handpose_model(nn.Module): def __init__(self): super(handpose_model, self).__init__() # these layers have no relu layer - no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\ - 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6'] + no_relu_layers = [ + "conv6_2_CPM", + "Mconv7_stage2", + "Mconv7_stage3", + "Mconv7_stage4", + "Mconv7_stage5", + "Mconv7_stage6", + ] # stage 1 - block1_0 = OrderedDict([ - ('conv1_1', [3, 64, 3, 1, 1]), - ('conv1_2', [64, 64, 3, 1, 1]), - ('pool1_stage1', [2, 2, 0]), - ('conv2_1', [64, 128, 3, 1, 1]), - ('conv2_2', [128, 128, 3, 1, 1]), - ('pool2_stage1', [2, 2, 0]), - ('conv3_1', [128, 256, 3, 1, 1]), - ('conv3_2', [256, 256, 3, 1, 1]), - ('conv3_3', [256, 256, 3, 1, 1]), - ('conv3_4', [256, 256, 3, 1, 1]), - ('pool3_stage1', [2, 2, 0]), - ('conv4_1', [256, 512, 3, 1, 1]), - ('conv4_2', [512, 512, 3, 1, 1]), - ('conv4_3', [512, 512, 3, 1, 1]), - ('conv4_4', [512, 512, 3, 1, 1]), - ('conv5_1', [512, 512, 3, 1, 1]), - ('conv5_2', [512, 512, 3, 1, 1]), - ('conv5_3_CPM', [512, 128, 3, 1, 1]) - ]) - - block1_1 = OrderedDict([ - ('conv6_1_CPM', [128, 512, 1, 1, 0]), - ('conv6_2_CPM', [512, 22, 1, 1, 0]) - ]) + block1_0 = OrderedDict( + [ + ("conv1_1", [3, 64, 3, 1, 1]), + ("conv1_2", [64, 64, 3, 1, 1]), + ("pool1_stage1", [2, 2, 0]), + ("conv2_1", [64, 128, 3, 1, 1]), + ("conv2_2", [128, 128, 3, 1, 1]), + ("pool2_stage1", [2, 2, 0]), + ("conv3_1", [128, 256, 3, 1, 1]), + ("conv3_2", [256, 256, 3, 1, 1]), + ("conv3_3", [256, 256, 3, 1, 1]), + ("conv3_4", [256, 256, 3, 1, 1]), + ("pool3_stage1", [2, 2, 0]), + ("conv4_1", [256, 512, 3, 1, 1]), + ("conv4_2", [512, 512, 3, 1, 1]), + ("conv4_3", [512, 512, 3, 1, 1]), + ("conv4_4", [512, 512, 3, 1, 1]), + ("conv5_1", [512, 512, 3, 1, 1]), + ("conv5_2", [512, 512, 3, 1, 1]), + ("conv5_3_CPM", [512, 128, 3, 1, 1]), + ] + ) + + block1_1 = OrderedDict([("conv6_1_CPM", [128, 512, 1, 1, 0]), ("conv6_2_CPM", [512, 22, 1, 1, 0])]) blocks = {} - blocks['block1_0'] = block1_0 - blocks['block1_1'] = block1_1 + blocks["block1_0"] = block1_0 + blocks["block1_1"] = block1_1 # stage 2-6 for i in range(2, 7): - blocks['block%d' % i] = OrderedDict([ - ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]), - ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]), - ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]), - ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]), - ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]), - ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]), - ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0]) - ]) + blocks["block%d" % i] = OrderedDict( + [ + ("Mconv1_stage%d" % i, [150, 128, 7, 1, 3]), + ("Mconv2_stage%d" % i, [128, 128, 7, 1, 3]), + ("Mconv3_stage%d" % i, [128, 128, 7, 1, 3]), + ("Mconv4_stage%d" % i, [128, 128, 7, 1, 3]), + ("Mconv5_stage%d" % i, [128, 128, 7, 1, 3]), + ("Mconv6_stage%d" % i, [128, 128, 1, 1, 0]), + ("Mconv7_stage%d" % i, [128, 22, 1, 1, 0]), + ] + ) for k in blocks.keys(): blocks[k] = make_layers(blocks[k], no_relu_layers) - self.model1_0 = blocks['block1_0'] - self.model1_1 = blocks['block1_1'] - self.model2 = blocks['block2'] - self.model3 = blocks['block3'] - self.model4 = blocks['block4'] - self.model5 = blocks['block5'] - self.model6 = blocks['block6'] + self.model1_0 = blocks["block1_0"] + self.model1_1 = blocks["block1_1"] + self.model2 = blocks["block2"] + self.model3 = blocks["block3"] + self.model4 = blocks["block4"] + self.model5 = blocks["block5"] + self.model6 = blocks["block6"] def forward(self, x): out1_0 = self.model1_0(x) diff --git a/invokeai/backend/bria/controlnet_aux/open_pose/util.py b/invokeai/backend/bria/controlnet_aux/open_pose/util.py index f10ca2dfcbf..7922a4b3b31 100644 --- a/invokeai/backend/bria/controlnet_aux/open_pose/util.py +++ b/invokeai/backend/bria/controlnet_aux/open_pose/util.py @@ -1,9 +1,10 @@ import math -import numpy as np -import cv2 from typing import List, Tuple, Union -from .body import BodyResult, Keypoint +import cv2 +import numpy as np + +from invokeai.backend.bria.controlnet_aux.open_pose.body import BodyResult, Keypoint eps = 0.01 @@ -41,19 +42,19 @@ def padRightDownCorner(img, stride, padValue): w = img.shape[1] pad = 4 * [None] - pad[0] = 0 # up - pad[1] = 0 # left - pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down - pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right + pad[0] = 0 # up + pad[1] = 0 # left + pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down + pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right img_padded = img - pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) + pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1)) img_padded = np.concatenate((pad_up, img_padded), axis=0) - pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) + pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1)) img_padded = np.concatenate((pad_left, img_padded), axis=1) - pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) + pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1)) img_padded = np.concatenate((img_padded, pad_down), axis=0) - pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) + pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1)) img_padded = np.concatenate((img_padded, pad_right), axis=1) return img_padded, pad @@ -62,7 +63,7 @@ def padRightDownCorner(img, stride, padValue): def transfer(model, model_weights): transfered_model_weights = {} for weights_name in model.state_dict().keys(): - transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] + transfered_model_weights[weights_name] = model_weights[".".join(weights_name.split(".")[1:])] return transfered_model_weights @@ -84,18 +85,47 @@ def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint]) -> np.ndarray: stickwidth = 4 limbSeq = [ - [2, 3], [2, 6], [3, 4], [4, 5], - [6, 7], [7, 8], [2, 9], [9, 10], - [10, 11], [2, 12], [12, 13], [13, 14], - [2, 1], [1, 15], [15, 17], [1, 16], + [2, 3], + [2, 6], + [3, 4], + [4, 5], + [6, 7], + [7, 8], + [2, 9], + [9, 10], + [10, 11], + [2, 12], + [12, 13], + [13, 14], + [2, 1], + [1, 15], + [15, 17], + [1, 16], [16, 18], ] - colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ - [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ - [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + colors = [ + [255, 0, 0], + [255, 85, 0], + [255, 170, 0], + [255, 255, 0], + [170, 255, 0], + [85, 255, 0], + [0, 255, 0], + [0, 255, 85], + [0, 255, 170], + [0, 255, 255], + [0, 170, 255], + [0, 85, 255], + [0, 0, 255], + [85, 0, 255], + [170, 0, 255], + [255, 0, 255], + [255, 0, 170], + [255, 0, 85], + ] - for (k1_index, k2_index), color in zip(limbSeq, colors): + for (k1_index, k2_index), color in zip(limbSeq, colors, strict=False): keypoint1 = keypoints[k1_index - 1] keypoint2 = keypoints[k2_index - 1] @@ -111,7 +141,7 @@ def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint]) -> np.ndarray: polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) cv2.fillConvexPoly(canvas, polygon, [int(float(c) * 0.6) for c in color]) - for keypoint, color in zip(keypoints, colors): + for keypoint, color in zip(keypoints, colors, strict=False): if keypoint is None: continue @@ -125,6 +155,7 @@ def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint]) -> np.ndarray: def draw_handpose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray: import matplotlib + """ Draw keypoints and connections representing hand pose on a given canvas. @@ -141,24 +172,50 @@ def draw_handpose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> """ if not keypoints: return canvas - + H, W, C = canvas.shape - edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ - [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] + edges = [ + [0, 1], + [1, 2], + [2, 3], + [3, 4], + [0, 5], + [5, 6], + [6, 7], + [7, 8], + [0, 9], + [9, 10], + [10, 11], + [11, 12], + [0, 13], + [13, 14], + [14, 15], + [15, 16], + [0, 17], + [17, 18], + [18, 19], + [19, 20], + ] for ie, (e1, e2) in enumerate(edges): k1 = keypoints[e1] k2 = keypoints[e2] if k1 is None or k2 is None: continue - + x1 = int(k1.x * W) y1 = int(k1.y * H) x2 = int(k2.x * W) y2 = int(k2.y * H) if x1 > eps and y1 > eps and x2 > eps and y2 > eps: - cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2) + cv2.line( + canvas, + (x1, y1), + (x2, y2), + matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, + thickness=2, + ) for keypoint in keypoints: x, y = keypoint.x, keypoint.y @@ -183,10 +240,10 @@ def draw_facepose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> Note: The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. - """ + """ if not keypoints: return canvas - + H, W, C = canvas.shape for keypoint in keypoints: x, y = keypoint.x, keypoint.y @@ -220,7 +277,7 @@ def handDetect(body: BodyResult, oriImg) -> List[Tuple[int, int, int, bool]]: ratioWristElbow = 0.33 detect_result = [] image_height, image_width = oriImg.shape[0:2] - + keypoints = body.keypoints # right hand: wrist 4, elbow 3, shoulder 2 # left hand: wrist 7, elbow 6, shoulder 5 @@ -236,24 +293,16 @@ def handDetect(body: BodyResult, oriImg) -> List[Tuple[int, int, int, bool]]: has_right = all(keypoint is not None for keypoint in (right_shoulder, right_elbow, right_wrist)) if not (has_left or has_right): return [] - + hands = [] - #left hand + # left hand if has_left: - hands.append([ - left_shoulder.x, left_shoulder.y, - left_elbow.x, left_elbow.y, - left_wrist.x, left_wrist.y, - True - ]) + hands.append([left_shoulder.x, left_shoulder.y, left_elbow.x, left_elbow.y, left_wrist.x, left_wrist.y, True]) # right hand if has_right: - hands.append([ - right_shoulder.x, right_shoulder.y, - right_elbow.x, right_elbow.y, - right_wrist.x, right_wrist.y, - False - ]) + hands.append( + [right_shoulder.x, right_shoulder.y, right_elbow.x, right_elbow.y, right_wrist.x, right_wrist.y, False] + ) for x1, y1, x2, y2, x3, y3, is_left in hands: # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox @@ -273,22 +322,26 @@ def handDetect(body: BodyResult, oriImg) -> List[Tuple[int, int, int, bool]]: x -= width / 2 y -= width / 2 # width = height # overflow the image - if x < 0: x = 0 - if y < 0: y = 0 + if x < 0: + x = 0 + if y < 0: + y = 0 width1 = width width2 = width - if x + width > image_width: width1 = image_width - x - if y + width > image_height: width2 = image_height - y + if x + width > image_width: + width1 = image_width - x + if y + width > image_height: + width2 = image_height - y width = min(width1, width2) # the max hand box value is 20 pixels if width >= 20: detect_result.append((int(x), int(y), int(width), is_left)) - ''' + """ return value: [[x, y, w, True if left hand else False]]. width=height since the network require squared input. - x, y is the coordinate of top left - ''' + x, y is the coordinate of top left. + """ return detect_result @@ -312,14 +365,14 @@ def faceDetect(body: BodyResult, oriImg) -> Union[Tuple[int, int, int], None]: """ # left right eye ear 14 15 16 17 image_height, image_width = oriImg.shape[0:2] - + keypoints = body.keypoints head = keypoints[0] left_eye = keypoints[14] right_eye = keypoints[15] left_ear = keypoints[16] right_ear = keypoints[17] - + if head is None or all(keypoint is None for keypoint in (left_eye, right_eye, left_ear, right_ear)): return None diff --git a/invokeai/backend/bria/controlnet_aux/util.py b/invokeai/backend/bria/controlnet_aux/util.py index 79ba7f120cc..21e51c643d7 100644 --- a/invokeai/backend/bria/controlnet_aux/util.py +++ b/invokeai/backend/bria/controlnet_aux/util.py @@ -5,7 +5,7 @@ import numpy as np import torch -annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') +annotator_ckpts_path = os.path.join(os.path.dirname(__file__), "ckpts") def HWC3(x): @@ -30,7 +30,7 @@ def HWC3(x): def make_noise_disk(H, W, C, F): noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C)) noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) - noise = noise[F: F + H, F: F + W] + noise = noise[F : F + H, F : F + W] noise -= np.min(noise) noise /= np.max(noise) if C == 1: @@ -55,6 +55,7 @@ def nms(x, t, s): z[y > t] = 255 return z + def min_max_norm(x): x -= np.min(x) x /= np.maximum(np.max(x), 1e-5) @@ -105,42 +106,155 @@ def torch_gc(): def ade_palette(): """ADE20K palette that maps each class to RGB values.""" - return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], - [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], - [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], - [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], - [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], - [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], - [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], - [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], - [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], - [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], - [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], - [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], - [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], - [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], - [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], - [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], - [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], - [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], - [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], - [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], - [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], - [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], - [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], - [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], - [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], - [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], - [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], - [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], - [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], - [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], - [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], - [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], - [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], - [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], - [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], - [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], - [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], - [102, 255, 0], [92, 0, 255]] - + return [ + [120, 120, 120], + [180, 120, 120], + [6, 230, 230], + [80, 50, 50], + [4, 200, 3], + [120, 120, 80], + [140, 140, 140], + [204, 5, 255], + [230, 230, 230], + [4, 250, 7], + [224, 5, 255], + [235, 255, 7], + [150, 5, 61], + [120, 120, 70], + [8, 255, 51], + [255, 6, 82], + [143, 255, 140], + [204, 255, 4], + [255, 51, 7], + [204, 70, 3], + [0, 102, 200], + [61, 230, 250], + [255, 6, 51], + [11, 102, 255], + [255, 7, 71], + [255, 9, 224], + [9, 7, 230], + [220, 220, 220], + [255, 9, 92], + [112, 9, 255], + [8, 255, 214], + [7, 255, 224], + [255, 184, 6], + [10, 255, 71], + [255, 41, 10], + [7, 255, 255], + [224, 255, 8], + [102, 8, 255], + [255, 61, 6], + [255, 194, 7], + [255, 122, 8], + [0, 255, 20], + [255, 8, 41], + [255, 5, 153], + [6, 51, 255], + [235, 12, 255], + [160, 150, 20], + [0, 163, 255], + [140, 140, 140], + [250, 10, 15], + [20, 255, 0], + [31, 255, 0], + [255, 31, 0], + [255, 224, 0], + [153, 255, 0], + [0, 0, 255], + [255, 71, 0], + [0, 235, 255], + [0, 173, 255], + [31, 0, 255], + [11, 200, 200], + [255, 82, 0], + [0, 255, 245], + [0, 61, 255], + [0, 255, 112], + [0, 255, 133], + [255, 0, 0], + [255, 163, 0], + [255, 102, 0], + [194, 255, 0], + [0, 143, 255], + [51, 255, 0], + [0, 82, 255], + [0, 255, 41], + [0, 255, 173], + [10, 0, 255], + [173, 255, 0], + [0, 255, 153], + [255, 92, 0], + [255, 0, 255], + [255, 0, 245], + [255, 0, 102], + [255, 173, 0], + [255, 0, 20], + [255, 184, 184], + [0, 31, 255], + [0, 255, 61], + [0, 71, 255], + [255, 0, 204], + [0, 255, 194], + [0, 255, 82], + [0, 10, 255], + [0, 112, 255], + [51, 0, 255], + [0, 194, 255], + [0, 122, 255], + [0, 255, 163], + [255, 153, 0], + [0, 255, 10], + [255, 112, 0], + [143, 255, 0], + [82, 0, 255], + [163, 255, 0], + [255, 235, 0], + [8, 184, 170], + [133, 0, 255], + [0, 255, 92], + [184, 0, 255], + [255, 0, 31], + [0, 184, 255], + [0, 214, 255], + [255, 0, 112], + [92, 255, 0], + [0, 224, 255], + [112, 224, 255], + [70, 184, 160], + [163, 0, 255], + [153, 0, 255], + [71, 255, 0], + [255, 0, 163], + [255, 204, 0], + [255, 0, 143], + [0, 255, 235], + [133, 255, 0], + [255, 0, 235], + [245, 0, 255], + [255, 0, 122], + [255, 245, 0], + [10, 190, 212], + [214, 255, 0], + [0, 204, 255], + [20, 0, 255], + [255, 255, 0], + [0, 153, 255], + [0, 41, 255], + [0, 255, 204], + [41, 0, 255], + [41, 255, 0], + [173, 0, 255], + [0, 245, 255], + [71, 0, 255], + [122, 0, 255], + [0, 255, 184], + [0, 92, 255], + [184, 255, 0], + [0, 133, 255], + [255, 214, 0], + [25, 194, 194], + [102, 255, 0], + [92, 0, 255], + ] diff --git a/invokeai/backend/bria/controlnet_bria.py b/invokeai/backend/bria/controlnet_bria.py index a845afbcf2e..f48c1a15abf 100644 --- a/invokeai/backend/bria/controlnet_bria.py +++ b/invokeai/backend/bria/controlnet_bria.py @@ -14,28 +14,33 @@ # limitations under the License. from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union -from typing import Literal from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Tuple, Union import torch import torch.nn as nn - -from invokeai.backend.bria.transformer_bria import TimestepProjEmbeddings, FluxSingleTransformerBlock, FluxTransformerBlock, EmbedND -from diffusers.models.controlnet import zero_module -from diffusers.utils.outputs import BaseOutput from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import PeftAdapterMixin +from diffusers.models.attention_processor import AttentionProcessor +from diffusers.models.controlnet import zero_module +from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers -from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.utils.outputs import BaseOutput -from diffusers.models.attention_processor import AttentionProcessor +from invokeai.backend.bria.transformer_bria import ( + EmbedND, + FluxSingleTransformerBlock, + FluxTransformerBlock, + TimestepProjEmbeddings, +) logger = logging.get_logger(__name__) # pylint: disable=invalid-name BRIA_CONTROL_MODES = Literal["depth", "canny", "colorgrid", "recolor", "tile", "pose"] + + class BriaControlModes(Enum): depth = 0 canny = 1 @@ -66,7 +71,7 @@ def __init__( joint_attention_dim: int = 4096, pooled_projection_dim: int = 768, guidance_embeds: bool = False, - axes_dims_rope: List[int] = [16, 56, 56], + axes_dims_rope: Optional[List[int]] = None, num_mode: int = None, rope_theta: int = 10000, time_theta: int = 10000, @@ -76,6 +81,7 @@ def __init__( self.inner_dim = num_attention_heads * attention_head_dim # self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) + axes_dims_rope = [16, 56, 56] if axes_dims_rope is None else axes_dims_rope self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope) # text_time_guidance_cls = ( @@ -84,10 +90,8 @@ def __init__( # self.time_text_embed = text_time_guidance_cls( # embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim # ) - self.time_embed = TimestepProjEmbeddings( - embedding_dim=self.inner_dim, time_theta=time_theta - ) - + self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta) + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim) @@ -290,7 +294,7 @@ def forward( # Convert controlnet_cond to the same dtype as the model weights controlnet_cond = controlnet_cond.to(dtype=self.controlnet_x_embedder.weight.dtype) - + # add hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond) @@ -316,28 +320,32 @@ def forward( "Please remove the batch dimension and pass it as a 2d torch Tensor" ) img_ids = img_ids[0] - + if self.union: # union mode if controlnet_mode is None: raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union") - + # Validate controlnet_mode values are within the valid range if torch.any(controlnet_mode < 0) or torch.any(controlnet_mode >= self.num_mode): - raise ValueError(f"`controlnet_mode` values must be in range [0, {self.num_mode-1}], but got values outside this range") - + raise ValueError( + f"`controlnet_mode` values must be in range [0, {self.num_mode - 1}], but got values outside this range" + ) + # union mode emb controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode) - if controlnet_mode_emb.shape[0] < encoder_hidden_states.shape[0]: # duplicate mode emb for each batch - controlnet_mode_emb = controlnet_mode_emb.expand(encoder_hidden_states.shape[0], 1, encoder_hidden_states.shape[2]) + if controlnet_mode_emb.shape[0] < encoder_hidden_states.shape[0]: # duplicate mode emb for each batch + controlnet_mode_emb = controlnet_mode_emb.expand( + encoder_hidden_states.shape[0], 1, encoder_hidden_states.shape[2] + ) encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1) - - txt_ids = torch.cat((txt_ids[0:1, :], txt_ids), dim=0) + + txt_ids = torch.cat((txt_ids[0:1, :], txt_ids), dim=0) ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) block_samples = () - for index_block, block in enumerate(self.transformer_blocks): + for _, block in enumerate(self.transformer_blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): @@ -371,7 +379,7 @@ def custom_forward(*inputs): hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) single_block_samples = () - for index_block, block in enumerate(self.single_transformer_blocks): + for _, block in enumerate(self.single_transformer_blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): @@ -402,12 +410,14 @@ def custom_forward(*inputs): # controlnet block controlnet_block_samples = () - for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks, strict=False): block_sample = controlnet_block(block_sample) controlnet_block_samples = controlnet_block_samples + (block_sample,) controlnet_single_block_samples = () - for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks): + for single_block_sample, controlnet_block in zip( + single_block_samples, self.controlnet_single_blocks, strict=False + ): single_block_sample = controlnet_block(single_block_sample) controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,) @@ -468,7 +478,9 @@ def forward( if len(self.nets) == 1 and self.nets[0].union: controlnet = self.nets[0] - for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)): + for i, (image, mode, scale) in enumerate( + zip(controlnet_cond, controlnet_mode, conditioning_scale, strict=False) + ): block_samples, single_block_samples = controlnet( hidden_states=hidden_states, controlnet_cond=image, @@ -491,13 +503,15 @@ def forward( else: control_block_samples = [ control_block_sample + block_sample - for control_block_sample, block_sample in zip(control_block_samples, block_samples) + for control_block_sample, block_sample in zip( + control_block_samples, block_samples, strict=False + ) ] control_single_block_samples = [ control_single_block_sample + block_sample for control_single_block_sample, block_sample in zip( - control_single_block_samples, single_block_samples + control_single_block_samples, single_block_samples, strict=False ) ] @@ -505,8 +519,8 @@ def forward( # load all ControlNets into memories else: for i, (image, mode, scale, controlnet) in enumerate( - zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets) - ): + zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets, strict=False) + ): block_samples, single_block_samples = controlnet( hidden_states=hidden_states, controlnet_cond=image, @@ -530,14 +544,16 @@ def forward( if block_samples is not None and control_block_samples is not None: control_block_samples = [ control_block_sample + block_sample - for control_block_sample, block_sample in zip(control_block_samples, block_samples) + for control_block_sample, block_sample in zip( + control_block_samples, block_samples, strict=False + ) ] if single_block_samples is not None and control_single_block_samples is not None: control_single_block_samples = [ control_single_block_sample + block_sample for control_single_block_sample, block_sample in zip( - control_single_block_samples, single_block_samples + control_single_block_samples, single_block_samples, strict=False ) ] - return control_block_samples, control_single_block_samples \ No newline at end of file + return control_block_samples, control_single_block_samples diff --git a/invokeai/backend/bria/controlnet_utils.py b/invokeai/backend/bria/controlnet_utils.py index 91dc270c846..4b0d38fc95d 100644 --- a/invokeai/backend/bria/controlnet_utils.py +++ b/invokeai/backend/bria/controlnet_utils.py @@ -1,11 +1,9 @@ from typing import List, Tuple -from PIL import Image -from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL - -from diffusers.image_processor import VaeImageProcessor import torch - +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from PIL import Image @torch.no_grad() @@ -17,7 +15,6 @@ def prepare_control_images( height: int, device: torch.device, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - tensored_control_images = [] tensored_control_modes = [] for idx, control_image_ in enumerate(control_images): @@ -42,11 +39,13 @@ def prepare_control_images( width_control_image, ) tensored_control_images.append(tensored_control_image) - tensored_control_modes.append(torch.tensor(control_modes[idx]).expand( - tensored_control_image.shape[0]).to(device, dtype=torch.long)) + tensored_control_modes.append( + torch.tensor(control_modes[idx]).expand(tensored_control_image.shape[0]).to(device, dtype=torch.long) + ) return tensored_control_images, tensored_control_modes + def _prepare_image( image: Image.Image, width: int, @@ -60,10 +59,10 @@ def _prepare_image( image = image.to(device=device, dtype=dtype) return image + def _pack_latents(latents, height, width): latents = latents.view(1, 4, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) latents = latents.reshape(1, (height // 2) * (width // 2), 16) return latents - diff --git a/invokeai/backend/bria/pipeline_bria.py b/invokeai/backend/bria/pipeline_bria.py index 7a195a6ae75..2d1f8468bd2 100644 --- a/invokeai/backend/bria/pipeline_bria.py +++ b/invokeai/backend/bria/pipeline_bria.py @@ -1,18 +1,15 @@ -from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, retrieve_timesteps, calculate_shift from typing import Any, Callable, Dict, List, Optional, Union +import diffusers +import numpy as np import torch - -from transformers import ( - T5EncoderModel, - T5TokenizerFast, -) - +from diffusers import AutoencoderKL, DDIMScheduler, EulerAncestralDiscreteScheduler from diffusers.image_processor import VaeImageProcessor -from diffusers import AutoencoderKL , DDIMScheduler, EulerAncestralDiscreteScheduler -from diffusers.schedulers import FlowMatchEulerDiscreteScheduler -from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.loaders import FluxLoraLoaderMixin +from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, calculate_shift, retrieve_timesteps +from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers from diffusers.utils import ( USE_PEFT_BACKEND, logging, @@ -20,16 +17,14 @@ scale_lora_layers, unscale_lora_layers, ) -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput -from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel -from invokeai.backend.bria.bria_utils import get_t5_prompt_embeds, get_original_sigmas, is_ng_none from diffusers.utils.torch_utils import randn_tensor -import diffusers -import numpy as np - -XLA_AVAILABLE = False +from transformers import ( + T5EncoderModel, + T5TokenizerFast, +) +from invokeai.backend.bria.bria_utils import get_original_sigmas, get_t5_prompt_embeds, is_ng_none +from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -57,6 +52,8 @@ - We use zero padding for prompts - No guidance embedding since this is not a distilled version """ + + class BriaPipeline(FluxPipeline): r""" Args: @@ -78,10 +75,10 @@ class BriaPipeline(FluxPipeline): def __init__( self, transformer: BriaTransformer2DModel, - scheduler: Union[FlowMatchEulerDiscreteScheduler,KarrasDiffusionSchedulers], + scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers], vae: AutoencoderKL, text_encoder: T5EncoderModel, - tokenizer: T5TokenizerFast + tokenizer: T5TokenizerFast, ): self.register_modules( vae=vae, @@ -96,15 +93,14 @@ def __init__( 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.default_sample_size = 64 # due to patchify=> 128,128 => res of 1k,1k + self.default_sample_size = 64 # due to patchify=> 128,128 => res of 1k,1k # T5 is senstive to precision so we use the precision used for precompute and cast as needed - + if self.vae.config.shift_factor is None: - self.vae.config.shift_factor=0 + self.vae.config.shift_factor = 0 self.vae.to(dtype=torch.float32) - def encode_prompt( self, prompt: Union[str, List[str]], @@ -169,7 +165,9 @@ def encode_prompt( if do_classifier_free_guidance and negative_prompt_embeds is None: if not is_ng_none(negative_prompt): - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt = ( + batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + ) if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( @@ -182,7 +180,7 @@ def encode_prompt( f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) - + negative_prompt_embeds = get_t5_prompt_embeds( self.tokenizer, self.text_encoder, @@ -192,7 +190,7 @@ def encode_prompt( device=device, ).to(dtype=self.transformer.dtype) else: - negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_prompt_embeds = torch.zeros_like(prompt_embeds) if self.text_encoder is not None: if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: @@ -209,7 +207,6 @@ def encode_prompt( def guidance_scale(self): return self._guidance_scale - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @@ -249,10 +246,10 @@ def __call__( return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, max_sequence_length: int = 128, - clip_value:Union[None,float] = None, - normalize:bool = False + clip_value: Union[None, float] = None, + normalize: bool = False, ): r""" Function invoked when calling the pipeline for generation. @@ -331,6 +328,9 @@ def __call__( width = width or self.default_sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct + callback_on_step_end_tensor_inputs = ( + ["latents"] if callback_on_step_end_tensor_inputs is None else callback_on_step_end_tensor_inputs + ) self.check_inputs( prompt=prompt, height=height, @@ -353,16 +353,10 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - - lora_scale = ( - self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None - ) - - ( - prompt_embeds, - negative_prompt_embeds, - text_ids - ) = self.encode_prompt( + + lora_scale = self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + + (prompt_embeds, negative_prompt_embeds, text_ids) = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, @@ -376,11 +370,9 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - - # 5. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4 + num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4 latents, latent_image_ids = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -392,11 +384,14 @@ def __call__( latents, ) - if isinstance(self.scheduler,FlowMatchEulerDiscreteScheduler) and self.scheduler.config['use_dynamic_shifting']: + if ( + isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler) + and self.scheduler.config["use_dynamic_shifting"] + ): sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = latents.shape[1] # Shift by height - Why just height? + image_seq_len = latents.shape[1] # Shift by height - Why just height? print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}") - + mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, @@ -415,19 +410,26 @@ def __call__( else: # 4. Prepare timesteps # Sample from training sigmas - if isinstance(self.scheduler,DDIMScheduler) or isinstance(self.scheduler,EulerAncestralDiscreteScheduler): - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, None) + if isinstance(self.scheduler, DDIMScheduler) or isinstance(self.scheduler, EulerAncestralDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, None, None + ) else: - sigmas = get_original_sigmas(num_train_timesteps=self.scheduler.config.num_train_timesteps,num_inference_steps=num_inference_steps) - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps,sigmas=sigmas) - + sigmas = get_original_sigmas( + num_train_timesteps=self.scheduler.config.num_train_timesteps, + num_inference_steps=num_inference_steps, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) # Supprot different diffusers versions - if diffusers.__version__>='0.32.0': - latent_image_ids=latent_image_ids[0] - text_ids=text_ids[0] + if diffusers.__version__ >= "0.32.0": + latent_image_ids = latent_image_ids[0] + text_ids = text_ids[0] # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -437,7 +439,7 @@ def __call__( # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - if type(self.scheduler)!=FlowMatchEulerDiscreteScheduler: + if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML @@ -453,7 +455,7 @@ def __call__( txt_ids=text_ids, img_ids=latent_image_ids, )[0] - + # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) @@ -461,16 +463,16 @@ def __call__( noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if normalize: - noise_pred = noise_pred * (0.7 *(cfg_noise_pred_text/noise_pred.std())) + 0.3 * noise_pred + noise_pred = noise_pred * (0.7 * (cfg_noise_pred_text / noise_pred.std())) + 0.3 * noise_pred if clip_value: - assert clip_value>0 - noise_pred = noise_pred.clip(-clip_value,clip_value) - + assert clip_value > 0 + noise_pred = noise_pred.clip(-clip_value, clip_value) + # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - + if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 @@ -485,14 +487,11 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if XLA_AVAILABLE: - xm.mark_step() - if output_type == "latent": image = latents @@ -509,7 +508,7 @@ def __call__( return (image,) return FluxPipelineOutput(images=image) - + def check_inputs( self, prompt, @@ -548,7 +547,6 @@ def check_inputs( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) - if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") @@ -560,13 +558,11 @@ def to(self, *args, **kwargs): for block in self.text_encoder.encoder.block: block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32) - if self.vae.config.shift_factor == 0 and self.vae.dtype!=torch.float32: + if self.vae.config.shift_factor == 0 and self.vae.dtype != torch.float32: self.vae.to(dtype=torch.float32) - return self - def prepare_latents( self, batch_size, @@ -581,7 +577,7 @@ def prepare_latents( # VAE applies 8x compression on images but we must also account for packing which requires # latent height and width to be divisible by 2. height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor ) + width = 2 * (int(width) // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height, width) @@ -623,7 +619,7 @@ def _unpack_latents(latents, height, width, vae_scale_factor): latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) return latents - + @staticmethod def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_ids = torch.zeros(height, width, 3) @@ -631,17 +627,10 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - + latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1) latent_image_ids = latent_image_ids.reshape( batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels ) return latent_image_ids.to(device=device, dtype=dtype) - - - - - - - diff --git a/invokeai/backend/bria/pipeline_bria_controlnet.py b/invokeai/backend/bria/pipeline_bria_controlnet.py index fb80fce3bff..5be4dfc06af 100644 --- a/invokeai/backend/bria/pipeline_bria_controlnet.py +++ b/invokeai/backend/bria/pipeline_bria_controlnet.py @@ -13,31 +13,27 @@ # limitations under the License. from typing import Any, Callable, Dict, List, Optional, Union -from diffusers.utils.peft_utils import scale_lora_layers, unscale_lora_layers + +import diffusers +import numpy as np import torch +from diffusers import AutoencoderKL # Waiting for diffusers udpdate +from diffusers.image_processor import PipelineImageInput +from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps +from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers +from diffusers.utils import USE_PEFT_BACKEND, logging +from diffusers.utils.peft_utils import scale_lora_layers, unscale_lora_layers +from diffusers.utils.torch_utils import randn_tensor from transformers import ( T5EncoderModel, T5TokenizerFast, ) -from diffusers.image_processor import PipelineImageInput -from diffusers import AutoencoderKL # Waiting for diffusers udpdate -from diffusers.schedulers import FlowMatchEulerDiscreteScheduler -from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.utils import logging, USE_PEFT_BACKEND -from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput -from invokeai.backend.bria.controlnet_bria import BriaControlNetModel, BriaMultiControlNetModel -from diffusers.pipelines.flux.pipeline_flux import retrieve_timesteps, calculate_shift +from invokeai.backend.bria.bria_utils import get_original_sigmas, get_t5_prompt_embeds, is_ng_none +from invokeai.backend.bria.controlnet_bria import BriaControlNetModel from invokeai.backend.bria.pipeline_bria import BriaPipeline from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel -from invokeai.backend.bria.bria_utils import get_original_sigmas -import numpy as np -import diffusers -from invokeai.backend.bria.bria_utils import get_t5_prompt_embeds, get_original_sigmas, is_ng_none -from diffusers.utils.torch_utils import randn_tensor - -XLA_AVAILABLE = False - logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -148,10 +144,12 @@ def prepare_control(self, control_image, width, height, batch_size, num_images_p return control_image, control_mode - def prepare_multi_control(self, control_image, width, height, batch_size, num_images_per_prompt, device, control_mode): + def prepare_multi_control( + self, control_image, width, height, batch_size, num_images_per_prompt, device, control_mode + ): num_channels_latents = self.transformer.config.in_channels // 4 control_images = [] - for i, control_image_ in enumerate(control_image): + for _, control_image_ in enumerate(control_image): control_image_ = self.prepare_image( image=control_image_, width=width, @@ -198,13 +196,13 @@ def prepare_multi_control(self, control_image, width, height, batch_size, num_im control_mode = control_modes return control_image, control_mode - + def get_controlnet_keep(self, timesteps, control_guidance_start, control_guidance_end): controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) - for s, e in zip(control_guidance_start, control_guidance_end) + for s, e in zip(control_guidance_start, control_guidance_end, strict=False) ] controlnet_keep.append(keeps[0] if isinstance(self.controlnet, BriaControlNetModel) else keeps) return controlnet_keep @@ -249,7 +247,7 @@ def __call__( return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, max_sequence_length: int = 128, ): r""" @@ -329,6 +327,9 @@ def __call__( ) # 1. Check inputs. Raise error if not correct + callback_on_step_end_tensor_inputs = ( + ["latents"] if callback_on_step_end_tensor_inputs is None else callback_on_step_end_tensor_inputs + ) self.check_inputs( prompt, height, @@ -346,24 +347,25 @@ def __call__( device = self._execution_device - # 4. Prepare timesteps - if isinstance(self.scheduler,FlowMatchEulerDiscreteScheduler) and self.scheduler.config['use_dynamic_shifting']: + if ( + isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler) + and self.scheduler.config["use_dynamic_shifting"] + ): sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - - + # Determine image sequence length if control_image is not None: - if type(control_image) == list: - image_seq_len = control_image[0].shape[1] + if isinstance(control_image, list): + image_seq_len = control_image[0].shape[1] else: - image_seq_len = control_image.shape[1] + image_seq_len = control_image.shape[1] else: # Use latents sequence length when no control image is provided image_seq_len = latents.shape[1] - + print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}") - + mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, @@ -387,9 +389,9 @@ def __call__( timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas ) - + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - self._num_timesteps = len(timesteps) + self._num_timesteps = len(timesteps) # 6. Create tensor stating which controlnets to keep if control_image is not None: @@ -399,13 +401,13 @@ def __call__( control_guidance_end=control_guidance_end, ) - if diffusers.__version__>='0.32.0': - latent_image_ids=latent_image_ids[0] - text_ids=text_ids[0] - + if diffusers.__version__ >= "0.32.0": + latent_image_ids = latent_image_ids[0] + text_ids = text_ids[0] + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - + # EYAL - added the CFG loop # 7. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -428,13 +430,15 @@ def __call__( if isinstance(controlnet_conditioning_scale, list): cond_scale = controlnet_conditioning_scale else: - cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + cond_scale = [ + c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i], strict=False) + ] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] - + controlnet_block_samples, controlnet_single_block_samples = self.controlnet( hidden_states=latents, controlnet_cond=control_image, @@ -492,9 +496,6 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if XLA_AVAILABLE: - xm.mark_step() - if output_type == "latent": image = latents @@ -514,17 +515,17 @@ def __call__( def encode_prompt( - prompt: Union[str, List[str]], - tokenizer: T5TokenizerFast, - text_encoder: T5EncoderModel, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - negative_prompt: Optional[Union[str, List[str]]] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 128, - lora_scale: Optional[float] = None, - ): + prompt: Union[str, List[str]], + tokenizer: T5TokenizerFast, + text_encoder: T5EncoderModel, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 128, + lora_scale: Optional[float] = None, +): r""" Args: @@ -561,7 +562,7 @@ def encode_prompt( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] - + dtype = text_encoder.dtype if text_encoder is not None else torch.float32 if prompt_embeds is None: prompt_embeds = get_t5_prompt_embeds( @@ -588,7 +589,7 @@ def encode_prompt( f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) - + negative_prompt_embeds = get_t5_prompt_embeds( tokenizer, text_encoder, @@ -598,7 +599,7 @@ def encode_prompt( device=device, ).to(dtype=dtype) else: - negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_prompt_embeds = torch.zeros_like(prompt_embeds) if text_encoder is not None: if USE_PEFT_BACKEND: @@ -625,7 +626,7 @@ def prepare_latents( # latent height and width to be divisible by 2. vae_scale_factor = 16 height = 2 * (int(height) // vae_scale_factor) - width = 2 * (int(width) // vae_scale_factor ) + width = 2 * (int(width) // vae_scale_factor) shape = (batch_size, num_channels_latents, height, width) @@ -653,7 +654,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - + latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1) latent_image_ids = latent_image_ids.reshape( batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels @@ -662,11 +663,9 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): return latent_image_ids.to(device=device, dtype=dtype) - - def _pack_latents(latents, batch_size, num_channels_latents, height, width): latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - return latents \ No newline at end of file + return latents diff --git a/invokeai/backend/bria/transformer_bria.py b/invokeai/backend/bria/transformer_bria.py index d19d11dc496..c454ddd0bd9 100644 --- a/invokeai/backend/bria/transformer_bria.py +++ b/invokeai/backend/bria/transformer_bria.py @@ -3,7 +3,6 @@ import numpy as np import torch import torch.nn as nn -from .bria_utils import FluxPosEmbed as EmbedND from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin from diffusers.models.embeddings import TimestepEmbedding, get_timestep_embedding @@ -13,6 +12,8 @@ from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from invokeai.backend.bria.bria_utils import FluxPosEmbed as EmbedND + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -94,7 +95,7 @@ def __init__( joint_attention_dim: int = 4096, pooled_projection_dim: int = None, guidance_embeds: bool = False, - axes_dims_rope: List[int] = [16, 56, 56], + axes_dims_rope: Optional[List[int]] = None, rope_theta=10000, time_theta=10000, ): @@ -102,6 +103,7 @@ def __init__( self.out_channels = in_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + axes_dims_rope = [16, 56, 56] if axes_dims_rope is None else axes_dims_rope self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope) self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta) diff --git a/invokeai/backend/model_manager/load/model_loaders/bria.py b/invokeai/backend/model_manager/load/model_loaders/bria.py index c5d6ec6f433..09771551139 100644 --- a/invokeai/backend/model_manager/load/model_loaders/bria.py +++ b/invokeai/backend/model_manager/load/model_loaders/bria.py @@ -4,9 +4,9 @@ from invokeai.backend.model_manager.config import ( AnyModelConfig, CheckpointConfigBase, - DiffusersConfigBase, - ControlNetDiffusersConfig, ControlNetCheckpointConfig, + ControlNetDiffusersConfig, + DiffusersConfigBase, ) from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader @@ -56,6 +56,7 @@ def _load_model( return result + @ModelLoaderRegistry.register(base=BaseModelType.Bria, type=ModelType.Main, format=ModelFormat.Diffusers) class BriaDiffusersModel(GenericDiffusersLoader): """Class to load Bria main models.""" From 711a579945e670b3e6f948a7b4a7eb439cc8e7be Mon Sep 17 00:00:00 2001 From: Ilan Tchenak Date: Thu, 24 Jul 2025 19:19:39 +0300 Subject: [PATCH 13/14] fixed schema --- .../frontend/web/src/services/api/schema.ts | 437 +++++++++++++++++- 1 file changed, 429 insertions(+), 8 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 2117ca782e5..a4d5b7f2e9f 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -3015,6 +3015,421 @@ export type components = { */ type: "bounding_box_output"; }; + /** BriaControlNetField */ + BriaControlNetField: { + /** @description The control image */ + image: components["schemas"]["ImageField"]; + /** @description The ControlNet model to use */ + model: components["schemas"]["ModelIdentifierField"]; + /** + * Mode + * @description The mode of the ControlNet + * @enum {string} + */ + mode: "depth" | "canny" | "colorgrid" | "recolor" | "tile" | "pose"; + /** + * Conditioning Scale + * @description The weight given to the ControlNet + */ + conditioning_scale: number; + }; + /** + * ControlNet - Bria + * @description Collect Bria ControlNet info to pass to denoiser node. + */ + BriaControlNetInvocation: { + /** + * @description The board to save the image to + * @default null + */ + board?: components["schemas"]["BoardField"] | null; + /** + * @description Optional metadata to be saved with the image + * @default null + */ + metadata?: components["schemas"]["MetadataField"] | null; + /** + * Id + * @description The id of this instance of an invocation. Must be unique among all instances of invocations. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this is an intermediate invocation. + * @default false + */ + is_intermediate?: boolean; + /** + * Use Cache + * @description Whether or not to use the cache + * @default true + */ + use_cache?: boolean; + /** + * @description The control image + * @default null + */ + control_image?: components["schemas"]["ImageField"] | null; + /** + * @description ControlNet model to load + * @default null + */ + control_model?: components["schemas"]["ModelIdentifierField"] | null; + /** + * Control Mode + * @description The mode of the ControlNet + * @default depth + * @enum {string} + */ + control_mode?: "depth" | "canny" | "colorgrid" | "recolor" | "tile" | "pose"; + /** + * Control Weight + * @description The weight given to the ControlNet + * @default 1 + */ + control_weight?: number; + /** + * type + * @default bria_controlnet + * @constant + */ + type: "bria_controlnet"; + }; + /** + * BriaControlNetOutput + * @description Bria ControlNet info + */ + BriaControlNetOutput: { + /** @description ControlNet(s) to apply */ + control: components["schemas"]["BriaControlNetField"]; + /** @description The preprocessed control image */ + preprocessed_images: components["schemas"]["ImageField"]; + /** + * type + * @default bria_controlnet_output + * @constant + */ + type: "bria_controlnet_output"; + }; + /** Decoder - Bria */ + BriaDecoderInvocation: { + /** + * Id + * @description The id of this instance of an invocation. Must be unique among all instances of invocations. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this is an intermediate invocation. + * @default false + */ + is_intermediate?: boolean; + /** + * Use Cache + * @description Whether or not to use the cache + * @default true + */ + use_cache?: boolean; + /** + * @description VAE + * @default null + */ + vae?: components["schemas"]["VAEField"] | null; + /** + * @description Latents tensor + * @default null + */ + latents?: components["schemas"]["LatentsField"] | null; + /** + * type + * @default bria_decoder + * @constant + */ + type: "bria_decoder"; + }; + /** Denoise - Bria */ + BriaDenoiseInvocation: { + /** + * Id + * @description The id of this instance of an invocation. Must be unique among all instances of invocations. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this is an intermediate invocation. + * @default false + */ + is_intermediate?: boolean; + /** + * Use Cache + * @description Whether or not to use the cache + * @default true + */ + use_cache?: boolean; + /** + * Number of Steps + * @description The number of steps to use for the denoiser + * @default 30 + */ + num_steps?: number; + /** + * Guidance Scale + * @description The guidance scale to use for the denoiser + * @default 5 + */ + guidance_scale?: number; + /** + * Transformer + * @description Bria model (Transformer) to load + * @default null + */ + transformer?: components["schemas"]["TransformerField"] | null; + /** + * T5Encoder + * @description T5 tokenizer and text encoder + * @default null + */ + t5_encoder?: components["schemas"]["T5EncoderField"] | null; + /** + * VAE + * @description VAE + * @default null + */ + vae?: components["schemas"]["VAEField"] | null; + /** + * Latents + * @description Latents to denoise + * @default null + */ + latents?: components["schemas"]["LatentsField"] | null; + /** + * Latent Image IDs + * @description Latent Image IDs to denoise + * @default null + */ + latent_image_ids?: components["schemas"]["LatentsField"] | null; + /** + * Positive Prompt Embeds + * @description Positive Prompt Embeds + * @default null + */ + pos_embeds?: components["schemas"]["LatentsField"] | null; + /** + * Negative Prompt Embeds + * @description Negative Prompt Embeds + * @default null + */ + neg_embeds?: components["schemas"]["LatentsField"] | null; + /** + * Text IDs + * @description Text IDs + * @default null + */ + text_ids?: components["schemas"]["LatentsField"] | null; + /** + * ControlNet + * @description ControlNet + * @default null + */ + control?: components["schemas"]["BriaControlNetField"] | components["schemas"]["BriaControlNetField"][] | null; + /** + * type + * @default bria_denoise + * @constant + */ + type: "bria_denoise"; + }; + /** BriaDenoiseInvocationOutput */ + BriaDenoiseInvocationOutput: { + /** @description Latents tensor */ + latents: components["schemas"]["LatentsField"]; + /** + * type + * @default bria_denoise_output + * @constant + */ + type: "bria_denoise_output"; + }; + /** Latent Sampler - Bria */ + BriaLatentSamplerInvocation: { + /** + * Id + * @description The id of this instance of an invocation. Must be unique among all instances of invocations. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this is an intermediate invocation. + * @default false + */ + is_intermediate?: boolean; + /** + * Use Cache + * @description Whether or not to use the cache + * @default true + */ + use_cache?: boolean; + /** + * Seed + * @description The seed to use for the latent sampler + * @default 42 + */ + seed?: number; + /** + * Transformer + * @description Bria model (Transformer) to load + * @default null + */ + transformer?: components["schemas"]["TransformerField"] | null; + /** + * type + * @default bria_latent_sampler + * @constant + */ + type: "bria_latent_sampler"; + }; + /** + * BriaLatentSamplerInvocationOutput + * @description Base class for nodes that output a CogView text conditioning tensor. + */ + BriaLatentSamplerInvocationOutput: { + /** @description Conditioning tensor */ + latents: components["schemas"]["LatentsField"]; + /** @description Conditioning tensor */ + latent_image_ids: components["schemas"]["LatentsField"]; + /** + * type + * @default bria_latent_sampler_output + * @constant + */ + type: "bria_latent_sampler_output"; + }; + /** + * Main Model - Bria + * @description Loads a bria base model, outputting its submodels. + */ + BriaModelLoaderInvocation: { + /** + * Id + * @description The id of this instance of an invocation. Must be unique among all instances of invocations. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this is an intermediate invocation. + * @default false + */ + is_intermediate?: boolean; + /** + * Use Cache + * @description Whether or not to use the cache + * @default true + */ + use_cache?: boolean; + /** @description Bria model (Transformer) to load */ + model: components["schemas"]["ModelIdentifierField"]; + /** + * type + * @default bria_model_loader + * @constant + */ + type: "bria_model_loader"; + }; + /** + * BriaModelLoaderOutput + * @description Bria base model loader output + */ + BriaModelLoaderOutput: { + /** + * Transformer + * @description Transformer + */ + transformer: components["schemas"]["TransformerField"]; + /** + * T5 Encoder + * @description T5 tokenizer and text encoder + */ + t5_encoder: components["schemas"]["T5EncoderField"]; + /** + * VAE + * @description VAE + */ + vae: components["schemas"]["VAEField"]; + /** + * type + * @default bria_model_loader_output + * @constant + */ + type: "bria_model_loader_output"; + }; + /** Prompt - Bria */ + BriaTextEncoderInvocation: { + /** + * Id + * @description The id of this instance of an invocation. Must be unique among all instances of invocations. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this is an intermediate invocation. + * @default false + */ + is_intermediate?: boolean; + /** + * Use Cache + * @description Whether or not to use the cache + * @default true + */ + use_cache?: boolean; + /** + * Prompt + * @description The prompt to encode + * @default null + */ + prompt?: string | null; + /** + * Negative Prompt + * @description The negative prompt to encode + * @default Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate + */ + negative_prompt?: string | null; + /** + * Max Length + * @description The maximum length of the prompt + * @default 128 + */ + max_length?: number; + /** + * T5Encoder + * @description T5 tokenizer and text encoder + * @default null + */ + t5_encoder?: components["schemas"]["T5EncoderField"] | null; + /** + * type + * @default bria_text_encoder + * @constant + */ + type: "bria_text_encoder"; + }; + /** + * BriaTextEncoderInvocationOutput + * @description Base class for nodes that output a CogView text conditioning tensor. + */ + BriaTextEncoderInvocationOutput: { + /** @description Conditioning tensor */ + pos_embeds: components["schemas"]["LatentsField"]; + /** @description Conditioning tensor */ + neg_embeds: components["schemas"]["LatentsField"]; + /** @description Conditioning tensor */ + text_ids: components["schemas"]["LatentsField"]; + /** + * type + * @default bria_text_encoder_output + * @constant + */ + type: "bria_text_encoder_output"; + }; /** * BulkDownloadCompleteEvent * @description Event model for bulk_download_complete @@ -8943,7 +9358,7 @@ export type components = { * @description The nodes in this graph */ nodes?: { - [key: string]: components["schemas"]["AddInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["ApplyMaskTensorToImageInvocation"] | components["schemas"]["ApplyMaskToImageInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["BoundingBoxInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["CannyEdgeDetectionInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["CanvasV2MaskAndCropInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CogView4DenoiseInvocation"] | components["schemas"]["CogView4ImageToLatentsInvocation"] | components["schemas"]["CogView4LatentsToImageInvocation"] | components["schemas"]["CogView4ModelLoaderInvocation"] | components["schemas"]["CogView4TextEncoderInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ColorMapInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ContentShuffleInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["CropImageToBoundingBoxInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["DWOpenposeDetectionInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DenoiseLatentsMetaInvocation"] | components["schemas"]["DepthAnythingDepthEstimationInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["ExpandMaskWithFadeInvocation"] | components["schemas"]["FLUXLoRACollectionLoader"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatBatchInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FloatGenerator"] | components["schemas"]["FloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["FluxControlLoRALoaderInvocation"] | components["schemas"]["FluxControlNetInvocation"] | components["schemas"]["FluxDenoiseInvocation"] | components["schemas"]["FluxDenoiseLatentsMetaInvocation"] | components["schemas"]["FluxFillInvocation"] | components["schemas"]["FluxIPAdapterInvocation"] | components["schemas"]["FluxKontextInvocation"] | components["schemas"]["FluxLoRALoaderInvocation"] | components["schemas"]["FluxModelLoaderInvocation"] | components["schemas"]["FluxReduxInvocation"] | components["schemas"]["FluxTextEncoderInvocation"] | components["schemas"]["FluxVaeDecodeInvocation"] | components["schemas"]["FluxVaeEncodeInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["GetMaskBoundingBoxInvocation"] | components["schemas"]["GroundingDinoInvocation"] | components["schemas"]["HEDEdgeDetectionInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageBatchInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageGenerator"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageNoiseInvocation"] | components["schemas"]["ImagePanelLayoutInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerBatchInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IntegerGenerator"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["InvokeAdjustImageHuePlusInvocation"] | components["schemas"]["InvokeEquivalentAchromaticLightnessInvocation"] | components["schemas"]["InvokeImageBlendInvocation"] | components["schemas"]["InvokeImageCompositorInvocation"] | components["schemas"]["InvokeImageDilateOrErodeInvocation"] | components["schemas"]["InvokeImageEnhanceInvocation"] | components["schemas"]["InvokeImageValueThresholdsInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["LineartAnimeEdgeDetectionInvocation"] | components["schemas"]["LineartEdgeDetectionInvocation"] | components["schemas"]["LlavaOnevisionVllmInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MLSDDetectionInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["MaskTensorToImageInvocation"] | components["schemas"]["MediaPipeFaceDetectionInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MetadataFieldExtractorInvocation"] | components["schemas"]["MetadataFromImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MetadataItemLinkedInvocation"] | components["schemas"]["MetadataToBoolCollectionInvocation"] | components["schemas"]["MetadataToBoolInvocation"] | components["schemas"]["MetadataToControlnetsInvocation"] | components["schemas"]["MetadataToFloatCollectionInvocation"] | components["schemas"]["MetadataToFloatInvocation"] | components["schemas"]["MetadataToIPAdaptersInvocation"] | components["schemas"]["MetadataToIntegerCollectionInvocation"] | components["schemas"]["MetadataToIntegerInvocation"] | components["schemas"]["MetadataToLorasCollectionInvocation"] | components["schemas"]["MetadataToLorasInvocation"] | components["schemas"]["MetadataToModelInvocation"] | components["schemas"]["MetadataToSDXLLorasInvocation"] | components["schemas"]["MetadataToSDXLModelInvocation"] | components["schemas"]["MetadataToSchedulerInvocation"] | components["schemas"]["MetadataToStringCollectionInvocation"] | components["schemas"]["MetadataToStringInvocation"] | components["schemas"]["MetadataToT2IAdaptersInvocation"] | components["schemas"]["MetadataToVAEInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["NormalMapInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["PasteImageIntoBoundingBoxInvocation"] | components["schemas"]["PiDiNetEdgeDetectionInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["SD3DenoiseInvocation"] | components["schemas"]["SD3ImageToLatentsInvocation"] | components["schemas"]["SD3LatentsToImageInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["Sd3ModelLoaderInvocation"] | components["schemas"]["Sd3TextEncoderInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SegmentAnythingInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["SpandrelImageToImageAutoscaleInvocation"] | components["schemas"]["SpandrelImageToImageInvocation"] | components["schemas"]["StringBatchInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringGenerator"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["TiledMultiDiffusionDenoiseLatents"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["VAELoaderInvocation"]; + [key: string]: components["schemas"]["AddInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["ApplyMaskTensorToImageInvocation"] | components["schemas"]["ApplyMaskToImageInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["BoundingBoxInvocation"] | components["schemas"]["BriaControlNetInvocation"] | components["schemas"]["BriaDecoderInvocation"] | components["schemas"]["BriaDenoiseInvocation"] | components["schemas"]["BriaLatentSamplerInvocation"] | components["schemas"]["BriaModelLoaderInvocation"] | components["schemas"]["BriaTextEncoderInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["CannyEdgeDetectionInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["CanvasV2MaskAndCropInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CogView4DenoiseInvocation"] | components["schemas"]["CogView4ImageToLatentsInvocation"] | components["schemas"]["CogView4LatentsToImageInvocation"] | components["schemas"]["CogView4ModelLoaderInvocation"] | components["schemas"]["CogView4TextEncoderInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ColorMapInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ContentShuffleInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["CropImageToBoundingBoxInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["DWOpenposeDetectionInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DenoiseLatentsMetaInvocation"] | components["schemas"]["DepthAnythingDepthEstimationInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["ExpandMaskWithFadeInvocation"] | components["schemas"]["FLUXLoRACollectionLoader"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatBatchInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FloatGenerator"] | components["schemas"]["FloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["FluxControlLoRALoaderInvocation"] | components["schemas"]["FluxControlNetInvocation"] | components["schemas"]["FluxDenoiseInvocation"] | components["schemas"]["FluxDenoiseLatentsMetaInvocation"] | components["schemas"]["FluxFillInvocation"] | components["schemas"]["FluxIPAdapterInvocation"] | components["schemas"]["FluxKontextInvocation"] | components["schemas"]["FluxLoRALoaderInvocation"] | components["schemas"]["FluxModelLoaderInvocation"] | components["schemas"]["FluxReduxInvocation"] | components["schemas"]["FluxTextEncoderInvocation"] | components["schemas"]["FluxVaeDecodeInvocation"] | components["schemas"]["FluxVaeEncodeInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["GetMaskBoundingBoxInvocation"] | components["schemas"]["GroundingDinoInvocation"] | components["schemas"]["HEDEdgeDetectionInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageBatchInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageGenerator"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageNoiseInvocation"] | components["schemas"]["ImagePanelLayoutInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerBatchInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IntegerGenerator"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["InvokeAdjustImageHuePlusInvocation"] | components["schemas"]["InvokeEquivalentAchromaticLightnessInvocation"] | components["schemas"]["InvokeImageBlendInvocation"] | components["schemas"]["InvokeImageCompositorInvocation"] | components["schemas"]["InvokeImageDilateOrErodeInvocation"] | components["schemas"]["InvokeImageEnhanceInvocation"] | components["schemas"]["InvokeImageValueThresholdsInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["LineartAnimeEdgeDetectionInvocation"] | components["schemas"]["LineartEdgeDetectionInvocation"] | components["schemas"]["LlavaOnevisionVllmInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MLSDDetectionInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["MaskTensorToImageInvocation"] | components["schemas"]["MediaPipeFaceDetectionInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MetadataFieldExtractorInvocation"] | components["schemas"]["MetadataFromImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MetadataItemLinkedInvocation"] | components["schemas"]["MetadataToBoolCollectionInvocation"] | components["schemas"]["MetadataToBoolInvocation"] | components["schemas"]["MetadataToControlnetsInvocation"] | components["schemas"]["MetadataToFloatCollectionInvocation"] | components["schemas"]["MetadataToFloatInvocation"] | components["schemas"]["MetadataToIPAdaptersInvocation"] | components["schemas"]["MetadataToIntegerCollectionInvocation"] | components["schemas"]["MetadataToIntegerInvocation"] | components["schemas"]["MetadataToLorasCollectionInvocation"] | components["schemas"]["MetadataToLorasInvocation"] | components["schemas"]["MetadataToModelInvocation"] | components["schemas"]["MetadataToSDXLLorasInvocation"] | components["schemas"]["MetadataToSDXLModelInvocation"] | components["schemas"]["MetadataToSchedulerInvocation"] | components["schemas"]["MetadataToStringCollectionInvocation"] | components["schemas"]["MetadataToStringInvocation"] | components["schemas"]["MetadataToT2IAdaptersInvocation"] | components["schemas"]["MetadataToVAEInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["NormalMapInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["PasteImageIntoBoundingBoxInvocation"] | components["schemas"]["PiDiNetEdgeDetectionInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["SD3DenoiseInvocation"] | components["schemas"]["SD3ImageToLatentsInvocation"] | components["schemas"]["SD3LatentsToImageInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["Sd3ModelLoaderInvocation"] | components["schemas"]["Sd3TextEncoderInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SegmentAnythingInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["SpandrelImageToImageAutoscaleInvocation"] | components["schemas"]["SpandrelImageToImageInvocation"] | components["schemas"]["StringBatchInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringGenerator"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["TiledMultiDiffusionDenoiseLatents"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["VAELoaderInvocation"]; }; /** * Edges @@ -8980,7 +9395,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CogView4ConditioningOutput"] | components["schemas"]["CogView4ModelLoaderOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatGeneratorOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxConditioningCollectionOutput"] | components["schemas"]["FluxConditioningOutput"] | components["schemas"]["FluxControlLoRALoaderOutput"] | components["schemas"]["FluxControlNetOutput"] | components["schemas"]["FluxFillOutput"] | components["schemas"]["FluxKontextOutput"] | components["schemas"]["FluxLoRALoaderOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["FluxReduxOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageGeneratorOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["ImagePanelCoordinateOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerGeneratorOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsMetaOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MDControlListOutput"] | components["schemas"]["MDIPAdapterListOutput"] | components["schemas"]["MDT2IAdapterListOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["MetadataToLorasCollectionOutput"] | components["schemas"]["MetadataToModelOutput"] | components["schemas"]["MetadataToSDXLModelOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SD3ConditioningOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["Sd3ModelLoaderOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringGeneratorOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"]; + [key: string]: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["BriaControlNetOutput"] | components["schemas"]["BriaDenoiseInvocationOutput"] | components["schemas"]["BriaLatentSamplerInvocationOutput"] | components["schemas"]["BriaModelLoaderOutput"] | components["schemas"]["BriaTextEncoderInvocationOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CogView4ConditioningOutput"] | components["schemas"]["CogView4ModelLoaderOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatGeneratorOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxConditioningCollectionOutput"] | components["schemas"]["FluxConditioningOutput"] | components["schemas"]["FluxControlLoRALoaderOutput"] | components["schemas"]["FluxControlNetOutput"] | components["schemas"]["FluxFillOutput"] | components["schemas"]["FluxKontextOutput"] | components["schemas"]["FluxLoRALoaderOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["FluxReduxOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageGeneratorOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["ImagePanelCoordinateOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerGeneratorOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsMetaOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MDControlListOutput"] | components["schemas"]["MDIPAdapterListOutput"] | components["schemas"]["MDT2IAdapterListOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["MetadataToLorasCollectionOutput"] | components["schemas"]["MetadataToModelOutput"] | components["schemas"]["MetadataToSDXLModelOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SD3ConditioningOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["Sd3ModelLoaderOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringGeneratorOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"]; }; /** * Errors @@ -11756,7 +12171,7 @@ export type components = { * Invocation * @description The ID of the invocation */ - invocation: components["schemas"]["AddInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["ApplyMaskTensorToImageInvocation"] | components["schemas"]["ApplyMaskToImageInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["BoundingBoxInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["CannyEdgeDetectionInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["CanvasV2MaskAndCropInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CogView4DenoiseInvocation"] | components["schemas"]["CogView4ImageToLatentsInvocation"] | components["schemas"]["CogView4LatentsToImageInvocation"] | components["schemas"]["CogView4ModelLoaderInvocation"] | components["schemas"]["CogView4TextEncoderInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ColorMapInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ContentShuffleInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["CropImageToBoundingBoxInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["DWOpenposeDetectionInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DenoiseLatentsMetaInvocation"] | components["schemas"]["DepthAnythingDepthEstimationInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["ExpandMaskWithFadeInvocation"] | components["schemas"]["FLUXLoRACollectionLoader"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatBatchInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FloatGenerator"] | components["schemas"]["FloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["FluxControlLoRALoaderInvocation"] | components["schemas"]["FluxControlNetInvocation"] | components["schemas"]["FluxDenoiseInvocation"] | components["schemas"]["FluxDenoiseLatentsMetaInvocation"] | components["schemas"]["FluxFillInvocation"] | components["schemas"]["FluxIPAdapterInvocation"] | components["schemas"]["FluxKontextInvocation"] | components["schemas"]["FluxLoRALoaderInvocation"] | components["schemas"]["FluxModelLoaderInvocation"] | components["schemas"]["FluxReduxInvocation"] | components["schemas"]["FluxTextEncoderInvocation"] | components["schemas"]["FluxVaeDecodeInvocation"] | components["schemas"]["FluxVaeEncodeInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["GetMaskBoundingBoxInvocation"] | components["schemas"]["GroundingDinoInvocation"] | components["schemas"]["HEDEdgeDetectionInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageBatchInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageGenerator"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageNoiseInvocation"] | components["schemas"]["ImagePanelLayoutInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerBatchInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IntegerGenerator"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["InvokeAdjustImageHuePlusInvocation"] | components["schemas"]["InvokeEquivalentAchromaticLightnessInvocation"] | components["schemas"]["InvokeImageBlendInvocation"] | components["schemas"]["InvokeImageCompositorInvocation"] | components["schemas"]["InvokeImageDilateOrErodeInvocation"] | components["schemas"]["InvokeImageEnhanceInvocation"] | components["schemas"]["InvokeImageValueThresholdsInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["LineartAnimeEdgeDetectionInvocation"] | components["schemas"]["LineartEdgeDetectionInvocation"] | components["schemas"]["LlavaOnevisionVllmInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MLSDDetectionInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["MaskTensorToImageInvocation"] | components["schemas"]["MediaPipeFaceDetectionInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MetadataFieldExtractorInvocation"] | components["schemas"]["MetadataFromImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MetadataItemLinkedInvocation"] | components["schemas"]["MetadataToBoolCollectionInvocation"] | components["schemas"]["MetadataToBoolInvocation"] | components["schemas"]["MetadataToControlnetsInvocation"] | components["schemas"]["MetadataToFloatCollectionInvocation"] | components["schemas"]["MetadataToFloatInvocation"] | components["schemas"]["MetadataToIPAdaptersInvocation"] | components["schemas"]["MetadataToIntegerCollectionInvocation"] | components["schemas"]["MetadataToIntegerInvocation"] | components["schemas"]["MetadataToLorasCollectionInvocation"] | components["schemas"]["MetadataToLorasInvocation"] | components["schemas"]["MetadataToModelInvocation"] | components["schemas"]["MetadataToSDXLLorasInvocation"] | components["schemas"]["MetadataToSDXLModelInvocation"] | components["schemas"]["MetadataToSchedulerInvocation"] | components["schemas"]["MetadataToStringCollectionInvocation"] | components["schemas"]["MetadataToStringInvocation"] | components["schemas"]["MetadataToT2IAdaptersInvocation"] | components["schemas"]["MetadataToVAEInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["NormalMapInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["PasteImageIntoBoundingBoxInvocation"] | components["schemas"]["PiDiNetEdgeDetectionInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["SD3DenoiseInvocation"] | components["schemas"]["SD3ImageToLatentsInvocation"] | components["schemas"]["SD3LatentsToImageInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["Sd3ModelLoaderInvocation"] | components["schemas"]["Sd3TextEncoderInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SegmentAnythingInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["SpandrelImageToImageAutoscaleInvocation"] | components["schemas"]["SpandrelImageToImageInvocation"] | components["schemas"]["StringBatchInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringGenerator"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["TiledMultiDiffusionDenoiseLatents"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["VAELoaderInvocation"]; + invocation: components["schemas"]["AddInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["ApplyMaskTensorToImageInvocation"] | components["schemas"]["ApplyMaskToImageInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["BoundingBoxInvocation"] | components["schemas"]["BriaControlNetInvocation"] | components["schemas"]["BriaDecoderInvocation"] | components["schemas"]["BriaDenoiseInvocation"] | components["schemas"]["BriaLatentSamplerInvocation"] | components["schemas"]["BriaModelLoaderInvocation"] | components["schemas"]["BriaTextEncoderInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["CannyEdgeDetectionInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["CanvasV2MaskAndCropInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CogView4DenoiseInvocation"] | components["schemas"]["CogView4ImageToLatentsInvocation"] | components["schemas"]["CogView4LatentsToImageInvocation"] | components["schemas"]["CogView4ModelLoaderInvocation"] | components["schemas"]["CogView4TextEncoderInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ColorMapInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ContentShuffleInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["CropImageToBoundingBoxInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["DWOpenposeDetectionInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DenoiseLatentsMetaInvocation"] | components["schemas"]["DepthAnythingDepthEstimationInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["ExpandMaskWithFadeInvocation"] | components["schemas"]["FLUXLoRACollectionLoader"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatBatchInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FloatGenerator"] | components["schemas"]["FloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["FluxControlLoRALoaderInvocation"] | components["schemas"]["FluxControlNetInvocation"] | components["schemas"]["FluxDenoiseInvocation"] | components["schemas"]["FluxDenoiseLatentsMetaInvocation"] | components["schemas"]["FluxFillInvocation"] | components["schemas"]["FluxIPAdapterInvocation"] | components["schemas"]["FluxKontextInvocation"] | components["schemas"]["FluxLoRALoaderInvocation"] | components["schemas"]["FluxModelLoaderInvocation"] | components["schemas"]["FluxReduxInvocation"] | components["schemas"]["FluxTextEncoderInvocation"] | components["schemas"]["FluxVaeDecodeInvocation"] | components["schemas"]["FluxVaeEncodeInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["GetMaskBoundingBoxInvocation"] | components["schemas"]["GroundingDinoInvocation"] | components["schemas"]["HEDEdgeDetectionInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageBatchInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageGenerator"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageNoiseInvocation"] | components["schemas"]["ImagePanelLayoutInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerBatchInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IntegerGenerator"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["InvokeAdjustImageHuePlusInvocation"] | components["schemas"]["InvokeEquivalentAchromaticLightnessInvocation"] | components["schemas"]["InvokeImageBlendInvocation"] | components["schemas"]["InvokeImageCompositorInvocation"] | components["schemas"]["InvokeImageDilateOrErodeInvocation"] | components["schemas"]["InvokeImageEnhanceInvocation"] | components["schemas"]["InvokeImageValueThresholdsInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["LineartAnimeEdgeDetectionInvocation"] | components["schemas"]["LineartEdgeDetectionInvocation"] | components["schemas"]["LlavaOnevisionVllmInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MLSDDetectionInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["MaskTensorToImageInvocation"] | components["schemas"]["MediaPipeFaceDetectionInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MetadataFieldExtractorInvocation"] | components["schemas"]["MetadataFromImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MetadataItemLinkedInvocation"] | components["schemas"]["MetadataToBoolCollectionInvocation"] | components["schemas"]["MetadataToBoolInvocation"] | components["schemas"]["MetadataToControlnetsInvocation"] | components["schemas"]["MetadataToFloatCollectionInvocation"] | components["schemas"]["MetadataToFloatInvocation"] | components["schemas"]["MetadataToIPAdaptersInvocation"] | components["schemas"]["MetadataToIntegerCollectionInvocation"] | components["schemas"]["MetadataToIntegerInvocation"] | components["schemas"]["MetadataToLorasCollectionInvocation"] | components["schemas"]["MetadataToLorasInvocation"] | components["schemas"]["MetadataToModelInvocation"] | components["schemas"]["MetadataToSDXLLorasInvocation"] | components["schemas"]["MetadataToSDXLModelInvocation"] | components["schemas"]["MetadataToSchedulerInvocation"] | components["schemas"]["MetadataToStringCollectionInvocation"] | components["schemas"]["MetadataToStringInvocation"] | components["schemas"]["MetadataToT2IAdaptersInvocation"] | components["schemas"]["MetadataToVAEInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["NormalMapInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["PasteImageIntoBoundingBoxInvocation"] | components["schemas"]["PiDiNetEdgeDetectionInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["SD3DenoiseInvocation"] | components["schemas"]["SD3ImageToLatentsInvocation"] | components["schemas"]["SD3LatentsToImageInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["Sd3ModelLoaderInvocation"] | components["schemas"]["Sd3TextEncoderInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SegmentAnythingInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["SpandrelImageToImageAutoscaleInvocation"] | components["schemas"]["SpandrelImageToImageInvocation"] | components["schemas"]["StringBatchInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringGenerator"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["TiledMultiDiffusionDenoiseLatents"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["VAELoaderInvocation"]; /** * Invocation Source Id * @description The ID of the prepared invocation's source node @@ -11766,7 +12181,7 @@ export type components = { * Result * @description The result of the invocation */ - result: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CogView4ConditioningOutput"] | components["schemas"]["CogView4ModelLoaderOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatGeneratorOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxConditioningCollectionOutput"] | components["schemas"]["FluxConditioningOutput"] | components["schemas"]["FluxControlLoRALoaderOutput"] | components["schemas"]["FluxControlNetOutput"] | components["schemas"]["FluxFillOutput"] | components["schemas"]["FluxKontextOutput"] | components["schemas"]["FluxLoRALoaderOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["FluxReduxOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageGeneratorOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["ImagePanelCoordinateOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerGeneratorOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsMetaOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MDControlListOutput"] | components["schemas"]["MDIPAdapterListOutput"] | components["schemas"]["MDT2IAdapterListOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["MetadataToLorasCollectionOutput"] | components["schemas"]["MetadataToModelOutput"] | components["schemas"]["MetadataToSDXLModelOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SD3ConditioningOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["Sd3ModelLoaderOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringGeneratorOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"]; + result: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["BriaControlNetOutput"] | components["schemas"]["BriaDenoiseInvocationOutput"] | components["schemas"]["BriaLatentSamplerInvocationOutput"] | components["schemas"]["BriaModelLoaderOutput"] | components["schemas"]["BriaTextEncoderInvocationOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CogView4ConditioningOutput"] | components["schemas"]["CogView4ModelLoaderOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatGeneratorOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxConditioningCollectionOutput"] | components["schemas"]["FluxConditioningOutput"] | components["schemas"]["FluxControlLoRALoaderOutput"] | components["schemas"]["FluxControlNetOutput"] | components["schemas"]["FluxFillOutput"] | components["schemas"]["FluxKontextOutput"] | components["schemas"]["FluxLoRALoaderOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["FluxReduxOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageGeneratorOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["ImagePanelCoordinateOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerGeneratorOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsMetaOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MDControlListOutput"] | components["schemas"]["MDIPAdapterListOutput"] | components["schemas"]["MDT2IAdapterListOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["MetadataToLorasCollectionOutput"] | components["schemas"]["MetadataToModelOutput"] | components["schemas"]["MetadataToSDXLModelOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SD3ConditioningOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["Sd3ModelLoaderOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringGeneratorOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"]; }; /** * InvocationErrorEvent @@ -11814,7 +12229,7 @@ export type components = { * Invocation * @description The ID of the invocation */ - invocation: components["schemas"]["AddInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["ApplyMaskTensorToImageInvocation"] | components["schemas"]["ApplyMaskToImageInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["BoundingBoxInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["CannyEdgeDetectionInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["CanvasV2MaskAndCropInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CogView4DenoiseInvocation"] | components["schemas"]["CogView4ImageToLatentsInvocation"] | components["schemas"]["CogView4LatentsToImageInvocation"] | components["schemas"]["CogView4ModelLoaderInvocation"] | components["schemas"]["CogView4TextEncoderInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ColorMapInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ContentShuffleInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["CropImageToBoundingBoxInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["DWOpenposeDetectionInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DenoiseLatentsMetaInvocation"] | components["schemas"]["DepthAnythingDepthEstimationInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["ExpandMaskWithFadeInvocation"] | components["schemas"]["FLUXLoRACollectionLoader"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatBatchInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FloatGenerator"] | components["schemas"]["FloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["FluxControlLoRALoaderInvocation"] | components["schemas"]["FluxControlNetInvocation"] | components["schemas"]["FluxDenoiseInvocation"] | components["schemas"]["FluxDenoiseLatentsMetaInvocation"] | components["schemas"]["FluxFillInvocation"] | components["schemas"]["FluxIPAdapterInvocation"] | components["schemas"]["FluxKontextInvocation"] | components["schemas"]["FluxLoRALoaderInvocation"] | components["schemas"]["FluxModelLoaderInvocation"] | components["schemas"]["FluxReduxInvocation"] | components["schemas"]["FluxTextEncoderInvocation"] | components["schemas"]["FluxVaeDecodeInvocation"] | components["schemas"]["FluxVaeEncodeInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["GetMaskBoundingBoxInvocation"] | components["schemas"]["GroundingDinoInvocation"] | components["schemas"]["HEDEdgeDetectionInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageBatchInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageGenerator"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageNoiseInvocation"] | components["schemas"]["ImagePanelLayoutInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerBatchInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IntegerGenerator"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["InvokeAdjustImageHuePlusInvocation"] | components["schemas"]["InvokeEquivalentAchromaticLightnessInvocation"] | components["schemas"]["InvokeImageBlendInvocation"] | components["schemas"]["InvokeImageCompositorInvocation"] | components["schemas"]["InvokeImageDilateOrErodeInvocation"] | components["schemas"]["InvokeImageEnhanceInvocation"] | components["schemas"]["InvokeImageValueThresholdsInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["LineartAnimeEdgeDetectionInvocation"] | components["schemas"]["LineartEdgeDetectionInvocation"] | components["schemas"]["LlavaOnevisionVllmInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MLSDDetectionInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["MaskTensorToImageInvocation"] | components["schemas"]["MediaPipeFaceDetectionInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MetadataFieldExtractorInvocation"] | components["schemas"]["MetadataFromImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MetadataItemLinkedInvocation"] | components["schemas"]["MetadataToBoolCollectionInvocation"] | components["schemas"]["MetadataToBoolInvocation"] | components["schemas"]["MetadataToControlnetsInvocation"] | components["schemas"]["MetadataToFloatCollectionInvocation"] | components["schemas"]["MetadataToFloatInvocation"] | components["schemas"]["MetadataToIPAdaptersInvocation"] | components["schemas"]["MetadataToIntegerCollectionInvocation"] | components["schemas"]["MetadataToIntegerInvocation"] | components["schemas"]["MetadataToLorasCollectionInvocation"] | components["schemas"]["MetadataToLorasInvocation"] | components["schemas"]["MetadataToModelInvocation"] | components["schemas"]["MetadataToSDXLLorasInvocation"] | components["schemas"]["MetadataToSDXLModelInvocation"] | components["schemas"]["MetadataToSchedulerInvocation"] | components["schemas"]["MetadataToStringCollectionInvocation"] | components["schemas"]["MetadataToStringInvocation"] | components["schemas"]["MetadataToT2IAdaptersInvocation"] | components["schemas"]["MetadataToVAEInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["NormalMapInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["PasteImageIntoBoundingBoxInvocation"] | components["schemas"]["PiDiNetEdgeDetectionInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["SD3DenoiseInvocation"] | components["schemas"]["SD3ImageToLatentsInvocation"] | components["schemas"]["SD3LatentsToImageInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["Sd3ModelLoaderInvocation"] | components["schemas"]["Sd3TextEncoderInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SegmentAnythingInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["SpandrelImageToImageAutoscaleInvocation"] | components["schemas"]["SpandrelImageToImageInvocation"] | components["schemas"]["StringBatchInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringGenerator"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["TiledMultiDiffusionDenoiseLatents"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["VAELoaderInvocation"]; + invocation: components["schemas"]["AddInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["ApplyMaskTensorToImageInvocation"] | components["schemas"]["ApplyMaskToImageInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["BoundingBoxInvocation"] | components["schemas"]["BriaControlNetInvocation"] | components["schemas"]["BriaDecoderInvocation"] | components["schemas"]["BriaDenoiseInvocation"] | components["schemas"]["BriaLatentSamplerInvocation"] | components["schemas"]["BriaModelLoaderInvocation"] | components["schemas"]["BriaTextEncoderInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["CannyEdgeDetectionInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["CanvasV2MaskAndCropInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CogView4DenoiseInvocation"] | components["schemas"]["CogView4ImageToLatentsInvocation"] | components["schemas"]["CogView4LatentsToImageInvocation"] | components["schemas"]["CogView4ModelLoaderInvocation"] | components["schemas"]["CogView4TextEncoderInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ColorMapInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ContentShuffleInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["CropImageToBoundingBoxInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["DWOpenposeDetectionInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DenoiseLatentsMetaInvocation"] | components["schemas"]["DepthAnythingDepthEstimationInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["ExpandMaskWithFadeInvocation"] | components["schemas"]["FLUXLoRACollectionLoader"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatBatchInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FloatGenerator"] | components["schemas"]["FloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["FluxControlLoRALoaderInvocation"] | components["schemas"]["FluxControlNetInvocation"] | components["schemas"]["FluxDenoiseInvocation"] | components["schemas"]["FluxDenoiseLatentsMetaInvocation"] | components["schemas"]["FluxFillInvocation"] | components["schemas"]["FluxIPAdapterInvocation"] | components["schemas"]["FluxKontextInvocation"] | components["schemas"]["FluxLoRALoaderInvocation"] | components["schemas"]["FluxModelLoaderInvocation"] | components["schemas"]["FluxReduxInvocation"] | components["schemas"]["FluxTextEncoderInvocation"] | components["schemas"]["FluxVaeDecodeInvocation"] | components["schemas"]["FluxVaeEncodeInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["GetMaskBoundingBoxInvocation"] | components["schemas"]["GroundingDinoInvocation"] | components["schemas"]["HEDEdgeDetectionInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageBatchInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageGenerator"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageNoiseInvocation"] | components["schemas"]["ImagePanelLayoutInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerBatchInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IntegerGenerator"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["InvokeAdjustImageHuePlusInvocation"] | components["schemas"]["InvokeEquivalentAchromaticLightnessInvocation"] | components["schemas"]["InvokeImageBlendInvocation"] | components["schemas"]["InvokeImageCompositorInvocation"] | components["schemas"]["InvokeImageDilateOrErodeInvocation"] | components["schemas"]["InvokeImageEnhanceInvocation"] | components["schemas"]["InvokeImageValueThresholdsInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["LineartAnimeEdgeDetectionInvocation"] | components["schemas"]["LineartEdgeDetectionInvocation"] | components["schemas"]["LlavaOnevisionVllmInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MLSDDetectionInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["MaskTensorToImageInvocation"] | components["schemas"]["MediaPipeFaceDetectionInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MetadataFieldExtractorInvocation"] | components["schemas"]["MetadataFromImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MetadataItemLinkedInvocation"] | components["schemas"]["MetadataToBoolCollectionInvocation"] | components["schemas"]["MetadataToBoolInvocation"] | components["schemas"]["MetadataToControlnetsInvocation"] | components["schemas"]["MetadataToFloatCollectionInvocation"] | components["schemas"]["MetadataToFloatInvocation"] | components["schemas"]["MetadataToIPAdaptersInvocation"] | components["schemas"]["MetadataToIntegerCollectionInvocation"] | components["schemas"]["MetadataToIntegerInvocation"] | components["schemas"]["MetadataToLorasCollectionInvocation"] | components["schemas"]["MetadataToLorasInvocation"] | components["schemas"]["MetadataToModelInvocation"] | components["schemas"]["MetadataToSDXLLorasInvocation"] | components["schemas"]["MetadataToSDXLModelInvocation"] | components["schemas"]["MetadataToSchedulerInvocation"] | components["schemas"]["MetadataToStringCollectionInvocation"] | components["schemas"]["MetadataToStringInvocation"] | components["schemas"]["MetadataToT2IAdaptersInvocation"] | components["schemas"]["MetadataToVAEInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["NormalMapInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["PasteImageIntoBoundingBoxInvocation"] | components["schemas"]["PiDiNetEdgeDetectionInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["SD3DenoiseInvocation"] | components["schemas"]["SD3ImageToLatentsInvocation"] | components["schemas"]["SD3LatentsToImageInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["Sd3ModelLoaderInvocation"] | components["schemas"]["Sd3TextEncoderInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SegmentAnythingInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["SpandrelImageToImageAutoscaleInvocation"] | components["schemas"]["SpandrelImageToImageInvocation"] | components["schemas"]["StringBatchInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringGenerator"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["TiledMultiDiffusionDenoiseLatents"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["VAELoaderInvocation"]; /** * Invocation Source Id * @description The ID of the prepared invocation's source node @@ -11857,6 +12272,12 @@ export type components = { boolean: components["schemas"]["BooleanOutput"]; boolean_collection: components["schemas"]["BooleanCollectionOutput"]; bounding_box: components["schemas"]["BoundingBoxOutput"]; + bria_controlnet: components["schemas"]["BriaControlNetOutput"]; + bria_decoder: components["schemas"]["ImageOutput"]; + bria_denoise: components["schemas"]["BriaDenoiseInvocationOutput"]; + bria_latent_sampler: components["schemas"]["BriaLatentSamplerInvocationOutput"]; + bria_model_loader: components["schemas"]["BriaModelLoaderOutput"]; + bria_text_encoder: components["schemas"]["BriaTextEncoderInvocationOutput"]; calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"]; calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"]; calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"]; @@ -12106,7 +12527,7 @@ export type components = { * Invocation * @description The ID of the invocation */ - invocation: components["schemas"]["AddInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["ApplyMaskTensorToImageInvocation"] | components["schemas"]["ApplyMaskToImageInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["BoundingBoxInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["CannyEdgeDetectionInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["CanvasV2MaskAndCropInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CogView4DenoiseInvocation"] | components["schemas"]["CogView4ImageToLatentsInvocation"] | components["schemas"]["CogView4LatentsToImageInvocation"] | components["schemas"]["CogView4ModelLoaderInvocation"] | components["schemas"]["CogView4TextEncoderInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ColorMapInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ContentShuffleInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["CropImageToBoundingBoxInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["DWOpenposeDetectionInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DenoiseLatentsMetaInvocation"] | components["schemas"]["DepthAnythingDepthEstimationInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["ExpandMaskWithFadeInvocation"] | components["schemas"]["FLUXLoRACollectionLoader"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatBatchInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FloatGenerator"] | components["schemas"]["FloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["FluxControlLoRALoaderInvocation"] | components["schemas"]["FluxControlNetInvocation"] | components["schemas"]["FluxDenoiseInvocation"] | components["schemas"]["FluxDenoiseLatentsMetaInvocation"] | components["schemas"]["FluxFillInvocation"] | components["schemas"]["FluxIPAdapterInvocation"] | components["schemas"]["FluxKontextInvocation"] | components["schemas"]["FluxLoRALoaderInvocation"] | components["schemas"]["FluxModelLoaderInvocation"] | components["schemas"]["FluxReduxInvocation"] | components["schemas"]["FluxTextEncoderInvocation"] | components["schemas"]["FluxVaeDecodeInvocation"] | components["schemas"]["FluxVaeEncodeInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["GetMaskBoundingBoxInvocation"] | components["schemas"]["GroundingDinoInvocation"] | components["schemas"]["HEDEdgeDetectionInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageBatchInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageGenerator"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageNoiseInvocation"] | components["schemas"]["ImagePanelLayoutInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerBatchInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IntegerGenerator"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["InvokeAdjustImageHuePlusInvocation"] | components["schemas"]["InvokeEquivalentAchromaticLightnessInvocation"] | components["schemas"]["InvokeImageBlendInvocation"] | components["schemas"]["InvokeImageCompositorInvocation"] | components["schemas"]["InvokeImageDilateOrErodeInvocation"] | components["schemas"]["InvokeImageEnhanceInvocation"] | components["schemas"]["InvokeImageValueThresholdsInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["LineartAnimeEdgeDetectionInvocation"] | components["schemas"]["LineartEdgeDetectionInvocation"] | components["schemas"]["LlavaOnevisionVllmInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MLSDDetectionInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["MaskTensorToImageInvocation"] | components["schemas"]["MediaPipeFaceDetectionInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MetadataFieldExtractorInvocation"] | components["schemas"]["MetadataFromImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MetadataItemLinkedInvocation"] | components["schemas"]["MetadataToBoolCollectionInvocation"] | components["schemas"]["MetadataToBoolInvocation"] | components["schemas"]["MetadataToControlnetsInvocation"] | components["schemas"]["MetadataToFloatCollectionInvocation"] | components["schemas"]["MetadataToFloatInvocation"] | components["schemas"]["MetadataToIPAdaptersInvocation"] | components["schemas"]["MetadataToIntegerCollectionInvocation"] | components["schemas"]["MetadataToIntegerInvocation"] | components["schemas"]["MetadataToLorasCollectionInvocation"] | components["schemas"]["MetadataToLorasInvocation"] | components["schemas"]["MetadataToModelInvocation"] | components["schemas"]["MetadataToSDXLLorasInvocation"] | components["schemas"]["MetadataToSDXLModelInvocation"] | components["schemas"]["MetadataToSchedulerInvocation"] | components["schemas"]["MetadataToStringCollectionInvocation"] | components["schemas"]["MetadataToStringInvocation"] | components["schemas"]["MetadataToT2IAdaptersInvocation"] | components["schemas"]["MetadataToVAEInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["NormalMapInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["PasteImageIntoBoundingBoxInvocation"] | components["schemas"]["PiDiNetEdgeDetectionInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["SD3DenoiseInvocation"] | components["schemas"]["SD3ImageToLatentsInvocation"] | components["schemas"]["SD3LatentsToImageInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["Sd3ModelLoaderInvocation"] | components["schemas"]["Sd3TextEncoderInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SegmentAnythingInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["SpandrelImageToImageAutoscaleInvocation"] | components["schemas"]["SpandrelImageToImageInvocation"] | components["schemas"]["StringBatchInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringGenerator"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["TiledMultiDiffusionDenoiseLatents"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["VAELoaderInvocation"]; + invocation: components["schemas"]["AddInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["ApplyMaskTensorToImageInvocation"] | components["schemas"]["ApplyMaskToImageInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["BoundingBoxInvocation"] | components["schemas"]["BriaControlNetInvocation"] | components["schemas"]["BriaDecoderInvocation"] | components["schemas"]["BriaDenoiseInvocation"] | components["schemas"]["BriaLatentSamplerInvocation"] | components["schemas"]["BriaModelLoaderInvocation"] | components["schemas"]["BriaTextEncoderInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["CannyEdgeDetectionInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["CanvasV2MaskAndCropInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CogView4DenoiseInvocation"] | components["schemas"]["CogView4ImageToLatentsInvocation"] | components["schemas"]["CogView4LatentsToImageInvocation"] | components["schemas"]["CogView4ModelLoaderInvocation"] | components["schemas"]["CogView4TextEncoderInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ColorMapInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ContentShuffleInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["CropImageToBoundingBoxInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["DWOpenposeDetectionInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DenoiseLatentsMetaInvocation"] | components["schemas"]["DepthAnythingDepthEstimationInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["ExpandMaskWithFadeInvocation"] | components["schemas"]["FLUXLoRACollectionLoader"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatBatchInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FloatGenerator"] | components["schemas"]["FloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["FluxControlLoRALoaderInvocation"] | components["schemas"]["FluxControlNetInvocation"] | components["schemas"]["FluxDenoiseInvocation"] | components["schemas"]["FluxDenoiseLatentsMetaInvocation"] | components["schemas"]["FluxFillInvocation"] | components["schemas"]["FluxIPAdapterInvocation"] | components["schemas"]["FluxKontextInvocation"] | components["schemas"]["FluxLoRALoaderInvocation"] | components["schemas"]["FluxModelLoaderInvocation"] | components["schemas"]["FluxReduxInvocation"] | components["schemas"]["FluxTextEncoderInvocation"] | components["schemas"]["FluxVaeDecodeInvocation"] | components["schemas"]["FluxVaeEncodeInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["GetMaskBoundingBoxInvocation"] | components["schemas"]["GroundingDinoInvocation"] | components["schemas"]["HEDEdgeDetectionInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageBatchInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageGenerator"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageNoiseInvocation"] | components["schemas"]["ImagePanelLayoutInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerBatchInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IntegerGenerator"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["InvokeAdjustImageHuePlusInvocation"] | components["schemas"]["InvokeEquivalentAchromaticLightnessInvocation"] | components["schemas"]["InvokeImageBlendInvocation"] | components["schemas"]["InvokeImageCompositorInvocation"] | components["schemas"]["InvokeImageDilateOrErodeInvocation"] | components["schemas"]["InvokeImageEnhanceInvocation"] | components["schemas"]["InvokeImageValueThresholdsInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["LineartAnimeEdgeDetectionInvocation"] | components["schemas"]["LineartEdgeDetectionInvocation"] | components["schemas"]["LlavaOnevisionVllmInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MLSDDetectionInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["MaskTensorToImageInvocation"] | components["schemas"]["MediaPipeFaceDetectionInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MetadataFieldExtractorInvocation"] | components["schemas"]["MetadataFromImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MetadataItemLinkedInvocation"] | components["schemas"]["MetadataToBoolCollectionInvocation"] | components["schemas"]["MetadataToBoolInvocation"] | components["schemas"]["MetadataToControlnetsInvocation"] | components["schemas"]["MetadataToFloatCollectionInvocation"] | components["schemas"]["MetadataToFloatInvocation"] | components["schemas"]["MetadataToIPAdaptersInvocation"] | components["schemas"]["MetadataToIntegerCollectionInvocation"] | components["schemas"]["MetadataToIntegerInvocation"] | components["schemas"]["MetadataToLorasCollectionInvocation"] | components["schemas"]["MetadataToLorasInvocation"] | components["schemas"]["MetadataToModelInvocation"] | components["schemas"]["MetadataToSDXLLorasInvocation"] | components["schemas"]["MetadataToSDXLModelInvocation"] | components["schemas"]["MetadataToSchedulerInvocation"] | components["schemas"]["MetadataToStringCollectionInvocation"] | components["schemas"]["MetadataToStringInvocation"] | components["schemas"]["MetadataToT2IAdaptersInvocation"] | components["schemas"]["MetadataToVAEInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["NormalMapInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["PasteImageIntoBoundingBoxInvocation"] | components["schemas"]["PiDiNetEdgeDetectionInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["SD3DenoiseInvocation"] | components["schemas"]["SD3ImageToLatentsInvocation"] | components["schemas"]["SD3LatentsToImageInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["Sd3ModelLoaderInvocation"] | components["schemas"]["Sd3TextEncoderInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SegmentAnythingInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["SpandrelImageToImageAutoscaleInvocation"] | components["schemas"]["SpandrelImageToImageInvocation"] | components["schemas"]["StringBatchInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringGenerator"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["TiledMultiDiffusionDenoiseLatents"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["VAELoaderInvocation"]; /** * Invocation Source Id * @description The ID of the prepared invocation's source node @@ -12175,7 +12596,7 @@ export type components = { * Invocation * @description The ID of the invocation */ - invocation: components["schemas"]["AddInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["ApplyMaskTensorToImageInvocation"] | components["schemas"]["ApplyMaskToImageInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["BoundingBoxInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["CannyEdgeDetectionInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["CanvasV2MaskAndCropInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CogView4DenoiseInvocation"] | components["schemas"]["CogView4ImageToLatentsInvocation"] | components["schemas"]["CogView4LatentsToImageInvocation"] | components["schemas"]["CogView4ModelLoaderInvocation"] | components["schemas"]["CogView4TextEncoderInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ColorMapInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ContentShuffleInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["CropImageToBoundingBoxInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["DWOpenposeDetectionInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DenoiseLatentsMetaInvocation"] | components["schemas"]["DepthAnythingDepthEstimationInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["ExpandMaskWithFadeInvocation"] | components["schemas"]["FLUXLoRACollectionLoader"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatBatchInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FloatGenerator"] | components["schemas"]["FloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["FluxControlLoRALoaderInvocation"] | components["schemas"]["FluxControlNetInvocation"] | components["schemas"]["FluxDenoiseInvocation"] | components["schemas"]["FluxDenoiseLatentsMetaInvocation"] | components["schemas"]["FluxFillInvocation"] | components["schemas"]["FluxIPAdapterInvocation"] | components["schemas"]["FluxKontextInvocation"] | components["schemas"]["FluxLoRALoaderInvocation"] | components["schemas"]["FluxModelLoaderInvocation"] | components["schemas"]["FluxReduxInvocation"] | components["schemas"]["FluxTextEncoderInvocation"] | components["schemas"]["FluxVaeDecodeInvocation"] | components["schemas"]["FluxVaeEncodeInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["GetMaskBoundingBoxInvocation"] | components["schemas"]["GroundingDinoInvocation"] | components["schemas"]["HEDEdgeDetectionInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageBatchInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageGenerator"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageNoiseInvocation"] | components["schemas"]["ImagePanelLayoutInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerBatchInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IntegerGenerator"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["InvokeAdjustImageHuePlusInvocation"] | components["schemas"]["InvokeEquivalentAchromaticLightnessInvocation"] | components["schemas"]["InvokeImageBlendInvocation"] | components["schemas"]["InvokeImageCompositorInvocation"] | components["schemas"]["InvokeImageDilateOrErodeInvocation"] | components["schemas"]["InvokeImageEnhanceInvocation"] | components["schemas"]["InvokeImageValueThresholdsInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["LineartAnimeEdgeDetectionInvocation"] | components["schemas"]["LineartEdgeDetectionInvocation"] | components["schemas"]["LlavaOnevisionVllmInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MLSDDetectionInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["MaskTensorToImageInvocation"] | components["schemas"]["MediaPipeFaceDetectionInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MetadataFieldExtractorInvocation"] | components["schemas"]["MetadataFromImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MetadataItemLinkedInvocation"] | components["schemas"]["MetadataToBoolCollectionInvocation"] | components["schemas"]["MetadataToBoolInvocation"] | components["schemas"]["MetadataToControlnetsInvocation"] | components["schemas"]["MetadataToFloatCollectionInvocation"] | components["schemas"]["MetadataToFloatInvocation"] | components["schemas"]["MetadataToIPAdaptersInvocation"] | components["schemas"]["MetadataToIntegerCollectionInvocation"] | components["schemas"]["MetadataToIntegerInvocation"] | components["schemas"]["MetadataToLorasCollectionInvocation"] | components["schemas"]["MetadataToLorasInvocation"] | components["schemas"]["MetadataToModelInvocation"] | components["schemas"]["MetadataToSDXLLorasInvocation"] | components["schemas"]["MetadataToSDXLModelInvocation"] | components["schemas"]["MetadataToSchedulerInvocation"] | components["schemas"]["MetadataToStringCollectionInvocation"] | components["schemas"]["MetadataToStringInvocation"] | components["schemas"]["MetadataToT2IAdaptersInvocation"] | components["schemas"]["MetadataToVAEInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["NormalMapInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["PasteImageIntoBoundingBoxInvocation"] | components["schemas"]["PiDiNetEdgeDetectionInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["SD3DenoiseInvocation"] | components["schemas"]["SD3ImageToLatentsInvocation"] | components["schemas"]["SD3LatentsToImageInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["Sd3ModelLoaderInvocation"] | components["schemas"]["Sd3TextEncoderInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SegmentAnythingInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["SpandrelImageToImageAutoscaleInvocation"] | components["schemas"]["SpandrelImageToImageInvocation"] | components["schemas"]["StringBatchInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringGenerator"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["TiledMultiDiffusionDenoiseLatents"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["VAELoaderInvocation"]; + invocation: components["schemas"]["AddInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["ApplyMaskTensorToImageInvocation"] | components["schemas"]["ApplyMaskToImageInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["BoundingBoxInvocation"] | components["schemas"]["BriaControlNetInvocation"] | components["schemas"]["BriaDecoderInvocation"] | components["schemas"]["BriaDenoiseInvocation"] | components["schemas"]["BriaLatentSamplerInvocation"] | components["schemas"]["BriaModelLoaderInvocation"] | components["schemas"]["BriaTextEncoderInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["CannyEdgeDetectionInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["CanvasV2MaskAndCropInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CogView4DenoiseInvocation"] | components["schemas"]["CogView4ImageToLatentsInvocation"] | components["schemas"]["CogView4LatentsToImageInvocation"] | components["schemas"]["CogView4ModelLoaderInvocation"] | components["schemas"]["CogView4TextEncoderInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ColorMapInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ContentShuffleInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["CropImageToBoundingBoxInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["DWOpenposeDetectionInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DenoiseLatentsMetaInvocation"] | components["schemas"]["DepthAnythingDepthEstimationInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["ExpandMaskWithFadeInvocation"] | components["schemas"]["FLUXLoRACollectionLoader"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatBatchInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FloatGenerator"] | components["schemas"]["FloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["FluxControlLoRALoaderInvocation"] | components["schemas"]["FluxControlNetInvocation"] | components["schemas"]["FluxDenoiseInvocation"] | components["schemas"]["FluxDenoiseLatentsMetaInvocation"] | components["schemas"]["FluxFillInvocation"] | components["schemas"]["FluxIPAdapterInvocation"] | components["schemas"]["FluxKontextInvocation"] | components["schemas"]["FluxLoRALoaderInvocation"] | components["schemas"]["FluxModelLoaderInvocation"] | components["schemas"]["FluxReduxInvocation"] | components["schemas"]["FluxTextEncoderInvocation"] | components["schemas"]["FluxVaeDecodeInvocation"] | components["schemas"]["FluxVaeEncodeInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["GetMaskBoundingBoxInvocation"] | components["schemas"]["GroundingDinoInvocation"] | components["schemas"]["HEDEdgeDetectionInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageBatchInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageGenerator"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageNoiseInvocation"] | components["schemas"]["ImagePanelLayoutInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerBatchInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IntegerGenerator"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["InvokeAdjustImageHuePlusInvocation"] | components["schemas"]["InvokeEquivalentAchromaticLightnessInvocation"] | components["schemas"]["InvokeImageBlendInvocation"] | components["schemas"]["InvokeImageCompositorInvocation"] | components["schemas"]["InvokeImageDilateOrErodeInvocation"] | components["schemas"]["InvokeImageEnhanceInvocation"] | components["schemas"]["InvokeImageValueThresholdsInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["LineartAnimeEdgeDetectionInvocation"] | components["schemas"]["LineartEdgeDetectionInvocation"] | components["schemas"]["LlavaOnevisionVllmInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MLSDDetectionInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["MaskTensorToImageInvocation"] | components["schemas"]["MediaPipeFaceDetectionInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MetadataFieldExtractorInvocation"] | components["schemas"]["MetadataFromImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MetadataItemLinkedInvocation"] | components["schemas"]["MetadataToBoolCollectionInvocation"] | components["schemas"]["MetadataToBoolInvocation"] | components["schemas"]["MetadataToControlnetsInvocation"] | components["schemas"]["MetadataToFloatCollectionInvocation"] | components["schemas"]["MetadataToFloatInvocation"] | components["schemas"]["MetadataToIPAdaptersInvocation"] | components["schemas"]["MetadataToIntegerCollectionInvocation"] | components["schemas"]["MetadataToIntegerInvocation"] | components["schemas"]["MetadataToLorasCollectionInvocation"] | components["schemas"]["MetadataToLorasInvocation"] | components["schemas"]["MetadataToModelInvocation"] | components["schemas"]["MetadataToSDXLLorasInvocation"] | components["schemas"]["MetadataToSDXLModelInvocation"] | components["schemas"]["MetadataToSchedulerInvocation"] | components["schemas"]["MetadataToStringCollectionInvocation"] | components["schemas"]["MetadataToStringInvocation"] | components["schemas"]["MetadataToT2IAdaptersInvocation"] | components["schemas"]["MetadataToVAEInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["NormalMapInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["PasteImageIntoBoundingBoxInvocation"] | components["schemas"]["PiDiNetEdgeDetectionInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["SD3DenoiseInvocation"] | components["schemas"]["SD3ImageToLatentsInvocation"] | components["schemas"]["SD3LatentsToImageInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["Sd3ModelLoaderInvocation"] | components["schemas"]["Sd3TextEncoderInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SegmentAnythingInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["SpandrelImageToImageAutoscaleInvocation"] | components["schemas"]["SpandrelImageToImageInvocation"] | components["schemas"]["StringBatchInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringGenerator"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["TiledMultiDiffusionDenoiseLatents"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["VAELoaderInvocation"]; /** * Invocation Source Id * @description The ID of the prepared invocation's source node @@ -21337,7 +21758,7 @@ export type components = { * used, and the type will be ignored. They are included here for backwards compatibility. * @enum {string} */ - UIType: "MainModelField" | "CogView4MainModelField" | "FluxMainModelField" | "BriaMainModelField" | "BriaControlNetModelField" | "SD3MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "FluxVAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "CLIPEmbedModelField" | "CLIPLEmbedModelField" | "CLIPGEmbedModelField" | "SpandrelImageToImageModelField" | "ControlLoRAModelField" | "SigLipModelField" | "FluxReduxModelField" | "LLaVAModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict"; + UIType: "MainModelField" | "CogView4MainModelField" | "FluxMainModelField" | "BriaMainModelField" | "BriaControlNetModelField" | "SD3MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "FluxVAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "CLIPEmbedModelField" | "CLIPLEmbedModelField" | "CLIPGEmbedModelField" | "SpandrelImageToImageModelField" | "ControlLoRAModelField" | "SigLipModelField" | "FluxReduxModelField" | "LLaVAModelField" | "Imagen3ModelField" | "Imagen4ModelField" | "ChatGPT4oModelField" | "FluxKontextModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict"; /** UNetField */ UNetField: { /** @description Info to load unet submodel */ From 3a14791da3181cf6cd7391b613a238cd0fda0039 Mon Sep 17 00:00:00 2001 From: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Date: Fri, 25 Jul 2025 12:58:27 -0400 Subject: [PATCH 14/14] bria-ui-updates-wip --- invokeai/frontend/web/public/locales/en.json | 10 +- .../ControlLayerControlAdapter.tsx | 5 +- .../ControlLayerControlAdapterControlMode.tsx | 36 ++-- .../controlLayers/store/canvasSlice.ts | 20 +- .../controlLayers/store/paramsSlice.ts | 1 + .../src/features/controlLayers/store/types.ts | 10 +- .../graph/generation/addControlAdapters.ts | 39 ++-- .../util/graph/generation/buildBriaGraph.ts | 184 ++++++++++++++++++ .../parameters/util/optimalDimension.ts | 6 +- .../web/src/features/queue/store/readiness.ts | 35 ++++ 10 files changed, 316 insertions(+), 30 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/generation/buildBriaGraph.ts diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 6f7b756e0e7..2551dbd3406 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -1234,6 +1234,8 @@ "modelIncompatibleBboxHeight": "Bbox height is {{height}} but {{model}} requires multiple of {{multiple}}", "modelIncompatibleScaledBboxWidth": "Scaled bbox width is {{width}} but {{model}} requires multiple of {{multiple}}", "modelIncompatibleScaledBboxHeight": "Scaled bbox height is {{height}} but {{model}} requires multiple of {{multiple}}", + "briaRequiresExactDimensions": "Bria requires exact {{size}}x{{size}} dimensions", + "briaRequiresExactScaledDimensions": "Bria requires exact {{size}}x{{size}} scaled dimensions", "fluxModelMultipleControlLoRAs": "Can only use 1 Control LoRA at a time", "fluxKontextMultipleReferenceImages": "Can only use 1 Reference Image at a time with Flux Kontext", "canvasIsFiltering": "Canvas is busy (filtering)", @@ -2185,7 +2187,13 @@ "balanced": "Balanced (recommended)", "prompt": "Prompt", "control": "Control", - "megaControl": "Mega Control" + "megaControl": "Mega Control", + "depth": "Depth", + "canny": "Canny", + "colorgrid": "Color Grid", + "recolor": "Recolor", + "tile": "Tile", + "pose": "Pose" }, "ipAdapterMethod": { "ipAdapterMethod": "Mode", diff --git a/invokeai/frontend/web/src/features/controlLayers/components/ControlLayer/ControlLayerControlAdapter.tsx b/invokeai/frontend/web/src/features/controlLayers/components/ControlLayer/ControlLayerControlAdapter.tsx index 953638fad4c..1fd3b0a252c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/ControlLayer/ControlLayerControlAdapter.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/ControlLayer/ControlLayerControlAdapter.tsx @@ -20,7 +20,7 @@ import { import { getFilterForModel } from 'features/controlLayers/store/filters'; import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice'; import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors'; -import type { CanvasEntityIdentifier, ControlModeV2 } from 'features/controlLayers/store/types'; +import type { CanvasEntityIdentifier, ControlMode } from 'features/controlLayers/store/types'; import { replaceCanvasEntityObjectsWithImage } from 'features/imageActions/actions'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -56,7 +56,7 @@ export const ControlLayerControlAdapter = memo(() => { ); const onChangeControlMode = useCallback( - (controlMode: ControlModeV2) => { + (controlMode: ControlMode) => { dispatch(controlLayerControlModeChanged({ entityIdentifier, controlMode })); }, [dispatch, entityIdentifier] @@ -169,6 +169,7 @@ export const ControlLayerControlAdapter = memo(() => { )} diff --git a/invokeai/frontend/web/src/features/controlLayers/components/ControlLayer/ControlLayerControlAdapterControlMode.tsx b/invokeai/frontend/web/src/features/controlLayers/components/ControlLayer/ControlLayerControlAdapterControlMode.tsx index c80a6ef037d..1193ed43125 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/ControlLayer/ControlLayerControlAdapterControlMode.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/ControlLayer/ControlLayerControlAdapterControlMode.tsx @@ -1,32 +1,46 @@ import type { ComboboxOnChange } from '@invoke-ai/ui-library'; import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; -import type { ControlModeV2 } from 'features/controlLayers/store/types'; -import { isControlModeV2 } from 'features/controlLayers/store/types'; +import type { ControlMode } from 'features/controlLayers/store/types'; +import { isControlMode } from 'features/controlLayers/store/types'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; +import type { ControlNetModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; type Props = { - controlMode: ControlModeV2; - onChange: (controlMode: ControlModeV2) => void; + controlMode: ControlMode; + onChange: (controlMode: ControlMode) => void; + model: ControlNetModelConfig | null; }; -export const ControlLayerControlAdapterControlMode = memo(({ controlMode, onChange }: Props) => { +export const ControlLayerControlAdapterControlMode = memo(({ controlMode, onChange, model }: Props) => { const { t } = useTranslation(); - const CONTROL_MODE_DATA = useMemo( - () => [ + + const CONTROL_MODE_DATA = useMemo(() => { + // Show BRIA-specific control modes if a BRIA model is selected + if (model?.base === 'bria') { + return [ + { label: t('controlLayers.controlMode.depth'), value: 'depth' }, + { label: t('controlLayers.controlMode.canny'), value: 'canny' }, + { label: t('controlLayers.controlMode.colorgrid'), value: 'colorgrid' }, + { label: t('controlLayers.controlMode.recolor'), value: 'recolor' }, + { label: t('controlLayers.controlMode.tile'), value: 'tile' }, + { label: t('controlLayers.controlMode.pose'), value: 'pose' }, + ]; + } + // Show standard control modes for other models + return [ { label: t('controlLayers.controlMode.balanced'), value: 'balanced' }, { label: t('controlLayers.controlMode.prompt'), value: 'more_prompt' }, { label: t('controlLayers.controlMode.control'), value: 'more_control' }, { label: t('controlLayers.controlMode.megaControl'), value: 'unbalanced' }, - ], - [t] - ); + ]; + }, [t, model?.base]); const handleControlModeChange = useCallback( (v) => { - assert(isControlModeV2(v?.value)); + assert(isControlMode(v?.value)); onChange(v.value); }, [onChange] diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts index f304f3a1eef..05ebc3d9f2c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts @@ -20,6 +20,7 @@ import type { CanvasInpaintMaskState, CanvasMetadata, ControlLoRAConfig, + ControlMode, EntityMovedByPayload, FillStyle, FLUXReduxImageInfluence, @@ -514,10 +515,27 @@ export const canvasSlice = createSlice({ default: break; } + + // When switching to a BRIA controlnet model, set appropriate default control mode + if (layer.controlAdapter.type === 'controlnet' && modelConfig.base === 'bria') { + const currentMode = layer.controlAdapter.controlMode; + // Check if current mode is not a valid BRIA mode + if (!['depth', 'canny', 'colorgrid', 'recolor', 'tile', 'pose'].includes(currentMode)) { + layer.controlAdapter.controlMode = 'depth'; // Default BRIA mode + } + } + // When switching from BRIA to other controlnet models, set appropriate default control mode + else if (layer.controlAdapter.type === 'controlnet' && modelConfig.base !== 'bria') { + const currentMode = layer.controlAdapter.controlMode; + // Check if current mode is a BRIA-specific mode + if (['depth', 'canny', 'colorgrid', 'recolor', 'tile', 'pose'].includes(currentMode)) { + layer.controlAdapter.controlMode = 'balanced'; // Default standard mode + } + } }, controlLayerControlModeChanged: ( state, - action: PayloadAction> + action: PayloadAction> ) => { const { entityIdentifier, controlMode } = action.payload; const layer = selectEntity(state, entityIdentifier); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index dc29f16fa1d..aef73b5a0ef 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -421,6 +421,7 @@ export const selectIsFLUX = createParamsSelector((params) => params.model?.base export const selectIsSD3 = createParamsSelector((params) => params.model?.base === 'sd-3'); export const selectIsCogView4 = createParamsSelector((params) => params.model?.base === 'cogview4'); export const selectIsImagen3 = createParamsSelector((params) => params.model?.base === 'imagen3'); +export const selectIsBria = createParamsSelector((params) => params.model?.base === 'bria'); export const selectIsImagen4 = createParamsSelector((params) => params.model?.base === 'imagen4'); export const selectIsFluxKontextApi = createParamsSelector((params) => params.model?.base === 'flux-kontext'); export const selectIsFluxKontext = createParamsSelector((params) => { diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 85dc44e38f1..770e207538a 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -74,6 +74,14 @@ const zControlModeV2 = z.enum(['balanced', 'more_prompt', 'more_control', 'unbal export type ControlModeV2 = z.infer; export const isControlModeV2 = (v: unknown): v is ControlModeV2 => zControlModeV2.safeParse(v).success; +const zBriaControlMode = z.enum(['depth', 'canny', 'colorgrid', 'recolor', 'tile', 'pose']); +export type BriaControlMode = z.infer; +export const isBriaControlMode = (v: unknown): v is BriaControlMode => zBriaControlMode.safeParse(v).success; + +const zControlMode = z.union([zControlModeV2, zBriaControlMode]); +export type ControlMode = z.infer; +export const isControlMode = (v: unknown): v is ControlMode => zControlMode.safeParse(v).success; + const zCLIPVisionModelV2 = z.enum(['ViT-H', 'ViT-G', 'ViT-L']); export type CLIPVisionModelV2 = z.infer; export const isCLIPVisionModelV2 = (v: unknown): v is CLIPVisionModelV2 => zCLIPVisionModelV2.safeParse(v).success; @@ -363,7 +371,7 @@ const zControlNetConfig = z.object({ model: zServerValidatedModelIdentifierField.nullable(), weight: z.number().gte(-1).lte(2), beginEndStepPct: zBeginEndStepPct, - controlMode: zControlModeV2, + controlMode: zControlMode, }); export type ControlNetConfig = z.infer; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts index 5fcd13ba4fd..89115f69c96 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts @@ -158,18 +158,33 @@ const addControlNetToGraph = ( assert(model !== null); const { image_name } = imageDTO; - const controlNet = g.addNode({ - id: `control_net_${id}`, - type: model.base === 'flux' ? 'flux_controlnet' : 'controlnet', - begin_step_percent: beginEndStepPct[0], - end_step_percent: beginEndStepPct[1], - control_mode: model.base === 'flux' ? undefined : controlMode, - resize_mode: 'just_resize', - control_model: model, - control_weight: weight, - image: { image_name }, - }); - g.addEdge(controlNet, 'control', collector, 'item'); + if (model.base === 'bria') { + // BRIA uses a different node type and parameters + const controlNet = g.addNode({ + id: `control_net_${id}`, + type: 'bria_controlnet', + control_model: model, + control_weight: weight, + control_image: { image_name }, + // BRIA uses control_mode instead of controlMode + control_mode: controlMode || 'pose', // Default to 'pose' if not specified + }); + g.addEdge(controlNet, 'control', collector, 'item'); + } else { + // Standard controlnet for other models + const controlNet = g.addNode({ + id: `control_net_${id}`, + type: model.base === 'flux' ? 'flux_controlnet' : 'controlnet', + begin_step_percent: beginEndStepPct[0], + end_step_percent: beginEndStepPct[1], + control_mode: model.base === 'flux' ? undefined : controlMode, + resize_mode: 'just_resize', + control_model: model, + control_weight: weight, + image: { image_name }, + }); + g.addEdge(controlNet, 'control', collector, 'item'); + } }; const addT2IAdapterToGraph = ( diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildBriaGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildBriaGraph.ts new file mode 100644 index 00000000000..e9cb66cf4e8 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildBriaGraph.ts @@ -0,0 +1,184 @@ +import { logger } from 'app/logging/logger'; +import { getPrefixedId } from 'features/controlLayers/konva/util'; +import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; +import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice'; +import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors'; +import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators'; +import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker'; +import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage'; +import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker'; +import { Graph } from 'features/nodes/util/graph/generation/Graph'; +import { selectCanvasOutputFields } from 'features/nodes/util/graph/graphBuilderUtils'; +import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types'; +import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'; +import { selectActiveTab } from 'features/ui/store/uiSelectors'; +import { t } from 'i18next'; +import type { Invocation } from 'services/api/types'; +import type { Equals } from 'tsafe'; +import { assert } from 'tsafe'; + +import { addControlNets } from './addControlAdapters'; + +const log = logger('system'); + +export const buildBriaGraph = async (arg: GraphBuilderArg): Promise => { + const { generationMode, state, manager } = arg; + log.debug({ generationMode, manager: manager?.id }, 'Building Bria graph'); + + const model = selectMainModelConfig(state); + assert(model, 'No model selected'); + assert(model.base === 'bria', 'Selected model is not a Bria model'); + + const params = selectParamsSlice(state); + const canvas = selectCanvasSlice(state); + const refImages = selectRefImagesSlice(state); + + const { guidance, steps, seed } = params; + + // Bria only supports txt2img for now + if (generationMode !== 'txt2img') { + throw new UnsupportedGenerationModeError(t('toast.briaIncompatibleGenerationMode')); + } + + const g = new Graph(getPrefixedId('bria_graph')); + + // Add model loader + const modelLoader = g.addNode({ + type: 'bria_model_loader', + id: getPrefixedId('bria_model_loader'), + model, + } as Invocation<'bria_model_loader'>); + + // Add positive prompt + const positivePrompt = g.addNode({ + id: getPrefixedId('positive_prompt'), + type: 'string', + }); + + // Add text encoder + const textEncoder = g.addNode({ + type: 'bria_text_encoder', + id: getPrefixedId('bria_text_encoder'), + prompt: positivePrompt, + negative_prompt: params.negativePrompt, + max_length: 128, + } as Invocation<'bria_text_encoder'>); + + // Add latent sampler for initial noise + const latentSampler = g.addNode({ + type: 'bria_latent_sampler', + id: getPrefixedId('bria_latent_sampler'), + width: params.width, + height: params.height, + seed: seed, + } as Invocation<'bria_latent_sampler'>); + + // Add denoise node + const denoise = g.addNode({ + type: 'bria_denoise', + id: getPrefixedId('bria_denoise'), + num_steps: steps, + guidance_scale: guidance, + } as Invocation<'bria_denoise'>); + + // Add decoder + const decoder = g.addNode({ + type: 'bria_decoder', + id: getPrefixedId('bria_decoder'), + } as Invocation<'bria_decoder'>); + + // Connect model components to text encoder + g.addEdge(modelLoader, 't5_encoder', textEncoder, 't5_encoder'); + + // Connect model components to latent sampler + g.addEdge(modelLoader, 'transformer', latentSampler, 'transformer'); + + // Connect model components to denoise + g.addEdge(modelLoader, 'transformer', denoise, 'transformer'); + g.addEdge(modelLoader, 't5_encoder', denoise, 't5_encoder'); + g.addEdge(modelLoader, 'vae', denoise, 'vae'); + + // Connect text encoder to denoise + g.addEdge(textEncoder, 'pos_embeds', denoise, 'pos_embeds'); + g.addEdge(textEncoder, 'neg_embeds', denoise, 'neg_embeds'); + g.addEdge(textEncoder, 'text_ids', denoise, 'text_ids'); + + // Connect latent sampler to denoise + g.addEdge(latentSampler, 'latents', denoise, 'latents'); + g.addEdge(latentSampler, 'latent_image_ids', denoise, 'latent_image_ids'); + + // Connect model components to decoder + g.addEdge(modelLoader, 'vae', decoder, 'vae'); + + // Connect denoise to decoder + g.addEdge(denoise, 'latents', decoder, 'latents'); + + // Add ControlNet support + if (manager !== null) { + const controlNetCollector = g.addNode({ + type: 'collect', + id: getPrefixedId('control_net_collector'), + }); + + const controlNetResult = await addControlNets({ + manager, + entities: canvas.controlLayers.entities, + g, + rect: canvas.bbox.rect, + collector: controlNetCollector, + model, + }); + + if (controlNetResult.addedControlNets > 0) { + // Connect the collector to the denoise node's control input + g.addEdge(controlNetCollector, 'collection', denoise, 'control'); + } else { + // Remove the collector if no control nets were added + g.deleteNode(controlNetCollector.id); + } + } + + // Add metadata + g.upsertMetadata({ + guidance_scale: guidance, + model: Graph.getModelMetadataField(model), + steps, + generation_mode: 'bria_txt2img', + }); + g.addEdgeToMetadata(latentSampler, 'seed', 'seed'); + g.addEdgeToMetadata(positivePrompt, 'value', 'positive_prompt'); + + let canvasOutput: Invocation = decoder; + + // Add text to image handling + canvasOutput = addTextToImage({ + g, + state, + denoise: decoder, // Use decoder as the denoise equivalent + l2i: decoder, + }); + + // Add NSFW checker + if (state.system.shouldUseNSFWChecker) { + canvasOutput = addNSFWChecker(g, canvasOutput); + } + + // Add watermarker + if (state.system.shouldUseWatermarker) { + canvasOutput = addWatermarker(g, canvasOutput); + } + + g.updateNode(canvasOutput, selectCanvasOutputFields(state)); + + if (selectActiveTab(state) === 'canvas') { + g.upsertMetadata(selectCanvasMetadata(state)); + } + + g.setMetadataReceivingNode(canvasOutput); + + return { + g, + seed: latentSampler, + positivePrompt, + }; +}; \ No newline at end of file diff --git a/invokeai/frontend/web/src/features/parameters/util/optimalDimension.ts b/invokeai/frontend/web/src/features/parameters/util/optimalDimension.ts index 9041a872139..54c6abf582f 100644 --- a/invokeai/frontend/web/src/features/parameters/util/optimalDimension.ts +++ b/invokeai/frontend/web/src/features/parameters/util/optimalDimension.ts @@ -3,7 +3,7 @@ import type { BaseModelType } from 'services/api/types'; /** * Gets the optimal dimension for a given base model: * - sd-1, sd-2: 512 - * - sdxl, flux, sd-3, cogview4: 1024 + * - sdxl, flux, sd-3, cogview4, bria: 1024 * - default: 1024 * @param base The base model * @returns The optimal dimension for the model, defaulting to 1024 @@ -21,6 +21,7 @@ export const getOptimalDimension = (base?: BaseModelType | null): number => { case 'imagen4': case 'chatgpt-4o': case 'flux-kontext': + case 'bria': default: return 1024; } @@ -63,7 +64,7 @@ export const isInSDXLTrainingDimensions = (width: number, height: number): boole /** * Gets the grid size for a given base model. For Flux, the grid size is 16, otherwise it is 8. * - sd-1, sd-2, sdxl: 8 - * - flux, sd-3: 16 + * - flux, sd-3, bria: 16 * - cogview4: 32 * - default: 8 * @param base The base model @@ -75,6 +76,7 @@ export const getGridSize = (base?: BaseModelType | null): number => { return 32; case 'flux': case 'sd-3': + case 'bria': return 16; case 'sd-1': case 'sd-2': diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.ts b/invokeai/frontend/web/src/features/queue/store/readiness.ts index 0a6a4009b17..712525027f9 100644 --- a/invokeai/frontend/web/src/features/queue/store/readiness.ts +++ b/invokeai/frontend/web/src/features/queue/store/readiness.ts @@ -267,6 +267,13 @@ const getReasonsWhyCannotEnqueueGenerateTab = (arg: { } } + if (model?.base === 'bria') { + if (!params.t5EncoderModel) { + reasons.push({ content: i18n.t('parameters.invoke.noT5EncoderModelSelected') }); + } + // Bria uses fixed 1024x1024 dimensions, no need to validate dimensions + } + if (model && isChatGPT4oHighModelDisabled(model)) { reasons.push({ content: i18n.t('parameters.invoke.modelDisabledForTrial', { modelName: model.name }) }); } @@ -601,6 +608,34 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { } } + if (model?.base === 'bria') { + if (!params.t5EncoderModel) { + reasons.push({ content: i18n.t('parameters.invoke.noT5EncoderModelSelected') }); + } + + // Bria requires fixed 1024x1024 dimensions + const { bbox } = canvas; + const requiredSize = 1024; + + if (bbox.scaleMethod === 'none') { + if (bbox.rect.width !== requiredSize || bbox.rect.height !== requiredSize) { + reasons.push({ + content: i18n.t('parameters.invoke.briaRequiresExactDimensions', { + size: requiredSize, + }), + }); + } + } else { + if (bbox.scaledSize.width !== requiredSize || bbox.scaledSize.height !== requiredSize) { + reasons.push({ + content: i18n.t('parameters.invoke.briaRequiresExactScaledDimensions', { + size: requiredSize, + }), + }); + } + } + } + if (model && isChatGPT4oHighModelDisabled(model)) { reasons.push({ content: i18n.t('parameters.invoke.modelDisabledForTrial', { modelName: model.name }) }); }