â„–10812[Quote]
Here's my script, btw
import torch
from PIL import Image, PngImagePlugin
from transformers import AutoProcessor, AutoModelForCausalLM
import tkinter as tk
from tkinterdnd2 import DND_FILES, TkinterDnD
from tkinter import messagebox
import os
import threading
import re
import time
import piexif
from torch.utils.data import Dataset, DataLoader
from functools import partial
from tqdm import tqdm
file_lock = threading.Lock()
MODEL_NAME = "yayayaaa/florence-2-large-ft-moredetailed"
device = "cuda:0"
torch_dtype = torch.float16
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
cache_dir="./model_cache",
torch_dtype=torch_dtype,
trust_remote_code=True
).to(device)
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
class ImageDataset(Dataset):
def init(self, image_paths, task_prompt):
self.image_paths = image_paths
self.task_prompt = task_prompt
def len(self):
return len(self.image_paths)
def getitem(self, idx):
path = self.image_paths[idx]
try:
img = Image.open(path).convert("RGB")
return {"image": img, "path": path, "task": self.task_prompt}
except Exception as e:
print(f"Error loading image {path}: {e}")
# Return a placeholder to avoid breaking the batch
return {"image": Image.new('RGB', (224, 224)), "path": path, "task": self.task_prompt, "error": str(e)}
def collate_fn(batch):
images = [item["image"] for item in batch]
paths = [item["path"] for item in batch]
tasks = [item["task"] for item in batch]
errors = [item.get("error", None) for item in batch]
return {"images": images, "paths": paths, "tasks": tasks, "errors": errors}
def process_batch(batch):
images = batch["images"]
paths = batch["paths"]
task = batch["tasks"][0]
errors = batch["errors"]
results = []
valid_indices = [i for i, err in enumerate(errors) if err is None]
if not valid_indices:
return [(path, None, err) for path, err in zip(paths, errors)]
valid_images = [images[i] for i in valid_indices]
valid_paths = [paths[i] for i in valid_indices]
try:
inputs = processor(text=[task] * len(valid_images), images=valid_images, return_tensors="pt").to(device, torch_dtype)
with torch.no_grad():
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
)
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
for i, (path, text) in enumerate(zip(valid_paths, generated_texts)):
results.append((path, text, None))
# Add back the error entries
for i, err in enumerate(errors):
if err is not None:
results.append((paths[i], None, err))
return results
except Exception as e:
# If batch processing fails, return error for all
return [(path, None, str(e)) for path in paths]
def clean_filename(text, max_length=200):
"""Clean the caption to be used as a filename with length limit"""
if not isinstance(text, str):
text = str(text)
# Remove placeholders
placeholders = ["<CAPTION>"]
for placeholder in placeholders:
text = text.replace(placeholder, "").strip()
# Strip leading/trailing whitespace
text = text.strip()
# Remove leading articles
text = re.sub(
r'^(a|an|the|in this image we can see a|in this image we can see|in this picture i can see)\b[\s\-_]+',
'',
text,
count=1,
flags=re.IGNORECASE)
text = text.strip()
if text:
text = text[0].upper() + text[1:]
invalid_chars = r'[<>:"/\\|?*]'
text = re.sub(invalid_chars, '', text)
text = re.sub(r'\s+', ' ', text).rstrip('.').strip()
# Trim to maximum length
if len(text) > max_length:
text = text[:max_length-3] + "…"
return text if text else "Untitled"
def rename_file(file_path, caption, replace_completely):
"""Rename the file with the caption with path length considerations"""
dirname = os.path.dirname(file_path)
basename, ext = os.path.splitext(os.path.basename(file_path))
clean_caption = clean_filename(caption)
if not clean_caption:
clean_caption = "image_description"
if replace_completely:
new_basename = clean_caption
else:
new_basename = f"{basename.strip()} - {clean_caption}" if clean_caption else basename.strip()
# Calculate maximum allowed basename length (Windows MAX_PATH is 260)
# Account for dirname length, separators, extension, and potential counter
max_path_length = 250 # Leave some buffer for "(1)", "(2)", etc.
max_basename_length = max_path_length - len(dirname) - len(ext) - 5
if len(new_basename) > max_basename_length:
new_basename = new_basename[:max_basename_length-3] + "…"
new_path = os.path.join(dirname, f"{new_basename}{ext}")
# Handle duplicate filenames
counter = 1
while os.path.exists(new_path):
counter_str = f" ({counter})"
# Make sure there's enough room for the counter
if len(new_basename) + len(counter_str) > max_basename_length:
new_basename = new_basename[:max_basename_length-len(counter_str)-3] + "…"
new_path = os.path.join(dirname, f"{new_basename}{counter_str}{ext}")
counter += 1
os.rename(file_path, new_path)
return new_path
def write_metadata(file_path, caption):
"""Write caption to image metadata without changing file modification time"""
if file_path.lower().endswith(('.png', '.jpg', '.jpeg')):
try:
# Get file's current access and modification times
file_stats = os.stat(file_path)
access_time = file_stats.st_atime
mod_time = file_stats.st_mtime
with Image.open(file_path) as img:
if img.format == 'PNG':
pnginfo = PngImagePlugin.PngInfo()
# Preserve existing metadata
for key, value in img.info.items():
if isinstance(key, str) and isinstance(value, str):
pnginfo.add_text(key, value)
pnginfo.add_text("Comment", caption)
img.save(file_path, pnginfo=pnginfo)
elif img.format in ('JPEG', 'JPEG2000'):
exif_dict = {}
if 'exif' in img.info:
exif_dict = piexif.load(img.info['exif'])
# Create UserComment with encoding header
user_comment = b'ASCII\x00\x00\x00' + caption.encode('utf-8')
exif_dict.setdefault('Exif', {})
exif_dict['Exif'][piexif.ExifIFD.UserComment] = user_comment
exif_bytes = piexif.dump(exif_dict)
img.save(file_path, exif=exif_bytes, quality="keep")
# Restore original file timestamps
os.utime(file_path, (access_time, mod_time))
except Exception as e:
print(f"Metadata write error: {str(e)}")
def process_files(file_paths, task_prompt, rename_mode, metadata_var, message_label, result_label):
batch_size = 14
dataset = ImageDataset(file_paths, task_prompt)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
def processing_thread():
for batch in tqdm(dataloader, desc="Processing batches"):
results = process_batch(batch)
for path, caption, error in results:
with file_lock:
if error:
message_label.config(text=f"Error processing {os.path.basename(path)}: {error}")
continue
result_label.config(text=f"Caption: {caption}")
root.update_idletasks()
# Handle renaming based on the mode
new_path = path
if rename_mode.get() == "append":
new_path = rename_file(path, caption, False)
message_label.config(text=f"Renamed to: {os.path.basename(new_path)}")
elif rename_mode.get() == "replace":
new_path = rename_file(path, caption, True)
message_label.config(text=f"Renamed to: {os.path.basename(new_path)}")
else: # "none" - no rename
message_label.config(text="Caption generated (file not renamed)")
# Write metadata if enabled
if metadata_var.get():
write_metadata(new_path, caption)
threading.Thread(target=processing_thread).start()
def drop(event, rename_mode, metadata_var, message_label, result_label):
raw_data = event.data.strip()
file_paths = []
# Parse file paths with possible spaces enclosed in braces
for match in re.findall(r'\{([^}]+)\}|([^ ]+)', raw_data):
path_part = match[0] or match[1]
if path_part:
file_paths.append(path_part)
if not file_paths:
messagebox.showerror("Error", "No valid files were dropped")
return
task_prompt = caption_type_var.get()
message_label.config(text=f"Processing {len(file_paths)} files…")
process_files(file_paths, task_prompt, rename_mode, metadata_var, message_label, result_label)
root = TkinterDnD.Tk()
root.title("Florence-2 Batch Image Captioning")
root.geometry("500x500")
frame = tk.Frame(root)
frame.pack(padx=20, pady=20, fill="both", expand=True)
message_label = tk.Label(frame, text="Ready", wraplength=400, justify="left", fg="blue")
message_label.pack(fill="x")
label = tk.Label(frame, text="Drag & drop images here", width=50, height=5, bg="lightgrey")
label.pack(pady=(0, 10))
label.drop_target_register(DND_FILES)
options_frame = tk.Frame(frame)
options_frame.pack(fill="x", pady=(0, 10))
# Radio buttons for renaming mode
rename_mode = tk.StringVar(value="append")
rename_frame = tk.LabelFrame(options_frame, text="Rename Options")
rename_frame.pack(fill="x", pady=(0, 10))
tk.Radiobutton(rename_frame, text="Append caption to original filename",
variable=rename_mode, value="append").pack(anchor="w")
tk.Radiobutton(rename_frame, text="Replace entire filename with caption",
variable=rename_mode, value="replace").pack(anchor="w")
tk.Radiobutton(rename_frame, text="Don't rename files (caption only)",
variable=rename_mode, value="none").pack(anchor="w")
metadata_var = tk.BooleanVar(value=False)
metadata_check = tk.Checkbutton(options_frame, text="Write to metadata (PNG Comment/JPG EXIF)", variable=metadata_var)
metadata_check.pack(anchor="w")
caption_type_var = tk.StringVar(root)
caption_type_var.set("<CAPTION>")
caption_frame = tk.Frame(frame)
caption_frame.pack(fill="x", pady=(0, 10))
tk.Label(caption_frame, text="Caption type:").pack(side="left", padx=(0, 10))
caption_type_menu = tk.OptionMenu(caption_frame, caption_type_var, "<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>")
caption_type_menu.pack(side="left")
result_label = tk.Label(frame, text="Caption will be here", wraplength=450, justify="left", anchor="w", padx=10, pady=5, bg="SystemButtonFace")
result_label.pack(fill="x", pady=(0, 10))
label.dnd_bind('<<Drop>>',
partial(drop,
rename_mode=rename_mode,
metadata_var=metadata_var,
message_label=message_label,
result_label=result_label))
root.mainloop()