Files
video_detect/service/model_service.py
2025-09-30 17:17:20 +08:00

131 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from http.client import HTTPException
import numpy as np
import torch
from MySQLdb import MySQLError
from ultralytics import YOLO
import os
from ds.db import db
from service.file_service import get_absolute_path
# 全局变量
current_yolo_model = None
current_model_absolute_path = None # 存储模型绝对路径不依赖model实例
ALLOWED_MODEL_EXT = {"pt"}
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
def load_yolo_model():
"""加载模型并存储绝对路径"""
global current_yolo_model, current_model_absolute_path
model_rel_path = get_enabled_model_rel_path()
print(f"[模型初始化] 加载模型:{model_rel_path}")
# 计算并存储绝对路径
current_model_absolute_path = get_absolute_path(model_rel_path)
print(f"[模型初始化] 绝对路径:{current_model_absolute_path}")
# 检查模型文件
if not os.path.exists(current_model_absolute_path):
raise FileNotFoundError(f"模型文件不存在: {current_model_absolute_path}")
try:
new_model = YOLO(current_model_absolute_path)
if torch.cuda.is_available():
new_model.to('cuda')
print("模型已移动到GPU")
else:
print("使用CPU进行推理")
current_yolo_model = new_model
print(f"成功加载模型: {current_model_absolute_path}")
return current_yolo_model
except Exception as e:
print(f"模型加载失败:{str(e)}")
raise
def get_current_model():
"""获取当前模型实例"""
if current_yolo_model is None:
raise ValueError("尚未加载任何YOLO模型请先调用load_yolo_model加载模型")
return current_yolo_model
def detect(image_np, conf_threshold=0.8):
# 1. 输入格式验证
if not isinstance(image_np, np.ndarray):
raise ValueError("输入必须是numpy数组BGR图像")
if image_np.ndim != 3 or image_np.shape[-1] != 3:
raise ValueError(f"输入图像格式错误,需为 (h, w, 3) 的BGR数组当前shape: {image_np.shape}")
detection_results = []
try:
model = get_current_model()
if not current_model_absolute_path:
raise RuntimeError("模型未初始化!请先调用 load_yolo_model 加载模型")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"检测设备:{device} | 置信度阈值:{conf_threshold}")
# 图像尺寸信息
img_height, img_width = image_np.shape[:2]
print(f"输入图像尺寸:{img_width}x{img_height}")
# YOLO检测
print("执行YOLO检测")
results = model.predict(
image_np,
conf=conf_threshold,
device=device,
show=False,
)
# 4. 整理检测结果仅保留Chest类别ID=2
for box in results[0].boxes:
class_id = int(box.cls[0]) # 类别ID
class_name = model.names[class_id]
confidence = float(box.conf[0])
bbox = tuple(map(int, box.xyxy[0]))
# 过滤条件:置信度达标 + 类别为Chestclass_id=2
# and class_id == 2
if confidence >= conf_threshold:
detection_results.append({
"class": class_name,
"confidence": confidence,
"bbox": bbox
})
# 判断是否有目标
has_content = len(detection_results) > 0
return has_content, detection_results
except Exception as e:
error_msg = f"检测过程出错:{str(e)}"
print(error_msg)
return False, None
def get_enabled_model_rel_path():
"""获取数据库中启用的模型相对路径"""
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
query = "SELECT path FROM model WHERE is_default = 1 LIMIT 1"
cursor.execute(query)
result = cursor.fetchone()
if not result or not result.get('path'):
raise HTTPException(status_code=404, detail="未找到启用的默认模型")
return result['path']
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"查询默认模型时发生数据库错误:{str(e)}") from e
except Exception as e:
if isinstance(e, HTTPException):
raise e
raise HTTPException(status_code=500, detail=f"获取默认模型路径失败:{str(e)}") from e
finally:
db.close_connection(conn, cursor)