Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 108 additions & 5 deletions wd14tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
config = get_extension_config()

defaults = {
"model": "wd-v1-4-moat-tagger-v2",
"threshold": 0.35,
"model": "wd-v1-4-vit-tagger",
"threshold": 0.5,
"character_threshold": 0.85,
"replace_underscore": False,
"replace_underscore": True,
"trailing_comma": False,
"exclude_tags": "",
"exclude_tags": "navel, jewelry, breasts, lips",
"ortProviders": ["CUDAExecutionProvider", "CPUExecutionProvider"],
"HF_ENDPOINT": "https://huggingface.co"
}
Expand All @@ -46,8 +46,111 @@ def get_installed_models():
models = [m for m in models if os.path.exists(os.path.join(models_dir, os.path.splitext(m)[0] + ".csv"))]
return models

async def tag(image, model_name, threshold=defaults["threshold"], character_threshold=defaults["character_threshold"], exclude_tags=defaults["exclude_tags"], replace_underscore=True, trailing_comma=False, client_id=None, node=None):
if model_name.endswith(".onnx"):
model_name = model_name[0:-5]
installed = list(get_installed_models())
if not any(model_name + ".onnx" in s for s in installed):
await download_model(model_name, client_id, node)

name = os.path.join(models_dir, model_name + ".onnx")
model = InferenceSession(name, providers=defaults["ortProviders"])

input = model.get_inputs()[0]
height = input.shape[1]

# Resize and pad
ratio = float(height)/max(image.size)
new_size = tuple([int(x*ratio) for x in image.size])
image = image.resize(new_size, Image.LANCZOS)
square = Image.new("RGB", (height, height), (255, 255, 255))
square.paste(image, ((height-new_size[0])//2, (height-new_size[1])//2))

image = np.array(square).astype(np.float32)
image = image[:, :, ::-1] # RGB -> BGR
image = np.expand_dims(image, 0)

# Load tags
tags = []
general_index = None
character_index = None
with open(os.path.join(models_dir, model_name + ".csv")) as f:
reader = csv.reader(f)
next(reader)
for row in reader:
if general_index is None and row[2] == "0":
general_index = reader.line_num - 2
elif character_index is None and row[2] == "4":
character_index = reader.line_num - 2
tag_name = row[1].replace("_", " ") if replace_underscore else row[1]
tags.append(tag_name)

label_name = model.get_outputs()[0].name
probs = model.run([label_name], {input.name: image})[0]
result = list(zip(tags, probs[0]))

general = [item for item in result[general_index:character_index] if item[1] > threshold]
character = [item for item in result[character_index:] if item[1] > character_threshold]

all_tags = character + general

# Step 1: Remove excluded tags
remove = [s.strip() for s in exclude_tags.lower().split(",")]
filtered = [(tag, score) for tag, score in all_tags if tag not in remove]

# Step 2: Deduplicate exact tags (keep highest score)
unique_tags = {}
for tag, score in filtered:
if tag not in unique_tags or score > unique_tags[tag]:
unique_tags[tag] = score

deduped = list(unique_tags.items())

# Step 3: Substring-based specificity filter
specific_tags = []
for tag, score in deduped:
is_subsumed = False
for other, _ in deduped:
if tag != other and f" {tag} " in f" {other} " and len(other) > len(tag):
is_subsumed = True
break
if not is_subsumed:
specific_tags.append((tag, score))

# Step 4: Synonym suppression
synonym_groups = [
{"phone", "cellphone", "smartphone"},
{"breasts", "boobs", "tits"},
{"underwear", "panties", "lingerie"},
{"bra", "brassiere"},
]
final_tags = []
seen_synonyms = set()

for group in synonym_groups:
present = [item for item in specific_tags if item[0] in group]
if present:
# Keep the longest tag
best = max(present, key=lambda x: (len(x[0]), x[1]))
final_tags.append(best)
seen_synonyms.update(group)

# Add all other tags not in synonym sets
for tag, score in specific_tags:
if all(tag not in group for group in synonym_groups):
final_tags.append((tag, score))

# Format result
res = ("" if trailing_comma else ", ").join(
tag.replace("(", "\\(").replace(")", "\\)") + (", " if trailing_comma else "")
for tag, _ in final_tags
)

print(res)
return res


async def tag(image, model_name, threshold=0.35, character_threshold=0.85, exclude_tags="", replace_underscore=True, trailing_comma=False, client_id=None, node=None):
async def tag_old(image, model_name, threshold=defaults["threshold"], character_threshold=defaults["character_threshold"], exclude_tags=defaults["exclude_tags"], replace_underscore=True, trailing_comma=False, client_id=None, node=None):
if model_name.endswith(".onnx"):
model_name = model_name[0:-5]
installed = list(get_installed_models())
Expand Down