From 82f3992135999bb492091c86ee875216232bd5a7 Mon Sep 17 00:00:00 2001 From: dunkeroni Date: Sun, 10 Aug 2025 00:47:40 -0400 Subject: [PATCH 01/11] visual adjustment filters --- invokeai/frontend/web/public/locales/en.json | 13 ++ .../components/RasterLayer/RasterLayer.tsx | 3 + .../RasterLayerAdjustmentsPanel.tsx | 175 +++++++++++++++++ .../RasterLayer/RasterLayerMenuItems.tsx | 33 +++- .../CanvasEntity/CanvasEntityAdapterBase.ts | 4 +- .../CanvasEntityAdapterRasterLayer.ts | 83 +++++++- .../features/controlLayers/konva/filters.ts | 182 ++++++++++++++++++ .../controlLayers/store/canvasSlice.ts | 170 ++++++++++++++++ .../src/features/controlLayers/store/types.ts | 26 +++ .../src/features/controlLayers/store/util.ts | 1 + 10 files changed, 685 insertions(+), 5 deletions(-) create mode 100644 invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerAdjustmentsPanel.tsx diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index ad245191b44..995034f13dd 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -2019,6 +2019,19 @@ "pullBboxIntoLayerError": "Problem Pulling BBox Into Layer", "pullBboxIntoReferenceImageOk": "Bbox Pulled Into ReferenceImage", "pullBboxIntoReferenceImageError": "Problem Pulling BBox Into ReferenceImage", + "addAdjustments": "Add Adjustments", + "removeAdjustments": "Remove Adjustments", + "adjustments": { + "heading": "Adjustments", + "expand": "Expand adjustments", + "collapse": "Collapse adjustments", + "brightness": "Brightness", + "contrast": "Contrast", + "saturation": "Saturation", + "temperature": "Temperature", + "tint": "Tint", + "sharpness": "Sharpness" + }, "regionIsEmpty": "Selected region is empty", "mergeVisible": "Merge Visible", "mergeDown": "Merge Down", diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx index ddaefb1073e..c0133687816 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx @@ -4,6 +4,7 @@ import { CanvasEntityHeader } from 'features/controlLayers/components/common/Can import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions'; import { CanvasEntityPreviewImage } from 'features/controlLayers/components/common/CanvasEntityPreviewImage'; import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit'; +import { RasterLayerAdjustmentsPanel } from 'features/controlLayers/components/RasterLayer/RasterLayerAdjustmentsPanel'; import { CanvasEntityStateGate } from 'features/controlLayers/contexts/CanvasEntityStateGate'; import { RasterLayerAdapterGate } from 'features/controlLayers/contexts/EntityAdapterContext'; import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; @@ -39,6 +40,8 @@ export const RasterLayer = memo(({ id }: Props) => { + {/* Show adjustments UI only when adjustments exist */} + { + const dispatch = useAppDispatch(); + const entityIdentifier = useEntityIdentifierContext<'raster_layer'>(); + const layer = useAppSelector((s) => selectEntity(s.canvas.present, entityIdentifier)); + const { t } = useTranslation(); + + const hasAdjustments = Boolean(layer?.adjustments); + const enabled = Boolean(layer?.adjustments?.enabled); + const collapsed = Boolean(layer?.adjustments?.collapsed); + const simple = layer?.adjustments?.simple ?? { + brightness: 0, + contrast: 0, + saturation: 0, + temperature: 0, + tint: 0, + sharpness: 0, + }; + + const onToggleEnabled = useCallback( + (v: boolean) => { + dispatch( + rasterLayerAdjustmentsSet({ entityIdentifier, adjustments: { enabled: v, collapsed: false, mode: 'simple' } }) + ); + }, + [dispatch, entityIdentifier] + ); + + const onReset = useCallback(() => { + // Reset values to defaults but keep adjustments present; preserve enabled/collapsed/mode + dispatch( + rasterLayerAdjustmentsSimpleUpdated({ + entityIdentifier, + simple: { + brightness: 0, + contrast: 0, + saturation: 0, + temperature: 0, + tint: 0, + sharpness: 0, + }, + }) + ); + const defaultPoints: Array<[number, number]> = [ + [0, 0], + [255, 255], + ]; + dispatch(rasterLayerAdjustmentsCurvesUpdated({ entityIdentifier, channel: 'master', points: defaultPoints })); + dispatch(rasterLayerAdjustmentsCurvesUpdated({ entityIdentifier, channel: 'r', points: defaultPoints })); + dispatch(rasterLayerAdjustmentsCurvesUpdated({ entityIdentifier, channel: 'g', points: defaultPoints })); + dispatch(rasterLayerAdjustmentsCurvesUpdated({ entityIdentifier, channel: 'b', points: defaultPoints })); + }, [dispatch, entityIdentifier]); + + const onToggleCollapsed = useCallback(() => { + dispatch( + rasterLayerAdjustmentsSet({ + entityIdentifier, + adjustments: { collapsed: !collapsed }, + }) + ); + }, [dispatch, entityIdentifier, collapsed]); + + const slider = useMemo( + () => + ({ + row: (label: string, value: number, onChange: (v: number) => void, min = -1, max = 1, step = 0.01) => ( + + + + {label} + + + + + + ), + }) as const, + [] + ); + + const onBrightness = useCallback( + (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { brightness: v } })), + [dispatch, entityIdentifier] + ); + const onContrast = useCallback( + (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { contrast: v } })), + [dispatch, entityIdentifier] + ); + const onSaturation = useCallback( + (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { saturation: v } })), + [dispatch, entityIdentifier] + ); + const onTemperature = useCallback( + (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { temperature: v } })), + [dispatch, entityIdentifier] + ); + const onTint = useCallback( + (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { tint: v } })), + [dispatch, entityIdentifier] + ); + const onSharpness = useCallback( + (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { sharpness: v } })), + [dispatch, entityIdentifier] + ); + + const handleToggleEnabled = useCallback( + (e: React.ChangeEvent) => onToggleEnabled(e.target.checked), + [onToggleEnabled] + ); + + // Hide the panel entirely until adjustments are added via context menu + if (!hasAdjustments) { + return null; + } + + return ( + <> + + + } + /> + + Adjustments + + + + + + {!collapsed && ( + <> + {slider.row(t('controlLayers.adjustments.brightness'), simple.brightness, onBrightness)} + {slider.row(t('controlLayers.adjustments.contrast'), simple.contrast, onContrast)} + {slider.row(t('controlLayers.adjustments.saturation'), simple.saturation, onSaturation)} + {slider.row(t('controlLayers.adjustments.temperature'), simple.temperature, onTemperature)} + {slider.row(t('controlLayers.adjustments.tint'), simple.tint, onTint)} + {slider.row(t('controlLayers.adjustments.sharpness'), simple.sharpness, onSharpness, 0, 1, 0.01)} + + )} + + ); +}); + +RasterLayerAdjustmentsPanel.displayName = 'RasterLayerAdjustmentsPanel'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx index 65a16a7b4f9..a3b7b0d48dc 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx @@ -1,4 +1,5 @@ -import { MenuDivider } from '@invoke-ai/ui-library'; +import { MenuDivider, MenuItem } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { IconMenuItemGroup } from 'common/components/IconMenuItem'; import { CanvasEntityMenuItemsArrange } from 'features/controlLayers/components/common/CanvasEntityMenuItemsArrange'; import { CanvasEntityMenuItemsCropToBbox } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCropToBbox'; @@ -11,9 +12,33 @@ import { CanvasEntityMenuItemsSelectObject } from 'features/controlLayers/compon import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform'; import { RasterLayerMenuItemsConvertToSubMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsConvertToSubMenu'; import { RasterLayerMenuItemsCopyToSubMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsCopyToSubMenu'; -import { memo } from 'react'; +import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; +import { rasterLayerAdjustmentsReset, rasterLayerAdjustmentsSet } from 'features/controlLayers/store/canvasSlice'; +import type { CanvasRasterLayerState } from 'features/controlLayers/store/types'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; export const RasterLayerMenuItems = memo(() => { + const dispatch = useAppDispatch(); + const entityIdentifier = useEntityIdentifierContext<'raster_layer'>(); + const { t } = useTranslation(); + const layer = useAppSelector((s) => + s.canvas.present.rasterLayers.entities.find((e: CanvasRasterLayerState) => e.id === entityIdentifier.id) + ); + const hasAdjustments = Boolean(layer?.adjustments); + const onToggleAdjustmentsPresence = useCallback(() => { + if (hasAdjustments) { + dispatch(rasterLayerAdjustmentsReset({ entityIdentifier })); + } else { + dispatch( + rasterLayerAdjustmentsSet({ + entityIdentifier, + adjustments: { enabled: true, collapsed: false, mode: 'simple' }, + }) + ); + } + }, [dispatch, entityIdentifier, hasAdjustments]); + return ( <> @@ -22,6 +47,10 @@ export const RasterLayerMenuItems = memo(() => { + + {hasAdjustments ? t('controlLayers.removeAdjustments') : t('controlLayers.addAdjustments')} + + diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts index 6c55e949377..2b45f61b291 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts @@ -475,7 +475,7 @@ export abstract class CanvasEntityAdapterBase { @@ -74,4 +80,79 @@ export class CanvasEntityAdapterRasterLayer extends CanvasEntityAdapterBase< const keysToOmit: (keyof CanvasRasterLayerState)[] = ['name', 'isLocked']; return omit(this.state, keysToOmit); }; + + private syncAdjustmentsFilter = () => { + const a = this.state.adjustments; + const apply = !!a && a.enabled; + // The filter operates on the renderer's object group; we can set filters at the group level via renderer + const group = this.renderer.konva.objectGroup; + if (apply) { + const filters = group.filters() ?? []; + let nextFilters = filters.filter((f: unknown) => f !== AdjustmentsSimpleFilter && f !== AdjustmentsCurvesFilter); + if (a.mode === 'simple') { + group.setAttr('adjustmentsSimple', a.simple); + group.setAttr('adjustmentsCurves', null); + nextFilters = [...nextFilters, AdjustmentsSimpleFilter]; + } else { + // Build LUTs and set curves attr + const master = buildCurveLUT(a.curves.master); + const r = buildCurveLUT(a.curves.r); + const g = buildCurveLUT(a.curves.g); + const b = buildCurveLUT(a.curves.b); + group.setAttr('adjustmentsCurves', { master, r, g, b }); + group.setAttr('adjustmentsSimple', null); + nextFilters = [...nextFilters, AdjustmentsCurvesFilter]; + } + group.filters(nextFilters); + this._throttledCacheRefresh(); + } else { + // Remove our filter if present + const filters = (group.filters() ?? []).filter( + (f: unknown) => f !== AdjustmentsSimpleFilter && f !== AdjustmentsCurvesFilter + ); + group.filters(filters); + group.setAttr('adjustmentsSimple', null); + group.setAttr('adjustmentsCurves', null); + this._throttledCacheRefresh(); + } + }; + + private _throttledCacheRefresh = throttle(() => this.renderer.syncKonvaCache(true), 50); + + private haveAdjustmentsChanged = (prevState: CanvasRasterLayerState, currState: CanvasRasterLayerState): boolean => { + const pa = prevState.adjustments; + const ca = currState.adjustments; + if (pa === ca) { + return false; + } + if (!pa || !ca) { + return true; + } + if (pa.enabled !== ca.enabled) { + return true; + } + if (pa.mode !== ca.mode) { + return true; + } + // simple params + const ps = pa.simple; + const cs = ca.simple; + if ( + ps.brightness !== cs.brightness || + ps.contrast !== cs.contrast || + ps.saturation !== cs.saturation || + ps.temperature !== cs.temperature || + ps.tint !== cs.tint || + ps.sharpness !== cs.sharpness + ) { + return true; + } + // curves reference (UI not implemented yet) - if arrays differ by ref, consider changed + const pc = pa.curves; + const cc = ca.curves; + if (pc !== cc) { + return true; + } + return false; + }; } diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts b/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts index 34c5c9ac5de..159d2c6da3d 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts @@ -20,3 +20,185 @@ export const LightnessToAlphaFilter = (imageData: ImageData): void => { imageData.data[i * 4 + 3] = Math.min(a, (cMin + cMax) / 2); } }; + +// Utility clamp +const clamp = (v: number, min: number, max: number) => (v < min ? min : v > max ? max : v); + +type SimpleAdjustParams = { + brightness: number; // -1..1 (additive) + contrast: number; // -1..1 (scale around 128) + saturation: number; // -1..1 + temperature: number; // -1..1 (blue<->yellow approx) + tint: number; // -1..1 (green<->magenta approx) + sharpness: number; // -1..1 (light unsharp mask) +}; + +/** + * Per-layer simple adjustments filter (brightness, contrast, saturation, temp, tint, sharpness). + * + * Parameters are read from the Konva node attr `adjustmentsSimple` set by the adapter. + */ +type KonvaFilterThis = { getAttr?: (key: string) => unknown }; +export const AdjustmentsSimpleFilter = function (this: KonvaFilterThis, imageData: ImageData): void { + const params = (this?.getAttr?.('adjustmentsSimple') as SimpleAdjustParams | undefined) ?? null; + if (!params) { + return; + } + + const { brightness, contrast, saturation, temperature, tint, sharpness } = params; + + const data = imageData.data; + const len = data.length / 4; + const width = (imageData as ImageData & { width: number }).width ?? 0; + const height = (imageData as ImageData & { height: number }).height ?? 0; + + // Precompute factors + const brightnessShift = brightness * 255; // additive shift + const contrastFactor = 1 + contrast; // scale around 128 + + // Temperature/Tint multipliers + const tempK = 0.5; + const tintK = 0.5; + const rTempMul = 1 + temperature * tempK; + const bTempMul = 1 - temperature * tempK; + const rTintMul = 1 + (tint > 0 ? tint * tintK : -tint * 0); + const gTintMul = 1 - Math.abs(tint) * tintK; + const bTintMul = 1 + (tint > 0 ? tint * tintK : -tint * 0); + + // Saturation matrix (HSL-based approximation via luma coefficients) + const lumaR = 0.2126; + const lumaG = 0.7152; + const lumaB = 0.0722; + const S = 1 + saturation; // 0..2 + const m00 = lumaR * (1 - S) + S; + const m01 = lumaG * (1 - S); + const m02 = lumaB * (1 - S); + const m10 = lumaR * (1 - S); + const m11 = lumaG * (1 - S) + S; + const m12 = lumaB * (1 - S); + const m20 = lumaR * (1 - S); + const m21 = lumaG * (1 - S); + const m22 = lumaB * (1 - S) + S; + + // First pass: apply per-pixel color adjustments (excluding sharpness) + for (let i = 0; i < len; i++) { + const idx = i * 4; + let r = data[idx + 0] as number; + let g = data[idx + 1] as number; + let b = data[idx + 2] as number; + const a = data[idx + 3] as number; + + // Brightness (additive) + r = r + brightnessShift; + g = g + brightnessShift; + b = b + brightnessShift; + + // Contrast around mid-point 128 + r = (r - 128) * contrastFactor + 128; + g = (g - 128) * contrastFactor + 128; + b = (b - 128) * contrastFactor + 128; + + // Temperature (R/B axis) and Tint (G vs Magenta) + r = r * rTempMul * rTintMul; + g = g * gTintMul; + b = b * bTempMul * bTintMul; + + // Saturation via matrix + const r2 = r * m00 + g * m01 + b * m02; + const g2 = r * m10 + g * m11 + b * m12; + const b2 = r * m20 + g * m21 + b * m22; + + data[idx + 0] = clamp(r2, 0, 255); + data[idx + 1] = clamp(g2, 0, 255); + data[idx + 2] = clamp(b2, 0, 255); + data[idx + 3] = a; + } + + // Optional sharpen (simple unsharp mask with 3x3 kernel) + if (Math.abs(sharpness) > 1e-3 && width > 2 && height > 2) { + const src = new Uint8ClampedArray(data); // copy of modified data + const a = Math.max(-1, Math.min(1, sharpness)) * 0.5; // amount + const center = 1 + 4 * a; + const neighbor = -a; + for (let y = 1; y < height - 1; y++) { + for (let x = 1; x < width - 1; x++) { + const idx = (y * width + x) * 4; + for (let c = 0; c < 3; c++) { + const centerPx = src[idx + c] ?? 0; + const leftPx = src[idx - 4 + c] ?? 0; + const rightPx = src[idx + 4 + c] ?? 0; + const topPx = src[idx - width * 4 + c] ?? 0; + const bottomPx = src[idx + width * 4 + c] ?? 0; + const v = centerPx * center + leftPx * neighbor + rightPx * neighbor + topPx * neighbor + bottomPx * neighbor; + data[idx + c] = clamp(v, 0, 255); + } + // preserve alpha + } + } + } +}; + +// Build a 256-length LUT from 0..255 control points (linear interpolation for v1) +export const buildCurveLUT = (points: Array<[number, number]>): number[] => { + if (!points || points.length === 0) { + return Array.from({ length: 256 }, (_, i) => i); + } + const pts = points + .map(([x, y]) => [clamp(Math.round(x), 0, 255), clamp(Math.round(y), 0, 255)] as [number, number]) + .sort((a, b) => a[0] - b[0]); + if ((pts[0]?.[0] ?? 0) !== 0) { + pts.unshift([0, pts[0]?.[1] ?? 0]); + } + const last = pts[pts.length - 1]; + if ((last?.[0] ?? 255) !== 255) { + pts.push([255, last?.[1] ?? 255]); + } + const lut = new Array(256); + let j = 0; + for (let x = 0; x <= 255; x++) { + while (j < pts.length - 2 && x > (pts[j + 1]?.[0] ?? 255)) { + j++; + } + const p0 = pts[j] ?? [0, 0]; + const p1 = pts[j + 1] ?? [255, 255]; + const [x0, y0] = p0; + const [x1, y1] = p1; + const t = x1 === x0 ? 0 : (x - x0) / (x1 - x0); + const y = y0 + (y1 - y0) * t; + lut[x] = clamp(Math.round(y), 0, 255); + } + return lut; +}; + +type CurvesAdjustParams = { + master: number[]; + r: number[]; + g: number[]; + b: number[]; +}; + +// Curves filter: apply master curve, then per-channel curves +export const AdjustmentsCurvesFilter = function (this: KonvaFilterThis, imageData: ImageData): void { + const params = (this?.getAttr?.('adjustmentsCurves') as CurvesAdjustParams | undefined) ?? null; + if (!params) { + return; + } + const { master, r, g, b } = params; + if (!master || !r || !g || !b) { + return; + } + const data = imageData.data; + const len = data.length / 4; + for (let i = 0; i < len; i++) { + const idx = i * 4; + const r0 = data[idx + 0] as number; + const g0 = data[idx + 1] as number; + const b0 = data[idx + 2] as number; + const rm = master[r0] ?? r0; + const gm = master[g0] ?? g0; + const bm = master[b0] ?? b0; + data[idx + 0] = clamp(r[rm] ?? rm, 0, 255); + data[idx + 1] = clamp(g[gm] ?? gm, 0, 255); + data[idx + 2] = clamp(b[bm] ?? bm, 0, 255); + } +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts index 61168a0ec5a..a00a9c588d0 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts @@ -102,6 +102,165 @@ const slice = createSlice({ reducers: { // undoable canvas state //#region Raster layers + rasterLayerAdjustmentsSet: ( + state, + action: PayloadAction< + EntityIdentifierPayload< + { + adjustments: + | NonNullable + | { enabled?: boolean; collapsed?: boolean; mode?: 'simple' | 'curves' } + | null; + }, + 'raster_layer' + > + > + ) => { + const { entityIdentifier, adjustments } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer) { + return; + } + if (adjustments === null) { + layer.adjustments = null; + return; + } + if (layer.adjustments === null) { + layer.adjustments = { + version: 1, + enabled: true, + collapsed: false, + mode: 'simple', + simple: { brightness: 0, contrast: 0, saturation: 0, temperature: 0, tint: 0, sharpness: 0 }, + curves: { + master: [ + [0, 0], + [255, 255], + ], + r: [ + [0, 0], + [255, 255], + ], + g: [ + [0, 0], + [255, 255], + ], + b: [ + [0, 0], + [255, 255], + ], + }, + }; + } + if (typeof adjustments === 'object' && adjustments !== null && 'version' in adjustments) { + layer.adjustments = merge(layer.adjustments, adjustments as NonNullable); + } else { + // Shallow toggles only + const partial = adjustments as { enabled?: boolean; collapsed?: boolean; mode?: 'simple' | 'curves' }; + layer.adjustments = merge(layer.adjustments, partial); + } + }, + rasterLayerAdjustmentsReset: (state, action: PayloadAction>) => { + const { entityIdentifier } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer) { + return; + } + layer.adjustments = null; + }, + rasterLayerAdjustmentsSimpleUpdated: ( + state, + action: PayloadAction< + EntityIdentifierPayload< + { + simple: Partial['simple']>>; + }, + 'raster_layer' + > + > + ) => { + const { entityIdentifier, simple } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer) { + return; + } + if (!layer.adjustments) { + // initialize baseline + layer.adjustments = { + version: 1, + enabled: true, + collapsed: false, + mode: 'simple', + simple: { brightness: 0, contrast: 0, saturation: 0, temperature: 0, tint: 0, sharpness: 0 }, + curves: { + master: [ + [0, 0], + [255, 255], + ], + r: [ + [0, 0], + [255, 255], + ], + g: [ + [0, 0], + [255, 255], + ], + b: [ + [0, 0], + [255, 255], + ], + }, + }; + } + layer.adjustments.simple = merge(layer.adjustments.simple, simple); + }, + rasterLayerAdjustmentsCurvesUpdated: ( + state, + action: PayloadAction< + EntityIdentifierPayload< + { + channel: 'master' | 'r' | 'g' | 'b'; + points: Array<[number, number]>; + }, + 'raster_layer' + > + > + ) => { + const { entityIdentifier, channel, points } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer) { + return; + } + if (!layer.adjustments) { + // initialize baseline + layer.adjustments = { + version: 1, + enabled: true, + collapsed: false, + mode: 'curves', + simple: { brightness: 0, contrast: 0, saturation: 0, temperature: 0, tint: 0, sharpness: 0 }, + curves: { + master: [ + [0, 0], + [255, 255], + ], + r: [ + [0, 0], + [255, 255], + ], + g: [ + [0, 0], + [255, 255], + ], + b: [ + [0, 0], + [255, 255], + ], + }, + }; + } + layer.adjustments.curves[channel] = points; + }, rasterLayerAdded: { reducer: ( state, @@ -1251,6 +1410,12 @@ const slice = createSlice({ switch (newEntity.type) { case 'raster_layer': newEntity.id = getPrefixedId('raster_layer'); + // Bake-on-copy semantics: adjustments should not carry over as live settings. + // TODO: Actually bake adjustments into pixels before cloning; for now, drop them on the copy. + if ('adjustments' in newEntity) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (newEntity as any).adjustments = null; + } state.rasterLayers.entities.push(newEntity); break; case 'control_layer': @@ -1621,6 +1786,11 @@ export const { entityBrushLineAdded, entityEraserLineAdded, entityRectAdded, + // Raster layer adjustments + rasterLayerAdjustmentsSet, + rasterLayerAdjustmentsReset, + rasterLayerAdjustmentsSimpleUpdated, + rasterLayerAdjustmentsCurvesUpdated, entityDeleted, entityArrangedForwardOne, entityArrangedToFront, diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index f771f8c7469..fdba9fa3f03 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -368,6 +368,32 @@ const zCanvasRasterLayerState = zCanvasEntityBase.extend({ position: zCoordinate, opacity: zOpacity, objects: z.array(zCanvasObjectState), + // Optional per-layer color adjustments (simple + curves). When null/undefined, no adjustments are applied. + adjustments: z + .object({ + version: z.literal(1), + enabled: z.boolean(), + collapsed: z.boolean(), + mode: z.enum(['simple', 'curves']), + simple: z.object({ + // All simple params normalized to [-1, 1] except sharpness [0, 1] + brightness: z.number().gte(-1).lte(1), + contrast: z.number().gte(-1).lte(1), + saturation: z.number().gte(-1).lte(1), + temperature: z.number().gte(-1).lte(1), + tint: z.number().gte(-1).lte(1), + sharpness: z.number().gte(0).lte(1), + }), + curves: z.object({ + // Curves are arrays of [x, y] control points in 0..255 space (no strict monotonic checks here) + master: z.array(z.tuple([z.number().int().min(0).max(255), z.number().int().min(0).max(255)])).min(2), + r: z.array(z.tuple([z.number().int().min(0).max(255), z.number().int().min(0).max(255)])).min(2), + g: z.array(z.tuple([z.number().int().min(0).max(255), z.number().int().min(0).max(255)])).min(2), + b: z.array(z.tuple([z.number().int().min(0).max(255), z.number().int().min(0).max(255)])).min(2), + }), + }) + .optional() + .nullable(), }); export type CanvasRasterLayerState = z.infer; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/util.ts b/invokeai/frontend/web/src/features/controlLayers/store/util.ts index 2d40cf17793..e14cfd546f4 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/util.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/util.ts @@ -198,6 +198,7 @@ export const getRasterLayerState = ( objects: [], opacity: 1, position: { x: 0, y: 0 }, + adjustments: null, }; merge(entityState, overrides); return entityState; From 8259268f673cfb7d7d98a2a038e89d5cdf3d64ce Mon Sep 17 00:00:00 2001 From: dunkeroni Date: Sun, 10 Aug 2025 01:43:11 -0400 Subject: [PATCH 02/11] apply filters to operations --- .../controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts | 2 +- .../konva/CanvasEntity/CanvasEntityAdapterControlLayer.ts | 2 +- .../konva/CanvasEntity/CanvasEntityAdapterRasterLayer.ts | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts index 2b45f61b291..5995e80663c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts @@ -571,7 +571,7 @@ export abstract class CanvasEntityAdapterBase => { const { rect } = this.manager.stateApi.getBbox(); const rasterizeResult = await withResultAsync(() => - this.renderer.rasterize({ rect, replaceObjects: true, attrs: { opacity: 1, filters: [] } }) + this.renderer.rasterize({ rect, replaceObjects: true, attrs: { opacity: 1 } }) ); if (rasterizeResult.isErr()) { toast({ status: 'error', title: 'Failed to crop to bbox' }); diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer.ts index 7e31b594fac..06620584fc5 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer.ts @@ -72,7 +72,7 @@ export class CanvasEntityAdapterControlLayer extends CanvasEntityAdapterBase< this.log.trace({ rect }, 'Getting canvas'); // The opacity may have been changed in response to user selecting a different entity category, so we must restore // the original opacity before rendering the canvas - const attrs: GroupConfig = { opacity: this.state.opacity, filters: [] }; + const attrs: GroupConfig = { opacity: this.state.opacity }; const canvas = this.renderer.getCanvas({ rect, attrs }); return canvas; }; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer.ts index d2fdd12448f..cd8dee6d2f3 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer.ts @@ -71,7 +71,7 @@ export class CanvasEntityAdapterRasterLayer extends CanvasEntityAdapterBase< this.log.trace({ rect }, 'Getting canvas'); // The opacity may have been changed in response to user selecting a different entity category, so we must restore // the original opacity before rendering the canvas - const attrs: GroupConfig = { opacity: this.state.opacity, filters: [] }; + const attrs: GroupConfig = { opacity: this.state.opacity }; const canvas = this.renderer.getCanvas({ rect, attrs }); return canvas; }; From 47ccd33c9c7636910318ad0ee325da092d85fcd7 Mon Sep 17 00:00:00 2001 From: dunkeroni Date: Sun, 10 Aug 2025 01:51:57 -0400 Subject: [PATCH 03/11] keep adjustments on duplicate --- .../web/src/features/controlLayers/store/canvasSlice.ts | 6 ------ 1 file changed, 6 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts index a00a9c588d0..bcc8fb9ffc5 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts @@ -1410,12 +1410,6 @@ const slice = createSlice({ switch (newEntity.type) { case 'raster_layer': newEntity.id = getPrefixedId('raster_layer'); - // Bake-on-copy semantics: adjustments should not carry over as live settings. - // TODO: Actually bake adjustments into pixels before cloning; for now, drop them on the copy. - if ('adjustments' in newEntity) { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (newEntity as any).adjustments = null; - } state.rasterLayers.entities.push(newEntity); break; case 'control_layer': From eb5648ae0b6e311a9cb6c452950c1bda9d449b61 Mon Sep 17 00:00:00 2001 From: dunkeroni Date: Sun, 10 Aug 2025 02:32:46 -0400 Subject: [PATCH 04/11] curves editor --- .../RasterLayerAdjustmentsPanel.tsx | 37 +- .../RasterLayer/RasterLayerCurvesEditor.tsx | 368 ++++++++++++++++++ 2 files changed, 404 insertions(+), 1 deletion(-) create mode 100644 invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor.tsx diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerAdjustmentsPanel.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerAdjustmentsPanel.tsx index 5a159a58f6c..40d6a262da0 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerAdjustmentsPanel.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerAdjustmentsPanel.tsx @@ -1,5 +1,6 @@ import { Button, + ButtonGroup, CompositeNumberInput, CompositeSlider, Flex, @@ -10,6 +11,7 @@ import { Text, } from '@invoke-ai/ui-library'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { RasterLayerCurvesEditor } from 'features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor'; import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; import { rasterLayerAdjustmentsCurvesUpdated, @@ -30,6 +32,7 @@ export const RasterLayerAdjustmentsPanel = memo(() => { const hasAdjustments = Boolean(layer?.adjustments); const enabled = Boolean(layer?.adjustments?.enabled); const collapsed = Boolean(layer?.adjustments?.collapsed); + const mode = layer?.adjustments?.mode ?? 'simple'; const simple = layer?.adjustments?.simple ?? { brightness: 0, contrast: 0, @@ -82,6 +85,28 @@ export const RasterLayerAdjustmentsPanel = memo(() => { ); }, [dispatch, entityIdentifier, collapsed]); + const onSetMode = useCallback( + (nextMode: 'simple' | 'curves') => { + if (!layer?.adjustments) { + return; + } + if (nextMode === mode) { + return; + } + dispatch( + rasterLayerAdjustmentsSet({ + entityIdentifier, + adjustments: { mode: nextMode }, + }) + ); + }, + [dispatch, entityIdentifier, layer?.adjustments, mode] + ); + + // Memoized click handlers to avoid inline arrow functions in JSX + const onClickModeSimple = useCallback(() => onSetMode('simple'), [onSetMode]); + const onClickModeCurves = useCallback(() => onSetMode('curves'), [onSetMode]); + const slider = useMemo( () => ({ @@ -152,13 +177,21 @@ export const RasterLayerAdjustmentsPanel = memo(() => { Adjustments + + + + - {!collapsed && ( + {!collapsed && mode === 'simple' && ( <> {slider.row(t('controlLayers.adjustments.brightness'), simple.brightness, onBrightness)} {slider.row(t('controlLayers.adjustments.contrast'), simple.contrast, onContrast)} @@ -168,6 +201,8 @@ export const RasterLayerAdjustmentsPanel = memo(() => { {slider.row(t('controlLayers.adjustments.sharpness'), simple.sharpness, onSharpness, 0, 1, 0.01)} )} + + {!collapsed && mode === 'curves' && } ); }); diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor.tsx new file mode 100644 index 00000000000..5123860c331 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor.tsx @@ -0,0 +1,368 @@ +import { Flex, Text } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useEntityAdapterContext } from 'features/controlLayers/contexts/EntityAdapterContext'; +import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; +import { rasterLayerAdjustmentsCurvesUpdated } from 'features/controlLayers/store/canvasSlice'; +import { selectEntity } from 'features/controlLayers/store/selectors'; +import type { CanvasRasterLayerState } from 'features/controlLayers/store/types'; +import React, { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'; + +const DEFAULT_POINTS: Array<[number, number]> = [ + [0, 0], + [255, 255], +]; + +type Channel = 'master' | 'r' | 'g' | 'b'; + +const channelColor: Record = { + master: '#888', + r: '#e53e3e', + g: '#38a169', + b: '#3182ce', +}; + +const clamp = (v: number, min: number, max: number) => (v < min ? min : v > max ? max : v); + +const sortPoints = (pts: Array<[number, number]>) => + [...pts] + .sort((a, b) => a[0] - b[0]) + .map(([x, y]) => [clamp(Math.round(x), 0, 255), clamp(Math.round(y), 0, 255)] as [number, number]); + +type CurveGraphProps = { + title: string; + channel: Channel; + points: Array<[number, number]> | undefined; + histogram: number[] | null; + onChange: (pts: Array<[number, number]>) => void; +}; + +const CurveGraph = memo(function CurveGraph(props: CurveGraphProps) { + const { title, channel, points, histogram, onChange } = props; + const canvasRef = useRef(null); + const [localPoints, setLocalPoints] = useState>(sortPoints(points ?? DEFAULT_POINTS)); + const [dragIndex, setDragIndex] = useState(null); + + useEffect(() => { + setLocalPoints(sortPoints(points ?? DEFAULT_POINTS)); + }, [points]); + + const width = 256; + const height = 160; + + const draw = useCallback(() => { + const c = canvasRef.current; + if (!c) { + return; + } + c.width = width; + c.height = height; + const ctx = c.getContext('2d'); + if (!ctx) { + return; + } + + // background + ctx.clearRect(0, 0, width, height); + ctx.fillStyle = '#111'; + ctx.fillRect(0, 0, width, height); + + // grid + ctx.strokeStyle = '#2a2a2a'; + ctx.lineWidth = 1; + for (let i = 0; i <= 4; i++) { + const y = (i * height) / 4; + ctx.beginPath(); + ctx.moveTo(0, y + 0.5); + ctx.lineTo(width, y + 0.5); + ctx.stroke(); + } + for (let i = 0; i <= 4; i++) { + const x = (i * width) / 4; + ctx.beginPath(); + ctx.moveTo(x + 0.5, 0); + ctx.lineTo(x + 0.5, height); + ctx.stroke(); + } + + // histogram + if (histogram) { + const max = Math.max(1, ...histogram); + ctx.fillStyle = '#5557'; + for (let x = 0; x < 256; x++) { + const v = histogram[x] ?? 0; + const h = Math.round((v / max) * (height - 4)); + ctx.fillRect(x, height - h, 1, h); + } + } + + // curve + const pts = sortPoints(localPoints); + ctx.strokeStyle = channelColor[channel]; + ctx.lineWidth = 2; + ctx.beginPath(); + for (let i = 0; i < pts.length; i++) { + const [x, y] = pts[i]!; + const cx = x; + const cy = height - (y / 255) * height; + if (i === 0) { + ctx.moveTo(cx, cy); + } else { + ctx.lineTo(cx, cy); + } + } + ctx.stroke(); + + // control points + for (let i = 0; i < pts.length; i++) { + const [x, y] = pts[i]!; + const cx = x; + const cy = height - (y / 255) * height; + ctx.fillStyle = '#000'; + ctx.beginPath(); + ctx.arc(cx, cy, 3.5, 0, Math.PI * 2); + ctx.fill(); + ctx.strokeStyle = channelColor[channel]; + ctx.lineWidth = 1.5; + ctx.stroke(); + } + + // title + ctx.fillStyle = '#bbb'; + ctx.font = '12px sans-serif'; + ctx.fillText(title, 6, 14); + }, [channel, height, histogram, localPoints, title, width]); + + useEffect(() => { + draw(); + }, [draw]); + + const getNearestPointIndex = useCallback( + (mx: number, my: number) => { + // map canvas y to [0..255] + const yVal = clamp(Math.round(255 - (my / height) * 255), 0, 255); + const xVal = clamp(Math.round(mx), 0, 255); + let best = -1; + let bestDist = 9999; + for (let i = 0; i < localPoints.length; i++) { + const [px, py] = localPoints[i]!; + const dx = px - xVal; + const dy = py - yVal; + const d = dx * dx + dy * dy; + if (d < bestDist) { + best = i; + bestDist = d; + } + } + if (best !== -1 && bestDist <= 20 * 20) { + return best; + } + return -1; + }, + [height, localPoints] + ); + + const handlePointerDown = useCallback( + (e: React.PointerEvent) => { + e.preventDefault(); + e.stopPropagation(); + const rect = (e.target as HTMLCanvasElement).getBoundingClientRect(); + const mx = e.clientX - rect.left; + const my = e.clientY - rect.top; + const idx = getNearestPointIndex(mx, my); + if (idx !== -1 && idx !== 0 && idx !== localPoints.length - 1) { + setDragIndex(idx); + return; + } + // add new point + const xVal = clamp(Math.round(mx), 0, 255); + const yVal = clamp(Math.round(255 - (my / height) * 255), 0, 255); + const next = sortPoints([...localPoints, [xVal, yVal]]); + setLocalPoints(next); + setDragIndex(next.findIndex(([x, y]) => x === xVal && y === yVal)); + }, + [getNearestPointIndex, height, localPoints] + ); + + const handlePointerMove = useCallback( + (e: React.PointerEvent) => { + e.preventDefault(); + e.stopPropagation(); + if (dragIndex === null) { + return; + } + const rect = (e.target as HTMLCanvasElement).getBoundingClientRect(); + const mx = clamp(Math.round(e.clientX - rect.left), 0, 255); + const myPx = clamp(Math.round(255 - ((e.clientY - rect.top) / height) * 255), 0, 255); + setLocalPoints((prev) => { + const next = [...prev]; + // clamp endpoints to ends and keep them immutable + if (dragIndex === 0) { + return prev; + } + if (dragIndex === prev.length - 1) { + return prev; + } + next[dragIndex] = [mx, myPx]; + return sortPoints(next); + }); + }, + [dragIndex, height] + ); + + const commit = useCallback( + (pts: Array<[number, number]>) => { + onChange(sortPoints(pts)); + }, + [onChange] + ); + + const handlePointerUp = useCallback(() => { + setDragIndex(null); + commit(localPoints); + }, [commit, localPoints]); + + const handleDoubleClick = useCallback( + (e: React.MouseEvent) => { + e.preventDefault(); + e.stopPropagation(); + const rect = (e.target as HTMLCanvasElement).getBoundingClientRect(); + const mx = e.clientX - rect.left; + const my = e.clientY - rect.top; + const idx = getNearestPointIndex(mx, my); + if (idx > 0 && idx < localPoints.length - 1) { + const next = localPoints.filter((_, i) => i !== idx); + setLocalPoints(next); + commit(next); + } + }, + [commit, getNearestPointIndex, localPoints] + ); + + const canvasStyle = useMemo( + () => ({ width: '100%', height: height, touchAction: 'none', borderRadius: 4, background: '#111' }), + [height] + ); + + return ( + + ); +}); + +export const RasterLayerCurvesEditor = memo(() => { + const dispatch = useAppDispatch(); + const entityIdentifier = useEntityIdentifierContext<'raster_layer'>(); + const adapter = useEntityAdapterContext<'raster_layer'>('raster_layer'); + const layer = useAppSelector((s) => selectEntity(s.canvas.present, entityIdentifier)) as + | CanvasRasterLayerState + | undefined; + + const [histMaster, setHistMaster] = useState(null); + const [histR, setHistR] = useState(null); + const [histG, setHistG] = useState(null); + const [histB, setHistB] = useState(null); + + const pointsMaster = layer?.adjustments?.curves.master ?? DEFAULT_POINTS; + const pointsR = layer?.adjustments?.curves.r ?? DEFAULT_POINTS; + const pointsG = layer?.adjustments?.curves.g ?? DEFAULT_POINTS; + const pointsB = layer?.adjustments?.curves.b ?? DEFAULT_POINTS; + + const recalcHistogram = useCallback(() => { + try { + const rect = adapter.transformer.getRelativeRect(); + if (rect.width === 0 || rect.height === 0) { + setHistMaster(Array(256).fill(0)); + setHistR(Array(256).fill(0)); + setHistG(Array(256).fill(0)); + setHistB(Array(256).fill(0)); + return; + } + const imageData = adapter.renderer.getImageData({ rect }); + const data = imageData.data; + const len = data.length / 4; + const master = new Array(256).fill(0); + const r = new Array(256).fill(0); + const g = new Array(256).fill(0); + const b = new Array(256).fill(0); + // sample every 4th pixel to lighten work + for (let i = 0; i < len; i += 4) { + const idx = i * 4; + const rv = data[idx] as number; + const gv = data[idx + 1] as number; + const bv = data[idx + 2] as number; + const m = Math.round(0.2126 * rv + 0.7152 * gv + 0.0722 * bv); + if (m >= 0 && m < 256) { + master[m] = (master[m] ?? 0) + 1; + } + if (rv >= 0 && rv < 256) { + r[rv] = (r[rv] ?? 0) + 1; + } + if (gv >= 0 && gv < 256) { + g[gv] = (g[gv] ?? 0) + 1; + } + if (bv >= 0 && bv < 256) { + b[bv] = (b[bv] ?? 0) + 1; + } + } + setHistMaster(master); + setHistR(r); + setHistG(g); + setHistB(b); + } catch { + // ignore + } + }, [adapter]); + + useEffect(() => { + recalcHistogram(); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [layer?.objects, layer?.adjustments]); + + const onChangePoints = useCallback( + (channel: Channel, pts: Array<[number, number]>) => { + dispatch(rasterLayerAdjustmentsCurvesUpdated({ entityIdentifier, channel, points: pts })); + }, + [dispatch, entityIdentifier] + ); + + // Memoize per-channel change handlers to avoid inline lambdas in JSX + const onChangeMaster = useCallback((pts: Array<[number, number]>) => onChangePoints('master', pts), [onChangePoints]); + const onChangeR = useCallback((pts: Array<[number, number]>) => onChangePoints('r', pts), [onChangePoints]); + const onChangeG = useCallback((pts: Array<[number, number]>) => onChangePoints('g', pts), [onChangePoints]); + const onChangeB = useCallback((pts: Array<[number, number]>) => onChangePoints('b', pts), [onChangePoints]); + + const gridStyles: React.CSSProperties = useMemo( + () => ({ display: 'grid', gridTemplateColumns: 'repeat(2, minmax(0, 1fr))', gap: 8 }), + [] + ); + + return ( + + + Curves + +
+ + + + +
+
+ ); +}); + +RasterLayerCurvesEditor.displayName = 'RasterLayerCurvesEditor'; From 0ec08f0fc6bb6e84a6c68c630bc7700b383bbb84 Mon Sep 17 00:00:00 2001 From: dunkeroni Date: Sun, 10 Aug 2025 02:44:06 -0400 Subject: [PATCH 05/11] log scale and panel width compatibility --- .../RasterLayer/RasterLayerCurvesEditor.tsx | 141 +++++++++++++----- 1 file changed, 102 insertions(+), 39 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor.tsx index 5123860c331..7b1af3c5d9b 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor.tsx @@ -48,6 +48,31 @@ const CurveGraph = memo(function CurveGraph(props: CurveGraphProps) { const width = 256; const height = 160; + // inner margins to keep a small buffer from edges (left/right/bottom) and space for title at top + const MARGIN_LEFT = 8; + const MARGIN_RIGHT = 8; + const MARGIN_TOP = 14; + const MARGIN_BOTTOM = 10; + const INNER_WIDTH = width - MARGIN_LEFT - MARGIN_RIGHT; + const INNER_HEIGHT = height - MARGIN_TOP - MARGIN_BOTTOM; + + // helpers to map value-space [0..255] to canvas pixels (respecting inner margins) + const valueToCanvasX = useCallback( + (x: number) => MARGIN_LEFT + (clamp(x, 0, 255) / 255) * INNER_WIDTH, + [INNER_WIDTH] + ); + const valueToCanvasY = useCallback( + (y: number) => MARGIN_TOP + INNER_HEIGHT - (clamp(y, 0, 255) / 255) * INNER_HEIGHT, + [INNER_HEIGHT] + ); + const canvasToValueX = useCallback( + (cx: number) => clamp(Math.round(((cx - MARGIN_LEFT) / INNER_WIDTH) * 255), 0, 255), + [INNER_WIDTH] + ); + const canvasToValueY = useCallback( + (cy: number) => clamp(Math.round(255 - ((cy - MARGIN_TOP) / INNER_HEIGHT) * 255), 0, 255), + [INNER_HEIGHT] + ); const draw = useCallback(() => { const c = canvasRef.current; @@ -66,32 +91,37 @@ const CurveGraph = memo(function CurveGraph(props: CurveGraphProps) { ctx.fillStyle = '#111'; ctx.fillRect(0, 0, width, height); - // grid + // grid inside inner rect ctx.strokeStyle = '#2a2a2a'; ctx.lineWidth = 1; for (let i = 0; i <= 4; i++) { - const y = (i * height) / 4; + const y = MARGIN_TOP + (i * INNER_HEIGHT) / 4; ctx.beginPath(); - ctx.moveTo(0, y + 0.5); - ctx.lineTo(width, y + 0.5); + ctx.moveTo(MARGIN_LEFT + 0.5, y + 0.5); + ctx.lineTo(MARGIN_LEFT + INNER_WIDTH - 0.5, y + 0.5); ctx.stroke(); } for (let i = 0; i <= 4; i++) { - const x = (i * width) / 4; + const x = MARGIN_LEFT + (i * INNER_WIDTH) / 4; ctx.beginPath(); - ctx.moveTo(x + 0.5, 0); - ctx.lineTo(x + 0.5, height); + ctx.moveTo(x + 0.5, MARGIN_TOP + 0.5); + ctx.lineTo(x + 0.5, MARGIN_TOP + INNER_HEIGHT - 0.5); ctx.stroke(); } // histogram if (histogram) { - const max = Math.max(1, ...histogram); + // logarithmic histogram for readability when values vary widely + const logHist = histogram.map((v) => Math.log10((v ?? 0) + 1)); + const max = Math.max(1e-6, ...logHist); ctx.fillStyle = '#5557'; - for (let x = 0; x < 256; x++) { - const v = histogram[x] ?? 0; - const h = Math.round((v / max) * (height - 4)); - ctx.fillRect(x, height - h, 1, h); + const binW = Math.max(1, INNER_WIDTH / 256); + for (let i = 0; i < 256; i++) { + const v = logHist[i] ?? 0; + const h = Math.round((v / max) * (INNER_HEIGHT - 2)); + const x = MARGIN_LEFT + Math.floor(i * binW); + const y = MARGIN_TOP + INNER_HEIGHT - h; + ctx.fillRect(x, y, Math.ceil(binW), h); } } @@ -102,8 +132,8 @@ const CurveGraph = memo(function CurveGraph(props: CurveGraphProps) { ctx.beginPath(); for (let i = 0; i < pts.length; i++) { const [x, y] = pts[i]!; - const cx = x; - const cy = height - (y / 255) * height; + const cx = valueToCanvasX(x); + const cy = valueToCanvasY(y); if (i === 0) { ctx.moveTo(cx, cy); } else { @@ -115,8 +145,8 @@ const CurveGraph = memo(function CurveGraph(props: CurveGraphProps) { // control points for (let i = 0; i < pts.length; i++) { const [x, y] = pts[i]!; - const cx = x; - const cy = height - (y / 255) * height; + const cx = valueToCanvasX(x); + const cy = valueToCanvasY(y); ctx.fillStyle = '#000'; ctx.beginPath(); ctx.arc(cx, cy, 3.5, 0, Math.PI * 2); @@ -129,18 +159,31 @@ const CurveGraph = memo(function CurveGraph(props: CurveGraphProps) { // title ctx.fillStyle = '#bbb'; ctx.font = '12px sans-serif'; - ctx.fillText(title, 6, 14); - }, [channel, height, histogram, localPoints, title, width]); + ctx.fillText(title, MARGIN_LEFT + 2, Math.max(12, MARGIN_TOP - 2)); + }, [ + MARGIN_LEFT, + MARGIN_TOP, + INNER_HEIGHT, + INNER_WIDTH, + channel, + height, + histogram, + localPoints, + title, + valueToCanvasX, + valueToCanvasY, + width, + ]); useEffect(() => { draw(); }, [draw]); const getNearestPointIndex = useCallback( - (mx: number, my: number) => { - // map canvas y to [0..255] - const yVal = clamp(Math.round(255 - (my / height) * 255), 0, 255); - const xVal = clamp(Math.round(mx), 0, 255); + (mxCanvas: number, myCanvas: number) => { + // convert canvas px to value-space [0..255] + const xVal = canvasToValueX(mxCanvas); + const yVal = canvasToValueY(myCanvas); let best = -1; let bestDist = 9999; for (let i = 0; i < localPoints.length; i++) { @@ -158,29 +201,35 @@ const CurveGraph = memo(function CurveGraph(props: CurveGraphProps) { } return -1; }, - [height, localPoints] + [canvasToValueX, canvasToValueY, localPoints] ); const handlePointerDown = useCallback( (e: React.PointerEvent) => { e.preventDefault(); e.stopPropagation(); - const rect = (e.target as HTMLCanvasElement).getBoundingClientRect(); - const mx = e.clientX - rect.left; - const my = e.clientY - rect.top; - const idx = getNearestPointIndex(mx, my); + const c = canvasRef.current; + if (!c) { + return; + } + const rect = c.getBoundingClientRect(); + const scaleX = c.width / rect.width; + const scaleY = c.height / rect.height; + const mxCanvas = (e.clientX - rect.left) * scaleX; + const myCanvas = (e.clientY - rect.top) * scaleY; + const idx = getNearestPointIndex(mxCanvas, myCanvas); if (idx !== -1 && idx !== 0 && idx !== localPoints.length - 1) { setDragIndex(idx); return; } // add new point - const xVal = clamp(Math.round(mx), 0, 255); - const yVal = clamp(Math.round(255 - (my / height) * 255), 0, 255); + const xVal = canvasToValueX(mxCanvas); + const yVal = canvasToValueY(myCanvas); const next = sortPoints([...localPoints, [xVal, yVal]]); setLocalPoints(next); setDragIndex(next.findIndex(([x, y]) => x === xVal && y === yVal)); }, - [getNearestPointIndex, height, localPoints] + [canvasToValueX, canvasToValueY, getNearestPointIndex, localPoints] ); const handlePointerMove = useCallback( @@ -190,9 +239,17 @@ const CurveGraph = memo(function CurveGraph(props: CurveGraphProps) { if (dragIndex === null) { return; } - const rect = (e.target as HTMLCanvasElement).getBoundingClientRect(); - const mx = clamp(Math.round(e.clientX - rect.left), 0, 255); - const myPx = clamp(Math.round(255 - ((e.clientY - rect.top) / height) * 255), 0, 255); + const c = canvasRef.current; + if (!c) { + return; + } + const rect = c.getBoundingClientRect(); + const scaleX = c.width / rect.width; + const scaleY = c.height / rect.height; + const mxCanvas = (e.clientX - rect.left) * scaleX; + const myCanvas = (e.clientY - rect.top) * scaleY; + const mxVal = canvasToValueX(mxCanvas); + const myVal = canvasToValueY(myCanvas); setLocalPoints((prev) => { const next = [...prev]; // clamp endpoints to ends and keep them immutable @@ -202,11 +259,11 @@ const CurveGraph = memo(function CurveGraph(props: CurveGraphProps) { if (dragIndex === prev.length - 1) { return prev; } - next[dragIndex] = [mx, myPx]; + next[dragIndex] = [mxVal, myVal]; return sortPoints(next); }); }, - [dragIndex, height] + [canvasToValueX, canvasToValueY, dragIndex] ); const commit = useCallback( @@ -225,10 +282,16 @@ const CurveGraph = memo(function CurveGraph(props: CurveGraphProps) { (e: React.MouseEvent) => { e.preventDefault(); e.stopPropagation(); - const rect = (e.target as HTMLCanvasElement).getBoundingClientRect(); - const mx = e.clientX - rect.left; - const my = e.clientY - rect.top; - const idx = getNearestPointIndex(mx, my); + const c = canvasRef.current; + if (!c) { + return; + } + const rect = c.getBoundingClientRect(); + const scaleX = c.width / rect.width; + const scaleY = c.height / rect.height; + const mxCanvas = (e.clientX - rect.left) * scaleX; + const myCanvas = (e.clientY - rect.top) * scaleY; + const idx = getNearestPointIndex(mxCanvas, myCanvas); if (idx > 0 && idx < localPoints.length - 1) { const next = localPoints.filter((_, i) => i !== idx); setLocalPoints(next); From 8a5906b35cdfcadfaad2a2d0fd94ff85f41b8bb7 Mon Sep 17 00:00:00 2001 From: dunkeroni Date: Sun, 10 Aug 2025 18:42:48 -0400 Subject: [PATCH 06/11] fix disable toggle reverts to simple view --- .../components/RasterLayer/RasterLayerAdjustmentsPanel.tsx | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerAdjustmentsPanel.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerAdjustmentsPanel.tsx index 40d6a262da0..b0b1f3dfc4e 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerAdjustmentsPanel.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerAdjustmentsPanel.tsx @@ -44,9 +44,8 @@ export const RasterLayerAdjustmentsPanel = memo(() => { const onToggleEnabled = useCallback( (v: boolean) => { - dispatch( - rasterLayerAdjustmentsSet({ entityIdentifier, adjustments: { enabled: v, collapsed: false, mode: 'simple' } }) - ); + // Only toggle the enabled state; preserve current mode/collapsed so users can A/B compare + dispatch(rasterLayerAdjustmentsSet({ entityIdentifier, adjustments: { enabled: v } })); }, [dispatch, entityIdentifier] ); From 1df81e9234fa2098b974a405375fb9d0b3d804c5 Mon Sep 17 00:00:00 2001 From: dunkeroni Date: Sun, 10 Aug 2025 18:55:47 -0400 Subject: [PATCH 07/11] Fix tint not shifting green in negative direction --- .../web/src/features/controlLayers/konva/filters.ts | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts b/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts index 159d2c6da3d..6a8a704e136 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts @@ -61,9 +61,12 @@ export const AdjustmentsSimpleFilter = function (this: KonvaFilterThis, imageDat const tintK = 0.5; const rTempMul = 1 + temperature * tempK; const bTempMul = 1 - temperature * tempK; - const rTintMul = 1 + (tint > 0 ? tint * tintK : -tint * 0); - const gTintMul = 1 - Math.abs(tint) * tintK; - const bTintMul = 1 + (tint > 0 ? tint * tintK : -tint * 0); + // Tint: green <-> magenta. Positive = magenta (R/B up, G down). Negative = green (G up, R/B down). + const t = clamp(tint, -1, 1) * tintK; + const mag = Math.abs(t); + const rTintMul = t >= 0 ? 1 + mag : 1 - mag; + const gTintMul = t >= 0 ? 1 - mag : 1 + mag; + const bTintMul = t >= 0 ? 1 + mag : 1 - mag; // Saturation matrix (HSL-based approximation via luma coefficients) const lumaR = 0.2126; From dba956b6b46d3637429e49d2865f5cfad4f84a44 Mon Sep 17 00:00:00 2001 From: dunkeroni Date: Wed, 13 Aug 2025 01:35:38 -0400 Subject: [PATCH 08/11] Finish button on adjustments --- invokeai/frontend/web/public/locales/en.json | 10 ++- .../components/RasterLayer/RasterLayer.tsx | 1 - .../RasterLayerAdjustmentsPanel.tsx | 29 +++++++-- .../RasterLayer/RasterLayerCurvesEditor.tsx | 65 ++++++++++++------- 4 files changed, 76 insertions(+), 29 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 995034f13dd..9ddf6c0f9e0 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -2022,6 +2022,8 @@ "addAdjustments": "Add Adjustments", "removeAdjustments": "Remove Adjustments", "adjustments": { + "simple": "Simple", + "curves": "Curves", "heading": "Adjustments", "expand": "Expand adjustments", "collapse": "Collapse adjustments", @@ -2030,7 +2032,13 @@ "saturation": "Saturation", "temperature": "Temperature", "tint": "Tint", - "sharpness": "Sharpness" + "sharpness": "Sharpness", + "finish": "Finish", + "reset": "Reset", + "master": "Master", + "red": "Red", + "green": "Green", + "blue": "Blue" }, "regionIsEmpty": "Selected region is empty", "mergeVisible": "Merge Visible", diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx index c0133687816..13dc30dea20 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx @@ -40,7 +40,6 @@ export const RasterLayer = memo(({ id }: Props) => { - {/* Show adjustments UI only when adjustments exist */} { const dispatch = useAppDispatch(); const entityIdentifier = useEntityIdentifierContext<'raster_layer'>(); + const canvasManager = useCanvasManager(); const layer = useAppSelector((s) => selectEntity(s.canvas.present, entityIdentifier)); const { t } = useTranslation(); @@ -154,6 +156,22 @@ export const RasterLayerAdjustmentsPanel = memo(() => { [onToggleEnabled] ); + const onFinish = useCallback(async () => { + // Bake current visual into layer pixels, then clear adjustments + const adapter = canvasManager.getAdapter(entityIdentifier); + if (!adapter || adapter.type !== 'raster_layer_adapter') { + return; + } + const rect = adapter.transformer.getRelativeRect(); + try { + await adapter.renderer.rasterize({ rect, replaceObjects: true }); + // Clear adjustments after baking + dispatch(rasterLayerAdjustmentsSet({ entityIdentifier, adjustments: null })); + } catch { + // no-op; leave state unchanged on failure + } + }, [canvasManager, entityIdentifier, dispatch]); + // Hide the panel entirely until adjustments are added via context menu if (!hasAdjustments) { return null; @@ -178,15 +196,18 @@ export const RasterLayerAdjustmentsPanel = memo(() => { - + diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor.tsx index 7b1af3c5d9b..cc0cb082253 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor.tsx @@ -6,6 +6,7 @@ import { rasterLayerAdjustmentsCurvesUpdated } from 'features/controlLayers/stor import { selectEntity } from 'features/controlLayers/store/selectors'; import type { CanvasRasterLayerState } from 'features/controlLayers/store/types'; import React, { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'; +import { useTranslation } from 'react-i18next'; const DEFAULT_POINTS: Array<[number, number]> = [ [0, 0], @@ -48,10 +49,10 @@ const CurveGraph = memo(function CurveGraph(props: CurveGraphProps) { const width = 256; const height = 160; - // inner margins to keep a small buffer from edges (left/right/bottom) and space for title at top + // inner margins to keep a small buffer from edges (left/right/bottom) const MARGIN_LEFT = 8; const MARGIN_RIGHT = 8; - const MARGIN_TOP = 14; + const MARGIN_TOP = 8; const MARGIN_BOTTOM = 10; const INNER_WIDTH = width - MARGIN_LEFT - MARGIN_RIGHT; const INNER_HEIGHT = height - MARGIN_TOP - MARGIN_BOTTOM; @@ -155,11 +156,6 @@ const CurveGraph = memo(function CurveGraph(props: CurveGraphProps) { ctx.lineWidth = 1.5; ctx.stroke(); } - - // title - ctx.fillStyle = '#bbb'; - ctx.font = '12px sans-serif'; - ctx.fillText(title, MARGIN_LEFT + 2, Math.max(12, MARGIN_TOP - 2)); }, [ MARGIN_LEFT, MARGIN_TOP, @@ -169,7 +165,6 @@ const CurveGraph = memo(function CurveGraph(props: CurveGraphProps) { height, histogram, localPoints, - title, valueToCanvasX, valueToCanvasY, width, @@ -307,16 +302,21 @@ const CurveGraph = memo(function CurveGraph(props: CurveGraphProps) { ); return ( - +
+ + {title} + + +
); }); @@ -324,6 +324,7 @@ export const RasterLayerCurvesEditor = memo(() => { const dispatch = useAppDispatch(); const entityIdentifier = useEntityIdentifierContext<'raster_layer'>(); const adapter = useEntityAdapterContext<'raster_layer'>('raster_layer'); + const { t } = useTranslation(); const layer = useAppSelector((s) => selectEntity(s.canvas.present, entityIdentifier)) as | CanvasRasterLayerState | undefined; @@ -410,19 +411,37 @@ export const RasterLayerCurvesEditor = memo(() => { return ( - Curves + {t('controlLayers.adjustments.curves')}
- - - + + +
); From ea2f953d333be75b6e9c2e5c4b183e9d1e2c4c27 Mon Sep 17 00:00:00 2001 From: dunkeroni Date: Wed, 13 Aug 2025 02:25:50 -0400 Subject: [PATCH 09/11] remove extra title --- .../RasterLayer/RasterLayerCurvesEditor.tsx | 27 +++---------------- 1 file changed, 3 insertions(+), 24 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor.tsx index cc0cb082253..930afe05873 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor.tsx @@ -410,9 +410,6 @@ export const RasterLayerCurvesEditor = memo(() => { return ( - - {t('controlLayers.adjustments.curves')} -
{ histogram={histMaster} onChange={onChangeMaster} /> - - - + + +
); From 613c9159ac43e4d9bc72d685c407e63f1bfd41b8 Mon Sep 17 00:00:00 2001 From: dunkeroni Date: Wed, 13 Aug 2025 03:14:42 -0400 Subject: [PATCH 10/11] remove redundant en.json colors --- invokeai/frontend/web/public/locales/en.json | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 9ddf6c0f9e0..3e2663023b4 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -2035,10 +2035,7 @@ "sharpness": "Sharpness", "finish": "Finish", "reset": "Reset", - "master": "Master", - "red": "Red", - "green": "Green", - "blue": "Blue" + "master": "Master" }, "regionIsEmpty": "Selected region is empty", "mergeVisible": "Merge Visible", From 23e5dca375223678f6b1acd1260e17b5724b56e7 Mon Sep 17 00:00:00 2001 From: dunkeroni Date: Wed, 13 Aug 2025 03:30:55 -0400 Subject: [PATCH 11/11] clean up right click menu --- .../RasterLayer/RasterLayerMenuItems.tsx | 36 ++---------------- .../RasterLayerMenuItemsAdjustments.tsx | 38 +++++++++++++++++++ 2 files changed, 42 insertions(+), 32 deletions(-) create mode 100644 invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments.tsx diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx index a3b7b0d48dc..708f7f29cd6 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx @@ -1,5 +1,4 @@ -import { MenuDivider, MenuItem } from '@invoke-ai/ui-library'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { MenuDivider } from '@invoke-ai/ui-library'; import { IconMenuItemGroup } from 'common/components/IconMenuItem'; import { CanvasEntityMenuItemsArrange } from 'features/controlLayers/components/common/CanvasEntityMenuItemsArrange'; import { CanvasEntityMenuItemsCropToBbox } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCropToBbox'; @@ -10,35 +9,12 @@ import { CanvasEntityMenuItemsMergeDown } from 'features/controlLayers/component import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave'; import { CanvasEntityMenuItemsSelectObject } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSelectObject'; import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform'; +import { RasterLayerMenuItemsAdjustments } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments'; import { RasterLayerMenuItemsConvertToSubMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsConvertToSubMenu'; import { RasterLayerMenuItemsCopyToSubMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsCopyToSubMenu'; -import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; -import { rasterLayerAdjustmentsReset, rasterLayerAdjustmentsSet } from 'features/controlLayers/store/canvasSlice'; -import type { CanvasRasterLayerState } from 'features/controlLayers/store/types'; -import { memo, useCallback } from 'react'; -import { useTranslation } from 'react-i18next'; +import { memo } from 'react'; export const RasterLayerMenuItems = memo(() => { - const dispatch = useAppDispatch(); - const entityIdentifier = useEntityIdentifierContext<'raster_layer'>(); - const { t } = useTranslation(); - const layer = useAppSelector((s) => - s.canvas.present.rasterLayers.entities.find((e: CanvasRasterLayerState) => e.id === entityIdentifier.id) - ); - const hasAdjustments = Boolean(layer?.adjustments); - const onToggleAdjustmentsPresence = useCallback(() => { - if (hasAdjustments) { - dispatch(rasterLayerAdjustmentsReset({ entityIdentifier })); - } else { - dispatch( - rasterLayerAdjustmentsSet({ - entityIdentifier, - adjustments: { enabled: true, collapsed: false, mode: 'simple' }, - }) - ); - } - }, [dispatch, entityIdentifier, hasAdjustments]); - return ( <> @@ -46,14 +22,10 @@ export const RasterLayerMenuItems = memo(() => { - - - {hasAdjustments ? t('controlLayers.removeAdjustments') : t('controlLayers.addAdjustments')} - - + diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments.tsx new file mode 100644 index 00000000000..77a939b7bf9 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments.tsx @@ -0,0 +1,38 @@ +import { MenuItem } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; +import { rasterLayerAdjustmentsReset, rasterLayerAdjustmentsSet } from 'features/controlLayers/store/canvasSlice'; +import type { CanvasRasterLayerState } from 'features/controlLayers/store/types'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiSlidersHorizontalBold } from 'react-icons/pi'; + +export const RasterLayerMenuItemsAdjustments = memo(() => { + const dispatch = useAppDispatch(); + const entityIdentifier = useEntityIdentifierContext<'raster_layer'>(); + const { t } = useTranslation(); + const layer = useAppSelector((s) => + s.canvas.present.rasterLayers.entities.find((e: CanvasRasterLayerState) => e.id === entityIdentifier.id) + ); + const hasAdjustments = Boolean(layer?.adjustments); + const onToggleAdjustmentsPresence = useCallback(() => { + if (hasAdjustments) { + dispatch(rasterLayerAdjustmentsReset({ entityIdentifier })); + } else { + dispatch( + rasterLayerAdjustmentsSet({ + entityIdentifier, + adjustments: { enabled: true, collapsed: false, mode: 'simple' }, + }) + ); + } + }, [dispatch, entityIdentifier, hasAdjustments]); + + return ( + }> + {hasAdjustments ? t('controlLayers.removeAdjustments') : t('controlLayers.addAdjustments')} + + ); +}); + +RasterLayerMenuItemsAdjustments.displayName = 'RasterLayerMenuItemsAdjustments';