atlury's picture
Update app.py
5be1788 verified
raw
history blame contribute delete
No virus
5.23 kB
import gradio as gr
from ultralytics import YOLO
import spaces
import torch
import cv2
import numpy as np
import os
import requests
# Define constants for the new model
ENTITIES_COLORS = {
"Caption": (191, 100, 21),
"Footnote": (2, 62, 115),
"Formula": (140, 80, 58),
"List-item": (168, 181, 69),
"Page-footer": (2, 69, 84),
"Page-header": (83, 115, 106),
"Picture": (255, 72, 88),
"Section-header": (0, 204, 192),
"Table": (116, 127, 127),
"Text": (0, 153, 221),
"Title": (196, 51, 2)
}
BOX_PADDING = 2
# Load pre-trained YOLOv8 models
model_paths = {
"YOLOv8x Model": "yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt",
"YOLOv8m Model": "yolov8m-doclaynet.pt",
"YOLOv8n Model": "yolov8n-doclaynet.pt",
"YOLOv8s Model": "yolov8s-doclaynet.pt",
"DLA Model": "models/dla-model.pt"
}
# Ensure the model files are in the correct location
for model_name, model_path in model_paths.items():
if not os.path.exists(model_path):
# For demonstration, we only download the YOLOv8x model
if model_name == "YOLOv8x Model":
model_url = "https://huggingface.co/DILHTWD/documentlayoutsegmentation_YOLOv8_ondoclaynet/resolve/main/yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt"
response = requests.get(model_url)
with open(model_path, "wb") as f:
f.write(response.content)
# Load models
models = {name: YOLO(path) for name, path in model_paths.items()}
# Get class names from the YOLOv8 models
class_names = list(ENTITIES_COLORS.keys())
@spaces.GPU(duration=60)
def process_image(image, model_choice):
try:
if "YOLOv8" in model_choice:
# Use the selected YOLOv8 model
model = models[model_choice]
results = model(source=image, save=False, show_labels=True, show_conf=True, show_boxes=True)
result = results[0]
# Extract annotated image and labels with class names
annotated_image = result.plot()
detected_areas_labels = "\n".join([
f"{class_names[int(box.cls.item())].upper()}: {float(box.conf):.2f}" for box in result.boxes
])
return annotated_image, detected_areas_labels
elif model_choice == "DLA Model":
# Use the DLA model
image_path = "input_image.jpg" # Temporary save the uploaded image
cv2.imwrite(image_path, cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
image = cv2.imread(image_path)
results = models[model_choice].predict(source=image, conf=0.2, iou=0.8)
boxes = results[0].boxes
if len(boxes) == 0:
return image
for box in boxes:
detection_class_conf = round(box.conf.item(), 2)
cls = class_names[int(box.cls)]
start_box = (int(box.xyxy[0][0]), int(box.xyxy[0][1]))
end_box = (int(box.xyxy[0][2]), int(box.xyxy[0][3]))
line_thickness = round(0.002 * (image.shape[0] + image.shape[1]) / 2) + 1
image = cv2.rectangle(img=image,
pt1=start_box,
pt2=end_box,
color=ENTITIES_COLORS[cls],
thickness=line_thickness)
text = cls + " " + str(detection_class_conf)
font_thickness = max(line_thickness - 1, 1)
(text_w, text_h), _ = cv2.getTextSize(text=text, fontFace=2, fontScale=line_thickness/3, thickness=font_thickness)
image = cv2.rectangle(img=image,
pt1=(start_box[0], start_box[1] - text_h - BOX_PADDING*2),
pt2=(start_box[0] + text_w + BOX_PADDING * 2, start_box[1]),
color=ENTITIES_COLORS[cls],
thickness=-1)
start_text = (start_box[0] + BOX_PADDING, start_box[1] - BOX_PADDING)
image = cv2.putText(img=image, text=text, org=start_text, fontFace=0, color=(255,255,255), fontScale=line_thickness/3, thickness=font_thickness)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB), "Labels: " + ", ".join(class_names)
else:
return None, "Invalid model choice"
except Exception as e:
return None, f"Error processing image: {e}"
# Create the Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# Document Layout Segmentation Comparison (ZeroGPU)")
with gr.Row():
input_image = gr.Image(type="pil", label="Upload Image")
output_image = gr.Image(type="pil", label="Annotated Image")
model_choice = gr.Dropdown(list(model_paths.keys()), label="Select Model", value="YOLOv8x Model", scale=0.5)
output_text = gr.Textbox(label="Detected Areas and Labels")
btn = gr.Button("Run Document Segmentation")
btn.click(fn=process_image, inputs=[input_image, model_choice], outputs=[output_image, output_text])
# Launch the demo with queuing
demo.queue(max_size=1).launch()