Compare commits
30 Commits
bae7785a97
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| de6d1b957a | |||
| 396505d8c2 | |||
| b280cc6f99 | |||
| cac28e990e | |||
| 680e25e3a8 | |||
| 4549e67a68 | |||
| eac8d21c20 | |||
| 3cb83b292e | |||
| ced84e49bc | |||
| a95074feca | |||
| 5959f9994c | |||
| d9192bd964 | |||
| 5ecbac0f9c | |||
| 206652d6bb | |||
| 4be7f7bf14 | |||
| 435b2a0e6c | |||
| ae177ca14a | |||
| d3c4820b73 | |||
| 532a9e75e9 | |||
| 0fe49bf829 | |||
| 2571da3c2d | |||
| 1dd832e18d | |||
| 8ceb92c572 | |||
| 9b3d20511a | |||
| 30bf7c9fcb | |||
| ec6dbfde90 | |||
| 3ed73bd9eb | |||
| 08f8a0e44e | |||
| b5d870a19c | |||
| ea82a33a8f |
2
.idea/Video.iml
generated
@ -2,7 +2,7 @@
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="video" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="Python 3.10" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
||||
2
.idea/misc.xml
generated
@ -3,5 +3,5 @@
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="video" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="video" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
@ -13,7 +13,3 @@ charset = utf8mb4
|
||||
secret_key = 6tsieyd87wefdw2wgeduwte23rfcsd
|
||||
algorithm = HS256
|
||||
access_token_expire_minutes = 30
|
||||
|
||||
[live]
|
||||
rtmp_url = rtmp://192.168.110.65:1935/live/
|
||||
webrtc_url = http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=
|
||||
|
||||
97
core/all.py
Normal file
@ -0,0 +1,97 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL.Image import Image
|
||||
|
||||
from core.establish import get_image_save_path
|
||||
from core.ocr import load_model as ocrLoadModel, detect as ocrDetect
|
||||
from core.face import load_model as faceLoadModel, detect as faceDetect
|
||||
from core.yolo import load_model as yoloLoadModel, detect as yoloDetect
|
||||
# 导入保存路径函数(根据实际文件位置调整导入路径)
|
||||
import numpy as np
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from ds.db import db
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
# 模型加载状态标记(避免重复加载)
|
||||
|
||||
|
||||
_model_loaded = False
|
||||
|
||||
|
||||
def load_model():
|
||||
"""加载所有检测模型(仅首次调用时执行)"""
|
||||
global _model_loaded
|
||||
if _model_loaded:
|
||||
print("模型已加载,无需重复执行")
|
||||
return
|
||||
|
||||
# 依次加载OCR、人脸、YOLO模型
|
||||
ocrLoadModel()
|
||||
faceLoadModel()
|
||||
yoloLoadModel()
|
||||
|
||||
_model_loaded = True
|
||||
print("所有检测模型加载完成")
|
||||
|
||||
|
||||
def save_db(model_type, client_ip, result):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
# 连接数据库
|
||||
conn = db.get_connection()
|
||||
# 往表插入数据
|
||||
cursor = conn.cursor(dictionary=True) # 返回字典格式结果
|
||||
insert_query = """
|
||||
INSERT INTO device_danger (client_ip, type, result)
|
||||
VALUES (%s, %s, %s)
|
||||
"""
|
||||
cursor.execute(insert_query, (client_ip, model_type, result))
|
||||
conn.commit()
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取设备列表失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
|
||||
# 修正后的 detect 函数关键部分
|
||||
def detect(client_ip, frame):
|
||||
# 1. YOLO检测
|
||||
yolo_flag, yolo_result = yoloDetect(frame)
|
||||
if yolo_flag:
|
||||
# model_type 传入 "yolo"(正确)
|
||||
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
|
||||
if full_save_path:
|
||||
cv2.imwrite(full_save_path, frame)
|
||||
print(f"✅ yolo违规图片已保存:{display_path}") # 日志也修正
|
||||
save_db(model_type="yolo", client_ip=client_ip, result=str(display_path))
|
||||
return (True, yolo_result, "yolo")
|
||||
|
||||
# 2. 人脸检测
|
||||
face_flag, face_result = faceDetect(frame)
|
||||
if face_flag:
|
||||
full_save_path, display_path = get_image_save_path(model_type="face", client_ip=client_ip) # 这里改了
|
||||
if full_save_path:
|
||||
cv2.imwrite(full_save_path, frame)
|
||||
print(f"✅ face违规图片已保存:{display_path}") # 日志也修正
|
||||
save_db(model_type="face", client_ip=client_ip, result=str(display_path))
|
||||
return (True, face_result, "face")
|
||||
|
||||
# 3. OCR检测
|
||||
ocr_flag, ocr_result = ocrDetect(frame)
|
||||
if ocr_flag:
|
||||
full_save_path, display_path = get_image_save_path(model_type="ocr", client_ip=client_ip)
|
||||
print(f"✅ ocr违规图片已保存:{display_path}")
|
||||
# 这里改了
|
||||
if full_save_path:
|
||||
cv2.imwrite(full_save_path, frame)
|
||||
print(f"✅ ocr违规图片已保存:{display_path}") # 日志也修正
|
||||
save_db(model_type="ocr", client_ip=client_ip, result=str(display_path))
|
||||
return (True, ocr_result, "ocr")
|
||||
|
||||
# 4. 无违规内容(不保存图片)
|
||||
print(f"❌ 未检测到任何违规内容,不保存图片")
|
||||
return (False, "未检测到任何内容", "none")
|
||||
133
core/establish.py
Normal file
@ -0,0 +1,133 @@
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
from service.device_service import get_unique_client_ips
|
||||
|
||||
|
||||
def create_directory_structure():
|
||||
"""创建项目所需的目录结构,为所有客户端IP预创建基础目录"""
|
||||
try:
|
||||
# 1. 创建根目录下的resource文件夹(存在则跳过,不覆盖子内容)
|
||||
resource_dir = Path("resource")
|
||||
resource_dir.mkdir(exist_ok=True)
|
||||
|
||||
# 2. 在resource下创建dect文件夹
|
||||
dect_dir = resource_dir / "dect"
|
||||
dect_dir.mkdir(exist_ok=True)
|
||||
|
||||
# 3. 在dect下创建三个模型文件夹
|
||||
model_dirs = ["ocr", "face", "yolo"]
|
||||
for model in model_dirs:
|
||||
model_dir = dect_dir / model
|
||||
model_dir.mkdir(exist_ok=True)
|
||||
|
||||
# 4. 调用外部方法获取所有客户端IP地址
|
||||
try:
|
||||
all_ip_addresses = get_unique_client_ips()
|
||||
|
||||
# 确保返回的是列表类型
|
||||
if not isinstance(all_ip_addresses, list):
|
||||
all_ip_addresses = [all_ip_addresses]
|
||||
|
||||
# 过滤有效IP(去除空字符串和空格)
|
||||
valid_ips = [ip.strip() for ip in all_ip_addresses if ip.strip()]
|
||||
|
||||
if not valid_ips:
|
||||
print("警告: 未获取到有效的客户端IP地址")
|
||||
return
|
||||
|
||||
print(f"获取到的所有客户端IP地址: {valid_ips}")
|
||||
|
||||
# 5. 获取当前日期(年、月)
|
||||
now = datetime.datetime.now()
|
||||
current_year = str(now.year)
|
||||
current_month = str(now.month)
|
||||
|
||||
# 6. 为每个客户端IP在每个模型文件夹下创建年->月的基础目录结构
|
||||
for ip in valid_ips:
|
||||
# 处理IP地址中的特殊字符(将.替换为_,避免路径问题)
|
||||
safe_ip = ip.replace(".", "_")
|
||||
|
||||
for model in model_dirs:
|
||||
# 构建路径: resource/dect/{model}/{safe_ip}/{year}/{month}
|
||||
ip_dir = dect_dir / model / safe_ip
|
||||
year_dir = ip_dir / current_year
|
||||
month_dir = year_dir / current_month
|
||||
|
||||
# 递归创建目录(存在则跳过,不覆盖)
|
||||
month_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理客户端IP和日期目录时发生错误: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"创建基础目录结构时发生错误: {str(e)}")
|
||||
|
||||
|
||||
def get_image_save_path(model_type: str, client_ip: str) -> Tuple[str, str]:
|
||||
"""
|
||||
获取图片保存的「本地完整路径」和「带路由前缀的显示路径」
|
||||
|
||||
参数:
|
||||
model_type: 模型类型,应为"ocr"、"face"或"yolo"
|
||||
client_ip: 检测到违禁的客户端IP地址(原始格式,如192.168.1.101)
|
||||
|
||||
返回:
|
||||
元组 (本地完整保存路径, 带/api/file/前缀的显示路径);若出错则返回 ("", "")
|
||||
"""
|
||||
try:
|
||||
# 验证模型类型有效性
|
||||
valid_models = ["ocr", "face", "yolo"]
|
||||
if model_type not in valid_models:
|
||||
raise ValueError(f"无效的模型类型: {model_type},必须是{valid_models}之一")
|
||||
|
||||
# 1. 验证客户端IP有效性(检查是否在已知IP列表中)
|
||||
all_ip_addresses = get_unique_client_ips()
|
||||
if not isinstance(all_ip_addresses, list):
|
||||
all_ip_addresses = [all_ip_addresses]
|
||||
valid_ips = [ip.strip() for ip in all_ip_addresses if ip.strip()]
|
||||
|
||||
client_ip_stripped = client_ip.strip()
|
||||
if client_ip_stripped not in valid_ips:
|
||||
raise ValueError(f"客户端IP {client_ip} 不在已知IP列表中,无法保存文件")
|
||||
|
||||
# 2. 处理IP地址(将.替换为_,避免路径问题)
|
||||
safe_ip = client_ip_stripped.replace(".", "_")
|
||||
|
||||
# 3. 获取当前日期和毫秒级时间戳(确保文件名唯一)
|
||||
now = datetime.datetime.now()
|
||||
current_year = str(now.year)
|
||||
current_month = str(now.month).zfill(2) # 确保月份为两位数
|
||||
current_day = str(now.day).zfill(2) # 确保日期为两位数
|
||||
timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3] # 取毫秒级时间戳
|
||||
|
||||
# 4. 定义基础目录(用于生成相对路径)
|
||||
base_dir = Path("resource") / "dect"
|
||||
# 构建日级目录(完整路径:resource/dect/{model}/{safe_ip}/{年}/{月}/{日})
|
||||
day_dir = base_dir / model_type / safe_ip / current_year / current_month / current_day
|
||||
day_dir.mkdir(parents=True, exist_ok=True) # 确保目录存在
|
||||
|
||||
# 5. 构建唯一文件名
|
||||
image_filename = f"dect_{model_type}_{safe_ip}_{current_year}{current_month}{current_day}_{timestamp}.jpg"
|
||||
|
||||
# 6. 生成「本地完整路径」(使用系统路径,但在字符串表示时统一为正斜杠)
|
||||
local_full_path = day_dir / image_filename
|
||||
# 转换为字符串并统一使用正斜杠
|
||||
local_full_path_str = str(local_full_path).replace("\\", "/")
|
||||
|
||||
# 7. 生成带路由前缀的显示路径(核心修改部分)
|
||||
# 获取项目根目录(base_dir是resource/dect,向上两级即为项目根目录)
|
||||
project_root = base_dir.parents[1]
|
||||
# 计算相对于项目根目录的路径(包含resource/dect层级)
|
||||
relative_path = local_full_path.relative_to(project_root)
|
||||
# 转换为字符串并统一使用正斜杠
|
||||
relative_path_str = str(relative_path).replace("\\", "/")
|
||||
# 拼接路由前缀
|
||||
routed_display_path = f"/api/file/{relative_path_str}"
|
||||
|
||||
return local_full_path_str, routed_display_path
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取图片保存路径时发生错误: {str(e)}")
|
||||
return "", ""
|
||||
330
core/face.py
Normal file
@ -0,0 +1,330 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import gc
|
||||
import time
|
||||
import threading
|
||||
from insightface.app import FaceAnalysis
|
||||
from service.face_service import get_all_face_name_with_eigenvalue
|
||||
|
||||
# GPU状态检查支持
|
||||
try:
|
||||
import pynvml
|
||||
|
||||
pynvml.nvmlInit()
|
||||
_nvml_available = True
|
||||
except ImportError:
|
||||
print("警告: pynvml库未安装,无法检测GPU状态,默认尝试使用GPU")
|
||||
_nvml_available = False
|
||||
|
||||
# 全局人脸引擎与特征库
|
||||
_face_app = None
|
||||
_known_faces_embeddings = {} # 姓名 -> 归一化特征值的映射
|
||||
_known_faces_names = [] # 已知人脸姓名列表
|
||||
|
||||
# GPU使用状态标记
|
||||
_using_gpu = False # 是否使用GPU
|
||||
_used_gpu_id = -1 # 使用的GPU ID(-1表示CPU)
|
||||
|
||||
# 资源管理变量
|
||||
_ref_count = 0 # 引擎引用计数(记录当前使用次数)
|
||||
# 修复点1:初始值设为当前时间,避免未加载引擎时用0计算超时
|
||||
_last_used_time = time.time()
|
||||
_lock = threading.Lock() # 线程安全锁
|
||||
_release_timeout = 8 # 闲置超时时间(秒)
|
||||
_is_releasing = False # 资源释放中标记
|
||||
_monitor_thread_running = False # 监控线程运行标记
|
||||
|
||||
# 调试计数器
|
||||
_debug_counter = {
|
||||
"engine_created": 0, # 引擎创建次数
|
||||
"engine_released": 0, # 引擎释放次数
|
||||
"detection_calls": 0 # 检测函数调用次数
|
||||
}
|
||||
|
||||
|
||||
def check_gpu_availability(gpu_id, memory_threshold=0.7):
|
||||
"""检查指定GPU的内存使用率是否低于阈值(判定为“可用”)"""
|
||||
if not _nvml_available:
|
||||
return True
|
||||
try:
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
|
||||
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
memory_usage = mem_info.used / mem_info.total
|
||||
return memory_usage < memory_threshold
|
||||
except Exception as e:
|
||||
print(f"检查GPU {gpu_id} 状态失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def select_best_gpu(preferred_gpus=[0, 1]):
|
||||
"""按优先级选择可用GPU,优先0号;均不可用则返回-1(CPU)"""
|
||||
for gpu_id in preferred_gpus:
|
||||
try:
|
||||
# 验证GPU是否存在
|
||||
if _nvml_available:
|
||||
pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
|
||||
# 验证GPU内存是否充足
|
||||
if check_gpu_availability(gpu_id):
|
||||
print(f"GPU {gpu_id} 可用,将使用该GPU")
|
||||
return gpu_id
|
||||
else:
|
||||
if gpu_id == 0:
|
||||
print("GPU 0 内存使用率过高,尝试其他GPU")
|
||||
except Exception as e:
|
||||
print(f"GPU {gpu_id} 不可用或访问失败: {e}")
|
||||
print("所有指定GPU均不可用,将使用CPU计算")
|
||||
return -1
|
||||
|
||||
|
||||
def _release_engine_resources():
|
||||
"""释放人脸引擎的所有资源(模型、特征库、GPU缓存等)"""
|
||||
global _face_app, _is_releasing, _known_faces_embeddings, _known_faces_names, _last_used_time
|
||||
if not _face_app or _is_releasing:
|
||||
return
|
||||
|
||||
try:
|
||||
_is_releasing = True
|
||||
print("开始释放人脸引擎资源...")
|
||||
|
||||
# 释放InsightFace模型资源
|
||||
if hasattr(_face_app, "model"):
|
||||
_face_app.model = None # 显式置空模型引用
|
||||
_face_app = None # 释放引擎实例
|
||||
|
||||
# 清空人脸特征库
|
||||
_known_faces_embeddings.clear()
|
||||
_known_faces_names.clear()
|
||||
|
||||
_debug_counter["engine_released"] += 1
|
||||
print(f"人脸引擎已释放,调试统计: {_debug_counter}")
|
||||
|
||||
# 强制垃圾回收
|
||||
gc.collect()
|
||||
|
||||
# 清理各深度学习框架的GPU缓存
|
||||
# Torch 缓存清理
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
print("Torch GPU缓存已清理")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# TensorFlow 缓存清理
|
||||
try:
|
||||
import tensorflow as tf
|
||||
tf.keras.backend.clear_session()
|
||||
print("TensorFlow会话已清理")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# MXNet 缓存清理(InsightFace底层常用MXNet)
|
||||
try:
|
||||
import mxnet as mx
|
||||
mx.nd.waitall() # 等待所有计算完成并释放资源
|
||||
print("MXNet资源已等待释放")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
print(f"释放资源过程中出错: {e}")
|
||||
finally:
|
||||
_is_releasing = False
|
||||
# 修复点2:释放完成后重置最后使用时间,避免下次加载时复用旧值
|
||||
_last_used_time = time.time()
|
||||
|
||||
|
||||
def _resource_monitor_thread():
|
||||
"""后台监控线程:检测引擎闲置超时,触发资源释放"""
|
||||
global _ref_count, _last_used_time, _face_app, _monitor_thread_running
|
||||
_monitor_thread_running = True
|
||||
while _monitor_thread_running:
|
||||
time.sleep(2) # 缩短检查间隔,加快闲置检测响应
|
||||
with _lock:
|
||||
# 当“引擎存在 + 无引用 + 未在释放中”时,检查闲置时间
|
||||
if _face_app and _ref_count == 0 and not _is_releasing:
|
||||
idle_time = time.time() - _last_used_time
|
||||
if idle_time > _release_timeout:
|
||||
print(f"引擎闲置超时({idle_time:.1f}s > {_release_timeout}s),释放资源")
|
||||
_release_engine_resources()
|
||||
|
||||
|
||||
def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
|
||||
"""加载人脸识别引擎及已知人脸特征库(默认优先用0号GPU)"""
|
||||
global _face_app, _known_faces_embeddings, _known_faces_names, _using_gpu, _used_gpu_id, _last_used_time
|
||||
|
||||
# 启动后台监控线程(确保仅启动一次)
|
||||
if not _monitor_thread_running:
|
||||
threading.Thread(
|
||||
target=_resource_monitor_thread,
|
||||
daemon=True,
|
||||
name="FaceEngineMonitor"
|
||||
).start()
|
||||
print("人脸引擎监控线程已启动")
|
||||
|
||||
# 若正在释放资源,等待释放完成
|
||||
while _is_releasing:
|
||||
time.sleep(0.1)
|
||||
|
||||
# 若引擎已初始化,直接返回
|
||||
if _face_app:
|
||||
return True
|
||||
|
||||
# 初始化InsightFace引擎
|
||||
try:
|
||||
print("正在初始化InsightFace人脸识别引擎...")
|
||||
_face_app = FaceAnalysis(name="buffalo_l", root=os.path.expanduser("~/.insightface"))
|
||||
|
||||
# 选择GPU(优先用0号)
|
||||
ctx_id = 0
|
||||
if prefer_gpu:
|
||||
ctx_id = select_best_gpu(preferred_gpus)
|
||||
_using_gpu = ctx_id != -1
|
||||
_used_gpu_id = ctx_id if _using_gpu else -1
|
||||
|
||||
if _using_gpu:
|
||||
print(f"引擎初始化成功,将使用GPU {ctx_id} 计算")
|
||||
else:
|
||||
print("引擎初始化成功,将使用CPU计算")
|
||||
|
||||
# 准备模型(加载到指定设备)
|
||||
_face_app.prepare(ctx_id=ctx_id, det_size=(640, 640))
|
||||
print("InsightFace引擎初始化完成")
|
||||
|
||||
# 修复点3:引擎初始化成功后,立即更新“最后使用时间”(核心修复)
|
||||
_last_used_time = time.time()
|
||||
|
||||
_debug_counter["engine_created"] += 1
|
||||
print(f"引擎调试统计: {_debug_counter}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"引擎初始化失败: {e}")
|
||||
return False
|
||||
|
||||
# 从服务加载已知人脸的姓名和特征值
|
||||
try:
|
||||
face_data = get_all_face_name_with_eigenvalue()
|
||||
for person_name, eigenvalue_data in face_data.items():
|
||||
# 兼容“numpy数组”和“字符串”格式的特征值
|
||||
if isinstance(eigenvalue_data, np.ndarray):
|
||||
eigenvalue = eigenvalue_data.astype(np.float32)
|
||||
elif isinstance(eigenvalue_data, str):
|
||||
# 清理字符串中的括号、换行等干扰符
|
||||
cleaned = eigenvalue_data.replace("[", "").replace("]", "").replace("\n", "").strip()
|
||||
# 分割并转换为浮点数数组
|
||||
values = [v for v in cleaned.split() if v] # 兼容空格/逗号分隔
|
||||
eigenvalue = np.array(list(map(float, values)), dtype=np.float32)
|
||||
else:
|
||||
print(f"不支持的特征值类型({type(eigenvalue_data)}),跳过 {person_name}")
|
||||
continue
|
||||
|
||||
# 特征值归一化(保证后续相似度计算的一致性)
|
||||
norm = np.linalg.norm(eigenvalue)
|
||||
if norm != 0:
|
||||
eigenvalue = eigenvalue / norm
|
||||
|
||||
_known_faces_embeddings[person_name] = eigenvalue
|
||||
_known_faces_names.append(person_name)
|
||||
|
||||
print(f"成功加载 {len(_known_faces_names)} 个人脸的特征库")
|
||||
|
||||
except Exception as e:
|
||||
print(f"加载人脸特征库失败: {e}")
|
||||
|
||||
return _face_app is not None
|
||||
|
||||
|
||||
def detect(frame, similarity_threshold=0.4):
|
||||
"""
|
||||
检测并识别人脸
|
||||
返回:(是否匹配到已知人脸, 结果描述字符串)
|
||||
"""
|
||||
global _face_app, _known_faces_embeddings, _known_faces_names, _ref_count, _last_used_time
|
||||
|
||||
# 校验输入帧有效性
|
||||
if frame is None or frame.size == 0:
|
||||
return (False, "无效的输入帧数据")
|
||||
|
||||
# 加锁并更新引用计数、最后使用时间
|
||||
engine = None
|
||||
with _lock:
|
||||
_ref_count += 1
|
||||
_last_used_time = time.time()
|
||||
_debug_counter["detection_calls"] += 1
|
||||
|
||||
# 若引擎未初始化且未在释放中,尝试初始化
|
||||
if not _face_app and not _is_releasing:
|
||||
if not load_model(prefer_gpu=True):
|
||||
# 初始化失败,恢复引用计数
|
||||
with _lock:
|
||||
_ref_count = max(0, _ref_count - 1)
|
||||
return (False, "人脸引擎初始化失败")
|
||||
|
||||
engine = _face_app # 获取引擎引用
|
||||
|
||||
# 校验引擎可用性
|
||||
if not engine or len(_known_faces_names) == 0:
|
||||
with _lock:
|
||||
_ref_count = max(0, _ref_count - 1)
|
||||
return (False, "人脸引擎不可用或特征库为空")
|
||||
|
||||
try:
|
||||
# GPU计算时,确保帧数据是连续内存(避免CUDA错误)
|
||||
if _using_gpu and engine is not None and not frame.flags.contiguous:
|
||||
frame = np.ascontiguousarray(frame)
|
||||
|
||||
# 执行人脸检测与特征提取
|
||||
faces = engine.get(frame)
|
||||
except Exception as e:
|
||||
print(f"人脸检测过程出错: {e}")
|
||||
# 出错时尝试重新初始化引擎(可能是GPU状态变化导致)
|
||||
print("尝试重新初始化人脸引擎...")
|
||||
with _lock:
|
||||
_ref_count = max(0, _ref_count - 1)
|
||||
load_model(prefer_gpu=True)
|
||||
return (False, f"检测错误: {str(e)}")
|
||||
|
||||
result_parts = []
|
||||
has_matched_known_face = False # 是否有任意人脸匹配到已知库
|
||||
|
||||
for face in faces:
|
||||
# 归一化当前检测到的人脸特征
|
||||
face_embedding = face.embedding.astype(np.float32)
|
||||
norm = np.linalg.norm(face_embedding)
|
||||
if norm == 0:
|
||||
continue
|
||||
face_embedding = face_embedding / norm
|
||||
|
||||
# 与已知人脸特征逐一比对
|
||||
max_similarity, best_match_name = -1.0, "Unknown"
|
||||
for name in _known_faces_names:
|
||||
known_emb = _known_faces_embeddings[name]
|
||||
similarity = np.dot(face_embedding, known_emb) # 余弦相似度
|
||||
if similarity > max_similarity:
|
||||
max_similarity = similarity
|
||||
best_match_name = name
|
||||
|
||||
# 判断是否匹配成功
|
||||
is_matched = max_similarity >= similarity_threshold
|
||||
if is_matched:
|
||||
has_matched_known_face = True
|
||||
|
||||
# 记录该人脸的检测结果
|
||||
bbox = face.bbox # 人脸边界框
|
||||
result_parts.append(
|
||||
f"{'匹配' if is_matched else '未匹配'}: {best_match_name} "
|
||||
f"(相似度: {max_similarity:.2f}, 边界框: {bbox.astype(int).tolist()})"
|
||||
)
|
||||
|
||||
# 构建最终结果字符串
|
||||
result_str = "未检测到人脸" if not result_parts else "; ".join(result_parts)
|
||||
|
||||
# 释放引用计数(线程安全)
|
||||
with _lock:
|
||||
_ref_count = max(0, _ref_count - 1)
|
||||
# 若仍有引用,更新最后使用时间;若引用为0,也立即标记(加快闲置检测)
|
||||
_last_used_time = time.time()
|
||||
|
||||
return (has_matched_known_face, result_str)
|
||||
BIN
core/models/best.pt
Normal file
253
core/ocr.py
Normal file
@ -0,0 +1,253 @@
|
||||
import os
|
||||
import cv2
|
||||
import gc
|
||||
import time
|
||||
import threading
|
||||
import numpy as np
|
||||
from paddleocr import PaddleOCR
|
||||
from service.sensitive_service import get_all_sensitive_words
|
||||
|
||||
# 解决NumPy 1.20+版本中np.int已移除的兼容性问题
|
||||
try:
|
||||
if not hasattr(np, 'int'):
|
||||
np.int = int
|
||||
except Exception as e:
|
||||
print(f"处理NumPy兼容性时出错: {e}")
|
||||
|
||||
# 全局变量
|
||||
_ocr_engine = None
|
||||
_forbidden_words = set()
|
||||
_conf_threshold = 0.5
|
||||
|
||||
# 资源管理变量
|
||||
_ref_count = 0
|
||||
_last_used_time = 0
|
||||
_lock = threading.Lock()
|
||||
_release_timeout = 5 # 30秒无使用则释放
|
||||
_is_releasing = False # 标记是否正在释放
|
||||
|
||||
# 并行处理配置
|
||||
_max_workers = 4 # 并行处理的线程数
|
||||
|
||||
# 调试用计数器
|
||||
_debug_counter = {
|
||||
"created": 0,
|
||||
"released": 0,
|
||||
"detected": 0
|
||||
}
|
||||
|
||||
|
||||
def _release_engine():
|
||||
"""释放OCR引擎资源"""
|
||||
global _ocr_engine, _is_releasing
|
||||
if not _ocr_engine or _is_releasing:
|
||||
return
|
||||
|
||||
try:
|
||||
_is_releasing = True
|
||||
_ocr_engine = None
|
||||
_debug_counter["released"] += 1
|
||||
print(f"OCR engine released. Stats: {_debug_counter}")
|
||||
|
||||
# 清理GPU缓存
|
||||
gc.collect()
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import paddle
|
||||
if paddle.is_compiled_with_cuda():
|
||||
paddle.device.cuda.empty_cache()
|
||||
except ImportError:
|
||||
pass
|
||||
finally:
|
||||
_is_releasing = False
|
||||
|
||||
|
||||
def _monitor_thread():
|
||||
"""监控线程,优化检查逻辑"""
|
||||
global _ref_count, _last_used_time, _ocr_engine
|
||||
while True:
|
||||
time.sleep(5) # 每5秒检查一次
|
||||
with _lock:
|
||||
if _ocr_engine and _ref_count == 0 and not _is_releasing:
|
||||
elapsed = time.time() - _last_used_time
|
||||
if elapsed > _release_timeout:
|
||||
print(f"Idle timeout ({elapsed:.1f}s > {_release_timeout}s), releasing engine")
|
||||
_release_engine()
|
||||
|
||||
|
||||
def load_model():
|
||||
"""加载违禁词列表和初始化监控线程"""
|
||||
global _forbidden_words
|
||||
|
||||
# 确保监控线程只启动一次
|
||||
if not any(t.name == "OCRMonitor" for t in threading.enumerate()):
|
||||
threading.Thread(target=_monitor_thread, daemon=True, name="OCRMonitor").start()
|
||||
print("OCR monitor thread started")
|
||||
|
||||
# 加载违禁词
|
||||
try:
|
||||
_forbidden_words = get_all_sensitive_words()
|
||||
print(f"Loaded {len(_forbidden_words)} forbidden words")
|
||||
except Exception as e:
|
||||
print(f"Forbidden words load error: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def detect(frame):
|
||||
"""OCR检测,支持并行处理"""
|
||||
global _ocr_engine, _forbidden_words, _conf_threshold, _ref_count, _last_used_time, _max_workers
|
||||
|
||||
# 验证前置条件
|
||||
if not _forbidden_words:
|
||||
return (False, "违禁词未初始化")
|
||||
if frame is None or frame.size == 0:
|
||||
return (False, "无效帧数据")
|
||||
|
||||
# 增加引用计数并获取引擎实例
|
||||
engine = None
|
||||
with _lock:
|
||||
_ref_count += 1
|
||||
_last_used_time = time.time()
|
||||
_debug_counter["detected"] += 1
|
||||
|
||||
# 初始化引擎(如果未初始化且不在释放中)
|
||||
if not _ocr_engine and not _is_releasing:
|
||||
try:
|
||||
# 初始化PaddleOCR,设置并行处理参数
|
||||
_ocr_engine = PaddleOCR(
|
||||
use_angle_cls=True,
|
||||
lang="ch",
|
||||
show_log=False,
|
||||
use_gpu=True,
|
||||
max_text_length=1024,
|
||||
threads=_max_workers
|
||||
)
|
||||
_debug_counter["created"] += 1
|
||||
print(f"PaddleOCR engine initialized with {_max_workers} workers. Stats: {_debug_counter}")
|
||||
except Exception as e:
|
||||
print(f"OCR model load failed: {e}")
|
||||
_ref_count -= 1
|
||||
return (False, f"引擎初始化失败: {str(e)}")
|
||||
|
||||
engine = _ocr_engine
|
||||
|
||||
# 检查引擎是否可用
|
||||
if not engine:
|
||||
with _lock:
|
||||
_ref_count -= 1
|
||||
return (False, "OCR引擎不可用")
|
||||
|
||||
try:
|
||||
# 执行OCR检测
|
||||
ocr_res = engine.ocr(frame, cls=True)
|
||||
|
||||
# 验证OCR结果格式
|
||||
if not ocr_res or not isinstance(ocr_res, list):
|
||||
return (False, "无OCR结果")
|
||||
|
||||
# 处理OCR结果 - 兼容多种格式
|
||||
texts = []
|
||||
confs = []
|
||||
for line in ocr_res:
|
||||
if line is None:
|
||||
continue
|
||||
|
||||
# 处理line可能是列表或直接是文本信息的情况
|
||||
if isinstance(line, list):
|
||||
items_to_process = line
|
||||
else:
|
||||
items_to_process = [line]
|
||||
|
||||
for item in items_to_process:
|
||||
# 精确识别并忽略图片坐标位置信息 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
|
||||
if isinstance(item, list) and len(item) == 4: # 四边形有4个顶点
|
||||
is_coordinate = True
|
||||
for point in item:
|
||||
# 每个顶点应该是包含2个数字的列表
|
||||
if not (isinstance(point, list) and len(point) == 2 and
|
||||
all(isinstance(coord, (int, float)) for coord in point)):
|
||||
is_coordinate = False
|
||||
break
|
||||
if is_coordinate:
|
||||
continue # 是坐标信息,直接忽略
|
||||
|
||||
# 跳过纯数字列表(其他可能的坐标形式)
|
||||
if isinstance(item, list) and all(isinstance(x, (int, float)) for x in item):
|
||||
continue
|
||||
|
||||
# 处理元组形式的文本和置信度 (text, confidence)
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
text, conf = item
|
||||
if isinstance(text, str) and isinstance(conf, (int, float)):
|
||||
texts.append(text.strip())
|
||||
confs.append(float(conf))
|
||||
continue
|
||||
|
||||
# 处理列表形式的[坐标信息, (text, confidence)]
|
||||
if isinstance(item, list) and len(item) >= 2:
|
||||
# 尝试从列表中提取文本和置信度
|
||||
text_data = item[1]
|
||||
if isinstance(text_data, tuple) and len(text_data) == 2:
|
||||
text, conf = text_data
|
||||
if isinstance(text, str) and isinstance(conf, (int, float)):
|
||||
texts.append(text.strip())
|
||||
confs.append(float(conf))
|
||||
continue
|
||||
elif isinstance(text_data, str):
|
||||
# 只有文本没有置信度的情况
|
||||
texts.append(text_data.strip())
|
||||
confs.append(1.0) # 默认最高置信度
|
||||
continue
|
||||
|
||||
# 无法识别的格式,记录日志
|
||||
print(f"无法解析的OCR结果格式: {item}")
|
||||
|
||||
if len(texts) != len(confs):
|
||||
return (False, "OCR结果格式异常")
|
||||
|
||||
# 筛选违禁词
|
||||
vio_info = []
|
||||
for txt, conf in zip(texts, confs):
|
||||
if conf < _conf_threshold:
|
||||
continue
|
||||
matched = [w for w in _forbidden_words if w in txt]
|
||||
if matched:
|
||||
vio_info.append(f"文本: '{txt}' 包含违禁词: {', '.join(matched)} (置信度: {conf:.2f})")
|
||||
|
||||
# 构建结果
|
||||
has_text = len(texts) > 0
|
||||
has_violation = len(vio_info) > 0
|
||||
|
||||
if not has_text:
|
||||
return (False, "未识别到文本")
|
||||
elif has_violation:
|
||||
return (True, "; ".join(vio_info))
|
||||
else:
|
||||
return (False, "未检测到违禁词")
|
||||
|
||||
except Exception as e:
|
||||
print(f"OCR detect error: {e}")
|
||||
return (False, f"检测错误: {str(e)}")
|
||||
|
||||
finally:
|
||||
# 减少引用计数,确保线程安全
|
||||
with _lock:
|
||||
_ref_count = max(0, _ref_count - 1)
|
||||
if _ref_count > 0:
|
||||
_last_used_time = time.time()
|
||||
|
||||
|
||||
def batch_detect(frames):
|
||||
"""批量检测接口,充分利用并行能力"""
|
||||
results = []
|
||||
for frame in frames:
|
||||
results.append(detect(frame))
|
||||
return results
|
||||
137
core/rtc.py
@ -1,137 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||
import aiohttp
|
||||
from ocr.ocr_violation_detector import OCRViolationDetector
|
||||
|
||||
import logging
|
||||
|
||||
# 创建检测器实例
|
||||
detector = OCRViolationDetector(
|
||||
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
|
||||
ocr_confidence_threshold=0.7,
|
||||
log_level=logging.INFO,
|
||||
log_file="ocr_detection.log"
|
||||
)
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("whep_video_puller")
|
||||
|
||||
|
||||
async def whep_pull_video_stream(ip,whep_url):
|
||||
"""
|
||||
通过WHEP从指定URL拉取视频流并在收到每一帧时打印消息
|
||||
|
||||
Args:
|
||||
whep_url: WHEP端点的URL
|
||||
"""
|
||||
pc = RTCPeerConnection()
|
||||
|
||||
# 添加连接状态变化监听
|
||||
@pc.on("connectionstatechange")
|
||||
async def on_connectionstatechange():
|
||||
print(f"连接状态: {pc.connectionState}")
|
||||
|
||||
# 添加ICE连接状态变化监听
|
||||
@pc.on("iceconnectionstatechange")
|
||||
async def on_iceconnectionstatechange():
|
||||
print(f"ICE连接状态: {pc.iceConnectionState}")
|
||||
|
||||
# 添加视频接收器
|
||||
pc.addTransceiver("video", direction="recvonly")
|
||||
|
||||
# 处理接收到的视频轨道
|
||||
@pc.on("track")
|
||||
def on_track(track):
|
||||
print(f"接收到轨道: {track.kind}")
|
||||
if track.kind == "video":
|
||||
print(f"轨道ID: {track.id}")
|
||||
print(f"轨道就绪状态: {track.readyState}")
|
||||
# 创建异步任务来处理视频帧
|
||||
asyncio.ensure_future(handle_video_track(track))
|
||||
|
||||
async def handle_video_track(track):
|
||||
"""处理视频轨道,接收并打印每一帧"""
|
||||
frame_count = 0
|
||||
print("开始处理视频轨道...")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# 尝试接收帧
|
||||
frame = await track.recv()
|
||||
frame_count += 1
|
||||
print(f"收到原始帧 (第{frame_count}帧)")
|
||||
|
||||
# 打印帧的基本信息
|
||||
if hasattr(frame, 'width') and hasattr(frame, 'height'):
|
||||
print(f" 尺寸: {frame.width}x{frame.height}")
|
||||
if hasattr(frame, 'time_base'):
|
||||
print(f" 时间基准: {frame.time_base}")
|
||||
if hasattr(frame, 'pts'):
|
||||
print(f" 显示时间戳: {frame.pts}")
|
||||
|
||||
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
|
||||
|
||||
# 输出检测结果
|
||||
if has_violation:
|
||||
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
|
||||
for word, conf in zip(violations, confidences):
|
||||
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
|
||||
else:
|
||||
detector.logger.info("图片中未检测到违禁词")
|
||||
except Exception as e:
|
||||
print(f"接收帧时出错: {e}")
|
||||
# 等待一段时间后重试
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
# 创建offer
|
||||
offer = await pc.createOffer()
|
||||
await pc.setLocalDescription(offer)
|
||||
|
||||
print(f"本地SDP信息:\n{offer.sdp}")
|
||||
|
||||
# 通过HTTP POST发送offer到WHEP端点
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
whep_url,
|
||||
data=offer.sdp,
|
||||
headers={"Content-Type": "application/sdp"}
|
||||
) as response:
|
||||
if response.status != 201:
|
||||
print(f"WHEP服务器返回错误: {response.status}")
|
||||
print(f"响应内容: {await response.text()}")
|
||||
raise Exception(f"WHEP服务器返回错误: {response.status}")
|
||||
|
||||
# 获取answer SDP
|
||||
answer_sdp = await response.text()
|
||||
|
||||
# 创建RTCSessionDescription对象
|
||||
answer = RTCSessionDescription(sdp=answer_sdp, type="answer")
|
||||
|
||||
print(f"收到远程SDP:\n{answer_sdp}")
|
||||
|
||||
# 设置远程描述
|
||||
await pc.setRemoteDescription(answer)
|
||||
|
||||
print("连接已建立,开始接收视频流...")
|
||||
|
||||
# 保持连接,直到用户中断
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
# 检查连接状态
|
||||
print(f"当前连接状态: {pc.connectionState}")
|
||||
except KeyboardInterrupt:
|
||||
print("用户中断,关闭连接...")
|
||||
finally:
|
||||
await pc.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 替换为你的WHEP端点URL
|
||||
WHEP_URL = "http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream=473b95a47e338301cbd96809ea7ac416"
|
||||
|
||||
# 运行拉流任务
|
||||
asyncio.run(whep_pull_video_stream(WHEP_URL))
|
||||
112
core/rtmp.py
@ -1,112 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import cv2
|
||||
import time
|
||||
from ocr.model_violation_detector import MultiModelViolationDetector
|
||||
|
||||
|
||||
# 配置文件相对路径(根据实际目录结构调整)
|
||||
YOLO_MODEL_PATH = "../ocr/models/best.pt" # 关键修正:从core目录向上一级找ocr文件夹
|
||||
FORBIDDEN_WORDS_PATH = "../ocr/forbidden_words.txt"
|
||||
OCR_CONFIG_PATH = "../ocr/config/1.yaml"
|
||||
KNOWN_FACES_DIR = "../ocr/known_faces"
|
||||
|
||||
# 创建检测器实例
|
||||
detector = MultiModelViolationDetector(
|
||||
forbidden_words_path=FORBIDDEN_WORDS_PATH,
|
||||
ocr_config_path=OCR_CONFIG_PATH,
|
||||
yolo_model_path=YOLO_MODEL_PATH,
|
||||
known_faces_dir=KNOWN_FACES_DIR,
|
||||
ocr_confidence_threshold=0.5
|
||||
)
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("rtmp_video_puller")
|
||||
|
||||
|
||||
async def rtmp_pull_video_stream(rtmp_url):
|
||||
"""
|
||||
通过RTMP从指定URL拉取视频流并进行违规检测
|
||||
"""
|
||||
cap = None # 初始化视频捕获对象
|
||||
try:
|
||||
# 异步打开RTMP流
|
||||
cap = await asyncio.to_thread(
|
||||
cv2.VideoCapture,
|
||||
rtmp_url,
|
||||
cv2.CAP_FFMPEG # 指定FFmpeg后端确保RTMP兼容性
|
||||
)
|
||||
|
||||
# 检查RTMP流是否成功打开
|
||||
is_opened = await asyncio.to_thread(cap.isOpened)
|
||||
if not is_opened:
|
||||
raise Exception(f"RTMP流打开失败: {rtmp_url}(请检查URL有效性和FFmpeg环境)")
|
||||
|
||||
# 获取RTMP流基础信息
|
||||
width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH)
|
||||
height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT)
|
||||
fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS)
|
||||
|
||||
# 处理异常情况
|
||||
fps = fps if fps > 0 else 30.0
|
||||
width, height = int(width), int(height)
|
||||
|
||||
# 打印流初始化成功信息
|
||||
print(f"RTMP流状态: 已成功连接")
|
||||
print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS")
|
||||
print("开始接收视频帧...(按 Ctrl+C 中断)")
|
||||
|
||||
# 初始化帧统计参数
|
||||
frame_count = 0
|
||||
start_time = time.time()
|
||||
|
||||
# 循环读取视频帧
|
||||
while True:
|
||||
ret, frame = await asyncio.to_thread(cap.read)
|
||||
|
||||
if not ret:
|
||||
print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)")
|
||||
break
|
||||
|
||||
frame_count += 1
|
||||
|
||||
# 打印当前帧信息
|
||||
print(f"收到帧 (第{frame_count}帧)")
|
||||
print(f" 帧尺寸: {width}x{height}")
|
||||
print(f" 配置帧率: {fps:.2f} FPS")
|
||||
|
||||
if frame is not None:
|
||||
has_violation, violation_type, details = detector.detect_violations(frame)
|
||||
if has_violation:
|
||||
print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
|
||||
else:
|
||||
print("未检测到任何违规内容")
|
||||
else:
|
||||
print(f"无法读取测试图像")
|
||||
|
||||
# 每100帧统计一次实际接收帧率
|
||||
if frame_count % 100 == 0:
|
||||
elapsed_time = time.time() - start_time
|
||||
actual_fps = frame_count / elapsed_time
|
||||
print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n用户操作: 已通过 Ctrl+C 中断程序")
|
||||
except Exception as e:
|
||||
logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True)
|
||||
print(f"错误信息: {str(e)}")
|
||||
finally:
|
||||
if cap is not None:
|
||||
await asyncio.to_thread(cap.release)
|
||||
print(f"\n资源释放: RTMP流已关闭")
|
||||
print(f"最终统计: 共接收 {frame_count if 'frame_count' in locals() else 0} 帧")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416"
|
||||
|
||||
try:
|
||||
asyncio.run(rtmp_pull_video_stream(RTMP_URL))
|
||||
except Exception as e:
|
||||
print(f"程序启动失败: {str(e)}")
|
||||
52
core/yolo.py
Normal file
@ -0,0 +1,52 @@
|
||||
from ultralytics import YOLO
|
||||
from service.model_service import get_current_yolo_model, get_current_conf_threshold # 新增置信度获取函数
|
||||
|
||||
|
||||
def load_model(model_path=None):
|
||||
"""加载YOLO模型(优先使用带版本校验的默认模型)"""
|
||||
if model_path is None:
|
||||
# 调用带版本校验的模型获取函数(自动判断是否需要重新加载)
|
||||
return get_current_yolo_model()
|
||||
try:
|
||||
# 加载指定路径模型(用于特殊场景)
|
||||
return YOLO(model_path)
|
||||
except Exception as e:
|
||||
print(f"YOLO模型加载失败(指定路径):{str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def detect(frame):
|
||||
"""执行目标检测(使用动态置信度,仅模型版本变化时重新加载)"""
|
||||
# 获取模型(内部已做版本校验,未变化则直接返回缓存)
|
||||
current_model = load_model()
|
||||
if not current_model:
|
||||
return (False, "未加载到最新YOLO模型")
|
||||
|
||||
if frame is None:
|
||||
return (False, "无效输入帧")
|
||||
|
||||
try:
|
||||
# 获取动态置信度(从全局配置中读取)
|
||||
conf_threshold = get_current_conf_threshold()
|
||||
# 用当前模型执行检测(复用缓存,无额外加载耗时)
|
||||
results = current_model(frame, conf=conf_threshold, verbose=False)
|
||||
has_results = len(results[0].boxes) > 0 if results else False
|
||||
|
||||
if not has_results:
|
||||
return (False, "未检测到目标")
|
||||
|
||||
# 构建结果字符串
|
||||
result_parts = []
|
||||
for box in results[0].boxes:
|
||||
cls = int(box.cls[0])
|
||||
conf = float(box.conf[0])
|
||||
bbox = [round(x, 2) for x in box.xyxy[0].tolist()] # 保留两位小数
|
||||
# 从当前模型中获取类别名(确保与模型匹配)
|
||||
class_name = current_model.names[cls] if hasattr(current_model, 'names') else f"类别{cls}"
|
||||
result_parts.append(f"{class_name}(置信度:{conf:.2f},位置:{bbox})")
|
||||
|
||||
return (True, "; ".join(result_parts))
|
||||
|
||||
except Exception as e:
|
||||
print(f"YOLO检测过程出错:{str(e)}")
|
||||
return (False, f"检测错误:{str(e)}")
|
||||
@ -14,4 +14,3 @@ config.read(config_path, encoding="utf-8")
|
||||
SERVER_CONFIG = config["server"]
|
||||
MYSQL_CONFIG = config["mysql"]
|
||||
JWT_CONFIG = config["jwt"]
|
||||
LIVE_CONFIG = config["live"]
|
||||
|
||||
14
ds/db.py
@ -3,6 +3,8 @@ from mysql.connector import Error
|
||||
|
||||
from .config import MYSQL_CONFIG
|
||||
|
||||
# 关键:声明类级别的连接池实例(必须有这一行!)
|
||||
_connection_pool = None # 确保这一行存在,且拼写正确
|
||||
|
||||
class Database:
|
||||
"""MySQL 连接池管理类"""
|
||||
@ -41,6 +43,18 @@ class Database:
|
||||
except Error as e:
|
||||
raise Exception(f"MySQL 连接关闭失败: {str(e)}") from e
|
||||
|
||||
@classmethod
|
||||
def close_all_connections(cls):
|
||||
"""清理连接池(服务重启前调用)"""
|
||||
try:
|
||||
# 先检查属性是否存在,再判断是否有值
|
||||
if hasattr(cls, "_connection_pool") and cls._connection_pool:
|
||||
cls._connection_pool = None # 重置连接池
|
||||
print("[Database] 连接池已重置,旧连接将被自动清理")
|
||||
else:
|
||||
print("[Database] 连接池未初始化或已重置,无需操作")
|
||||
except Exception as e:
|
||||
print(f"[Database] 重置连接池失败: {str(e)}")
|
||||
|
||||
# 暴露数据库操作工具
|
||||
db = Database()
|
||||
|
||||
43
encryption/encrypt_decorator.py
Normal file
@ -0,0 +1,43 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from encryption.encryption import aes_encrypt
|
||||
from schema.response_schema import APIResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def encrypt_response(field: str = "data"):
|
||||
"""接口返回值加密装饰器:正确序列化自定义对象为JSON"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
original_response: APIResponse = await func(*args, **kwargs)
|
||||
field_value = getattr(original_response, field)
|
||||
|
||||
if not field_value:
|
||||
return original_response
|
||||
|
||||
# 自定义JSON序列化函数:处理Pydantic模型和datetime
|
||||
def json_default(obj: Any) -> Any:
|
||||
# 处理Pydantic模型(转换为字典)
|
||||
if isinstance(obj, BaseModel):
|
||||
return obj.model_dump() # Pydantic v2用model_dump(),v1用dict()
|
||||
# 处理datetime(转换为ISO格式字符串)
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
# 其他无法序列化的类型,可根据需要扩展
|
||||
return str(obj) # 作为最后兜底
|
||||
|
||||
# 使用自定义序列化函数,确保生成标准JSON
|
||||
field_value_json = json.dumps(field_value, default=json_default)
|
||||
encrypted_data = aes_encrypt(field_value_json)
|
||||
setattr(original_response, field, encrypted_data)
|
||||
|
||||
return original_response
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
56
encryption/encryption.py
Normal file
@ -0,0 +1,56 @@
|
||||
import os
|
||||
import base64
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
from fastapi import HTTPException
|
||||
|
||||
# 硬编码AES密钥(32字节,AES-256)
|
||||
AES_SECRET_KEY = b"jr1vA6tfWMHOYi6UXw67UuO6fdak2rMa"
|
||||
AES_BLOCK_SIZE = 16 # AES固定块大小
|
||||
|
||||
# 校验密钥长度(确保符合AES规范)
|
||||
valid_key_lengths = [16, 24, 32]
|
||||
if len(AES_SECRET_KEY) not in valid_key_lengths:
|
||||
raise ValueError(
|
||||
f"AES密钥长度必须为{valid_key_lengths}字节,当前为{len(AES_SECRET_KEY)}字节"
|
||||
)
|
||||
|
||||
|
||||
def aes_encrypt(plaintext: str) -> dict:
|
||||
"""AES-CBC模式加密(返回密文+IV,均为Base64编码)"""
|
||||
try:
|
||||
# 生成随机IV(16字节)
|
||||
iv = os.urandom(AES_BLOCK_SIZE)
|
||||
|
||||
# 创建加密器
|
||||
cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv)
|
||||
|
||||
# 明文填充并加密
|
||||
padded_plaintext = pad(plaintext.encode("utf-8"), AES_BLOCK_SIZE)
|
||||
ciphertext = base64.b64encode(cipher.encrypt(padded_plaintext)).decode("utf-8")
|
||||
iv_base64 = base64.b64encode(iv).decode("utf-8")
|
||||
|
||||
return {
|
||||
"ciphertext": ciphertext,
|
||||
"iv": iv_base64,
|
||||
"algorithm": "AES-CBC"
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"AES加密失败:{str(e)}") from e
|
||||
|
||||
|
||||
def aes_decrypt(ciphertext: str, iv: str) -> str:
|
||||
"""AES-CBC模式解密"""
|
||||
try:
|
||||
# 解码Base64
|
||||
ciphertext_bytes = base64.b64decode(ciphertext)
|
||||
iv_bytes = base64.b64decode(iv)
|
||||
|
||||
# 创建解密器
|
||||
cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv_bytes)
|
||||
|
||||
# 解密并去填充
|
||||
decrypted_bytes = unpad(cipher.decrypt(ciphertext_bytes), AES_BLOCK_SIZE)
|
||||
return decrypted_bytes.decode("utf-8")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"AES解密失败:{str(e)}") from e
|
||||
76
main.py
@ -1,43 +1,93 @@
|
||||
import uvicorn
|
||||
import time
|
||||
import os
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
# 原有业务导入
|
||||
from core.all import load_model
|
||||
from ds.config import SERVER_CONFIG
|
||||
from middle.error_handler import global_exception_handler
|
||||
from service.user_service import router as user_router
|
||||
from service.sensitive_service import router as sensitive_router
|
||||
from service.face_service import router as face_router
|
||||
from service.device_service import router as device_router
|
||||
from service.model_service import router as model_router
|
||||
from service.file_service import router as file_router
|
||||
from service.device_danger_service import router as device_danger_router
|
||||
from ws.ws import ws_router, lifespan
|
||||
from core.establish import create_directory_structure
|
||||
|
||||
# ------------------------------
|
||||
# 初始化 FastAPI 应用、指定生命周期管理
|
||||
# ------------------------------
|
||||
# 初始化 FastAPI 应用
|
||||
app = FastAPI(
|
||||
title="内容安全审核后台",
|
||||
description="内容安全审核后台",
|
||||
description="含图片访问服务和动态模型管理",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# ------------------------------
|
||||
ALLOWED_ORIGINS = [
|
||||
"*"
|
||||
]
|
||||
|
||||
# 配置 CORS 中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS, # 允许的前端域名
|
||||
allow_credentials=True, # 允许携带 Cookie
|
||||
allow_methods=["*"], # 允许所有 HTTP 方法
|
||||
allow_headers=["*"], # 允许所有请求头
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
# ------------------------------
|
||||
app.include_router(user_router)
|
||||
app.include_router(device_router)
|
||||
app.include_router(face_router)
|
||||
app.include_router(sensitive_router)
|
||||
app.include_router(model_router)
|
||||
app.include_router(file_router)
|
||||
app.include_router(device_danger_router)
|
||||
app.include_router(ws_router)
|
||||
|
||||
# ------------------------------
|
||||
# 注册全局异常处理器
|
||||
# ------------------------------
|
||||
app.add_exception_handler(Exception, global_exception_handler)
|
||||
|
||||
# ------------------------------
|
||||
# 启动服务
|
||||
# ------------------------------
|
||||
# 主服务启动入口
|
||||
if __name__ == "__main__":
|
||||
create_directory_structure()
|
||||
print(f"[初始化] 目录结构创建完成")
|
||||
|
||||
# 创建模型保存目录
|
||||
MODEL_SAVE_DIR = os.path.join("core", "models")
|
||||
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
|
||||
print(f"[初始化] 模型保存目录:{MODEL_SAVE_DIR}")
|
||||
|
||||
# 确保图片目录存在(原Flask服务负责的目录)
|
||||
BASE_IMAGE_DIR = os.path.abspath(os.path.join("resource", "dect"))
|
||||
if not os.path.exists(BASE_IMAGE_DIR):
|
||||
print(f"[初始化] 图片根目录不存在,创建:{BASE_IMAGE_DIR}")
|
||||
os.makedirs(BASE_IMAGE_DIR, exist_ok=True)
|
||||
|
||||
# 加载检测模型
|
||||
try:
|
||||
load_success = load_model()
|
||||
if load_success:
|
||||
print(f"[初始化] 检测模型加载完成")
|
||||
else:
|
||||
print(f"[初始化] 未找到默认模型,可通过API上传并设置默认模型")
|
||||
except Exception as e:
|
||||
print(f"[初始化] 模型加载警告:{str(e)}(服务仍启动)")
|
||||
|
||||
# 启动 FastAPI 主服务(仅使用8000端口)
|
||||
port = int(SERVER_CONFIG.get("port", 8000))
|
||||
print(f"\n[FastAPI 服务] 准备启动,端口:{port}")
|
||||
print(f"[FastAPI 服务] 接口文档:http://服务器IP:{port}/docs\n")
|
||||
|
||||
uvicorn.run(
|
||||
app="main:app",
|
||||
host="0.0.0.0",
|
||||
port=port,
|
||||
reload=True,
|
||||
ws="websockets"
|
||||
workers=1,
|
||||
ws="websockets",
|
||||
reload=False
|
||||
)
|
||||
|
||||
@ -8,7 +8,6 @@ from passlib.context import CryptContext
|
||||
|
||||
from ds.config import JWT_CONFIG
|
||||
from ds.db import db
|
||||
from service.user_service import UserResponse
|
||||
|
||||
# ------------------------------
|
||||
# 密码加密配置
|
||||
@ -22,9 +21,10 @@ SECRET_KEY = JWT_CONFIG["secret_key"]
|
||||
ALGORITHM = JWT_CONFIG["algorithm"]
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"])
|
||||
|
||||
# OAuth2 依赖(从请求头获取 Token、格式:Bearer <token>)
|
||||
# OAuth2 依赖(从请求头获取 Token、格式: Bearer <token>)
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 密码工具函数
|
||||
# ------------------------------
|
||||
@ -32,10 +32,12 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""验证明文密码与加密密码是否匹配"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""对明文密码进行 bcrypt 加密"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# JWT 工具函数
|
||||
# ------------------------------
|
||||
@ -53,11 +55,15 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 认证依赖(获取当前登录用户)
|
||||
# ------------------------------
|
||||
def get_current_user(token: str = Depends(oauth2_scheme)) -> UserResponse:
|
||||
def get_current_user(token: str = Depends(oauth2_scheme)): # 移除返回类型注解
|
||||
"""从 Token 中解析用户信息、验证通过后返回当前用户"""
|
||||
# 延迟导入、打破循环依赖
|
||||
from schema.user_schema import UserResponse # 在这里导入
|
||||
|
||||
# 认证失败异常
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
||||
@ -8,7 +8,7 @@ from schema.response_schema import APIResponse
|
||||
|
||||
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
"""全局异常处理器:所有未捕获的异常都会在这里统一处理"""
|
||||
"""全局异常处理器: 所有未捕获的异常都会在这里统一处理"""
|
||||
# 1. 请求参数验证错误(Pydantic 校验失败)
|
||||
if isinstance(exc, RequestValidationError):
|
||||
error_details = []
|
||||
@ -18,7 +18,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content=APIResponse(
|
||||
code=400,
|
||||
message=f"请求参数错误:{'; '.join(error_details)}",
|
||||
message=f"请求参数错误: {'; '.join(error_details)}",
|
||||
data=None
|
||||
).model_dump()
|
||||
)
|
||||
@ -52,7 +52,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content=APIResponse(
|
||||
code=500,
|
||||
message=f"数据库错误:{str(exc)}",
|
||||
message=f"数据库错误: {str(exc)}",
|
||||
data=None
|
||||
).model_dump()
|
||||
)
|
||||
@ -62,7 +62,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content=APIResponse(
|
||||
code=500,
|
||||
message=f"服务器内部错误:{str(exc)}",
|
||||
message=f"服务器内部错误: {str(exc)}",
|
||||
data=None
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
@ -1,139 +0,0 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import insightface
|
||||
from insightface.app import FaceAnalysis
|
||||
|
||||
|
||||
class FaceRecognizer:
|
||||
"""
|
||||
封装InsightFace人脸识别功能,支持从文件夹加载已知人脸。
|
||||
"""
|
||||
|
||||
def __init__(self, known_faces_dir: str):
|
||||
self.known_faces_dir = known_faces_dir
|
||||
self.app = self._initialize_insightface()
|
||||
self.known_faces_embeddings = {}
|
||||
self.known_faces_names = []
|
||||
self._load_known_faces()
|
||||
|
||||
def _initialize_insightface(self):
|
||||
"""初始化InsightFace FaceAnalysis应用"""
|
||||
print("初始化InsightFace引擎...")
|
||||
try:
|
||||
app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
|
||||
app.prepare(ctx_id=0, det_size=(640, 640))
|
||||
print("InsightFace引擎初始化完成")
|
||||
return app
|
||||
except Exception as e:
|
||||
print(f"InsightFace初始化失败: {e}")
|
||||
print("请检查依赖是否安装及模型是否可访问")
|
||||
return None
|
||||
|
||||
def _load_known_faces(self):
|
||||
"""加载已知人脸特征"""
|
||||
if not os.path.exists(self.known_faces_dir):
|
||||
print(f"已知人脸目录不存在,已创建: {self.known_faces_dir}")
|
||||
os.makedirs(self.known_faces_dir, exist_ok=True)
|
||||
return
|
||||
|
||||
print(f"从目录加载人脸特征: {self.known_faces_dir}")
|
||||
for person_name in os.listdir(self.known_faces_dir):
|
||||
person_dir = os.path.join(self.known_faces_dir, person_name)
|
||||
if os.path.isdir(person_dir):
|
||||
print(f"处理人物: {person_name}")
|
||||
embeddings = []
|
||||
for filename in os.listdir(person_dir):
|
||||
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
|
||||
image_path = os.path.join(person_dir, filename)
|
||||
try:
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
print(f"无法读取图片: {image_path},已跳过")
|
||||
continue
|
||||
|
||||
faces = self.app.get(img)
|
||||
if faces:
|
||||
embeddings.append(faces[0].embedding)
|
||||
print(f"提取特征成功: {filename}")
|
||||
else:
|
||||
print(f"未检测到人脸: {filename},已跳过")
|
||||
except Exception as e:
|
||||
print(f"处理图片出错 {image_path}: {e}")
|
||||
|
||||
if embeddings:
|
||||
self.known_faces_embeddings[person_name] = np.array(embeddings).mean(axis=0)
|
||||
self.known_faces_names.append(person_name)
|
||||
print(f"人物 {person_name} 加载完成,共 {len(embeddings)} 张照片")
|
||||
else:
|
||||
print(f"人物 {person_name} 无有效特征,已跳过")
|
||||
print(f"人脸加载完成,共 {len(self.known_faces_names)} 人")
|
||||
|
||||
def recognize(self, frame, threshold=0.4):
|
||||
"""识别人脸并返回结果"""
|
||||
if not self.app or not self.known_faces_names:
|
||||
return False, None, None
|
||||
|
||||
faces = self.app.get(frame)
|
||||
if not faces:
|
||||
return False, None, None
|
||||
|
||||
for face in faces:
|
||||
for known_name in self.known_faces_names:
|
||||
known_embedding = self.known_faces_embeddings[known_name]
|
||||
|
||||
embedding1 = face.embedding.astype(np.float32)
|
||||
embedding2 = known_embedding.astype(np.float32)
|
||||
|
||||
dot_product = np.dot(embedding1, embedding2)
|
||||
norm_embedding1 = np.linalg.norm(embedding1)
|
||||
norm_embedding2 = np.linalg.norm(embedding2)
|
||||
|
||||
similarity = 0.0 if (norm_embedding1 == 0 or norm_embedding2 == 0) else (
|
||||
dot_product / (norm_embedding1 * norm_embedding2)
|
||||
)
|
||||
|
||||
if similarity >= threshold:
|
||||
print(f"检测到已知人物: {known_name} (相似度: {similarity:.4f})")
|
||||
return True, known_name, similarity
|
||||
|
||||
return False, None, None
|
||||
|
||||
def test_single_image(self, image_path: str, threshold=0.4):
|
||||
"""测试单张图片识别"""
|
||||
if not os.path.exists(image_path):
|
||||
print(f"图片不存在: {image_path}")
|
||||
return False, None, None
|
||||
|
||||
frame = cv2.imread(image_path)
|
||||
if frame is None:
|
||||
print(f"无法读取图片: {image_path}")
|
||||
return False, None, None
|
||||
|
||||
result, name, similarity = self.recognize(frame, threshold)
|
||||
|
||||
if result:
|
||||
print(f"识别结果: {name} (相似度: {similarity:.4f})")
|
||||
|
||||
faces = self.app.get(frame)
|
||||
for face in faces:
|
||||
bbox = face.bbox.astype(int)
|
||||
cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
|
||||
text = f"{name}: {similarity:.2f}"
|
||||
cv2.putText(frame, text, (bbox[0], bbox[1] - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
|
||||
|
||||
cv2.imshow('识别结果', frame)
|
||||
print("按任意键关闭窗口...")
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
else:
|
||||
print("未识别到已知人脸")
|
||||
|
||||
return result, name, similarity
|
||||
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# recognizer = FaceRecognizer(known_faces_dir="known_faces")
|
||||
# test_image_path = r"F:\OCR\RapidOCR-main\known_faces\B\14sino-qiu02-master1050.jpg"
|
||||
# recognizer.test_single_image(test_image_path, threshold=0.4)
|
||||
@ -1,156 +0,0 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import insightface
|
||||
from insightface.app import FaceAnalysis
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class BinaryFaceFeatureHandler:
|
||||
"""
|
||||
专门处理图片二进制数据的特征提取器,支持分批次接收二进制数据并累积计算平均特征
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.app = self._init_insightface()
|
||||
self.feature_list = [] # 存储所有图片二进制数据提取的特征
|
||||
|
||||
def _init_insightface(self):
|
||||
"""初始化InsightFace引擎"""
|
||||
try:
|
||||
print("正在初始化InsightFace引擎...")
|
||||
app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
|
||||
app.prepare(ctx_id=0, det_size=(640, 640))
|
||||
print("InsightFace引擎初始化完成")
|
||||
return app
|
||||
except Exception as e:
|
||||
print(f"InsightFace初始化失败: {e}")
|
||||
return None
|
||||
|
||||
def add_binary_data(self, binary_data):
|
||||
"""
|
||||
接收单张图片的二进制数据,提取特征并保存
|
||||
|
||||
参数:
|
||||
binary_data: 图片的二进制数据(bytes类型)
|
||||
|
||||
返回:
|
||||
成功提取特征时返回 (True, 特征值numpy数组)
|
||||
失败时返回 (False, None)
|
||||
"""
|
||||
if not self.app:
|
||||
print("引擎未初始化,无法处理")
|
||||
return False, None
|
||||
|
||||
try:
|
||||
# 直接处理二进制数据:转换为图像格式
|
||||
img = Image.open(BytesIO(binary_data))
|
||||
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
# 提取特征
|
||||
faces = self.app.get(frame)
|
||||
if faces:
|
||||
# 获取当前提取的特征值
|
||||
current_feature = faces[0].embedding
|
||||
# 添加到特征列表
|
||||
self.feature_list.append(current_feature)
|
||||
print(f"已累计 {len(self.feature_list)} 个特征")
|
||||
# 返回成功标志和当前特征值
|
||||
return True,current_feature
|
||||
else:
|
||||
print("二进制数据中未检测到人脸")
|
||||
return False, None
|
||||
except Exception as e:
|
||||
print(f"处理二进制数据出错: {e}")
|
||||
return False, None
|
||||
|
||||
def get_average_feature(self, features):
|
||||
"""
|
||||
计算多个特征向量的平均值
|
||||
|
||||
参数:
|
||||
features: 特征值列表,每个元素可以是字符串格式或numpy数组
|
||||
例如: [feature1, feature2, ...]
|
||||
返回:
|
||||
单一平均特征向量的numpy数组,若无可计算数据则返回None
|
||||
"""
|
||||
try:
|
||||
# 验证输入是否为列表且不为空
|
||||
if not isinstance(features, list) or len(features) == 0:
|
||||
print("输入必须是包含至少一个特征值的列表")
|
||||
return None
|
||||
|
||||
# 处理每个特征值
|
||||
processed_features = []
|
||||
for i, embedding in enumerate(features):
|
||||
try:
|
||||
if isinstance(embedding, str):
|
||||
# 处理包含括号和逗号的字符串格式
|
||||
embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip()
|
||||
embedding_list = [float(num) for num in embedding_str.split() if num.strip()]
|
||||
embedding_np = np.array(embedding_list, dtype=np.float32)
|
||||
else:
|
||||
embedding_np = np.array(embedding, dtype=np.float32)
|
||||
|
||||
# 验证特征值格式
|
||||
if len(embedding_np.shape) == 1:
|
||||
processed_features.append(embedding_np)
|
||||
print(f"已添加第 {i + 1} 个特征值用于计算平均值")
|
||||
else:
|
||||
print(f"跳过第 {i + 1} 个特征值,不是一维数组")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理第 {i + 1} 个特征值时出错: {e}")
|
||||
|
||||
# 确保有有效的特征值
|
||||
if not processed_features:
|
||||
print("没有有效的特征值用于计算平均值")
|
||||
return None
|
||||
|
||||
# 检查所有特征向量维度是否相同
|
||||
dims = {feat.shape[0] for feat in processed_features}
|
||||
if len(dims) > 1:
|
||||
print(f"特征值维度不一致,无法计算平均值。检测到的维度: {dims}")
|
||||
return None
|
||||
|
||||
# 计算平均值
|
||||
avg_feature = np.mean(processed_features, axis=0)
|
||||
print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量,维度: {avg_feature.shape[0]}")
|
||||
|
||||
return avg_feature
|
||||
|
||||
except Exception as e:
|
||||
print(f"计算平均特征值时出错: {e}")
|
||||
return None
|
||||
|
||||
# def clear(self):
|
||||
# """清空已存储的特征数据"""
|
||||
# self.feature_list = []
|
||||
# print("已清空所有特征数据")
|
||||
|
||||
|
||||
# # 使用示例
|
||||
# if __name__ == "__main__":
|
||||
# handler = BinaryFaceFeatureHandler()
|
||||
#
|
||||
# # 模拟接收图片二进制数据
|
||||
# try:
|
||||
# # 第一次接收
|
||||
# with open(r"D:\Git\bin\video\ocr\known_faces\B\102-f.jpg_1140x855.jpg", "rb") as f:
|
||||
# bin_data1 = f.read()
|
||||
# success, feature1 = handler.add_binary_data(bin_data1)
|
||||
# if success:
|
||||
# print(f"第一次提取的特征值前5个: {feature1[:5]}")
|
||||
#
|
||||
# # 第二次接收
|
||||
# with open(r"D:\Git\bin\video\ocr\known_faces\B\104-1.jpg", "rb") as f:
|
||||
# bin_data2 = f.read()
|
||||
# success, feature2 = handler.add_binary_data(bin_data2)
|
||||
# if success:
|
||||
# print(f"第二次提取的特征值前5个: {feature2[:5]}")
|
||||
#
|
||||
# # 计算平均值
|
||||
# avg_feature = handler.get_average_feature()
|
||||
#
|
||||
# except Exception as e:
|
||||
# print(f"处理过程出错: {e}")
|
||||
@ -1,18 +0,0 @@
|
||||
打倒习近平
|
||||
打到习近平
|
||||
打倒毛泽东
|
||||
打到毛泽东
|
||||
打到主席
|
||||
打倒主席
|
||||
打到共产主义
|
||||
打倒共产主义
|
||||
打到共产党
|
||||
打倒共产党
|
||||
胖猫
|
||||
法轮功
|
||||
法轮大法
|
||||
法轮大法好
|
||||
法轮功大法好
|
||||
法轮
|
||||
李洪志
|
||||
习近平
|
||||
|
Before Width: | Height: | Size: 195 KiB |
|
Before Width: | Height: | Size: 208 KiB |
|
Before Width: | Height: | Size: 657 KiB |
|
Before Width: | Height: | Size: 53 KiB |
|
Before Width: | Height: | Size: 8.1 KiB |
|
Before Width: | Height: | Size: 14 KiB |
|
Before Width: | Height: | Size: 58 KiB |
|
Before Width: | Height: | Size: 4.9 KiB |
|
Before Width: | Height: | Size: 34 KiB |
|
Before Width: | Height: | Size: 155 KiB |
|
Before Width: | Height: | Size: 386 KiB |
|
Before Width: | Height: | Size: 1.4 MiB |
|
Before Width: | Height: | Size: 62 KiB |
@ -1,49 +0,0 @@
|
||||
#日志文件
|
||||
import logging
|
||||
import sys
|
||||
|
||||
def setup_logger():
|
||||
"""
|
||||
配置一个全局日志记录器,支持输出到控制台和文件。
|
||||
"""
|
||||
# 创建一个日志记录器
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger = logging.getLogger("ViolationDetectorLogger")
|
||||
logger.setLevel(logging.DEBUG) # 设置最低级别为DEBUG
|
||||
|
||||
# 如果已经有处理器了,就不要重复添加,防止日志重复打印
|
||||
if logger.hasHandlers():
|
||||
return logger
|
||||
|
||||
# --- 控制台处理器 ---
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
# 对于控制台,我们只显示INFO及以上级别的信息
|
||||
console_handler.setLevel(logging.INFO)
|
||||
console_formatter = logging.Formatter(
|
||||
'%(asctime)s - %(levelname)s - [%(module)s:%(lineno)d] - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
|
||||
# --- 文件处理器 ---
|
||||
file_handler = logging.FileHandler("violation_detector.log", mode='a', encoding='utf-8')
|
||||
# 对于文件,我们记录所有DEBUG及以上级别的信息
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_formatter = logging.Formatter(
|
||||
'%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
|
||||
# 将处理器添加到日志记录器
|
||||
logger.addHandler(console_handler)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
return logger
|
||||
|
||||
# 创建并导出logger实例
|
||||
logger = setup_logger()
|
||||
@ -1,136 +0,0 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from .ocr_violation_detector import OCRViolationDetector
|
||||
from .yolo_violation_detector import ViolationDetector as YoloViolationDetector
|
||||
from .face_recognizer import FaceRecognizer
|
||||
|
||||
class MultiModelViolationDetector:
|
||||
"""
|
||||
多模型违规检测封装类,串行调用OCR、人脸识别和YOLO模型,任一模型检测到违规即返回结果
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
forbidden_words_path: str,
|
||||
ocr_config_path: str,
|
||||
yolo_model_path: str,
|
||||
known_faces_dir: str,
|
||||
ocr_confidence_threshold: float = 0.5):
|
||||
"""
|
||||
初始化所有检测模型
|
||||
"""
|
||||
# 初始化OCR检测器
|
||||
self.ocr_detector = OCRViolationDetector(
|
||||
forbidden_words_path=forbidden_words_path,
|
||||
ocr_config_path=ocr_config_path,
|
||||
ocr_confidence_threshold=ocr_confidence_threshold
|
||||
)
|
||||
|
||||
# 初始化人脸识别器
|
||||
self.face_recognizer = FaceRecognizer(
|
||||
known_faces_dir=known_faces_dir
|
||||
)
|
||||
|
||||
# 初始化YOLO检测器
|
||||
self.yolo_detector = YoloViolationDetector(
|
||||
model_path=yolo_model_path
|
||||
)
|
||||
|
||||
print("多模型违规检测器初始化完成")
|
||||
|
||||
def detect_violations(self, frame):
|
||||
"""
|
||||
串行调用三个检测模型(OCR → 人脸识别 → YOLO),任一检测到违规即返回结果
|
||||
"""
|
||||
# 1. 首先进行OCR违禁词检测
|
||||
try:
|
||||
ocr_has_violation, ocr_words, ocr_confs = self.ocr_detector.detect(frame)
|
||||
if ocr_has_violation:
|
||||
details = {
|
||||
"words": ocr_words,
|
||||
"confidences": ocr_confs
|
||||
}
|
||||
print(f"警告: OCR检测到违禁内容: {details}")
|
||||
return (True, "ocr", details)
|
||||
except Exception as e:
|
||||
print(f"错误: OCR检测出错: {str(e)}")
|
||||
|
||||
# 2. 接着进行人脸识别检测
|
||||
try:
|
||||
face_has_violation, face_name, face_similarity = self.face_recognizer.recognize(frame)
|
||||
if face_has_violation:
|
||||
details = {
|
||||
"name": face_name,
|
||||
"similarity": face_similarity
|
||||
}
|
||||
print(f"警告: 人脸识别到违规人员: {details}")
|
||||
return (True, "face", details)
|
||||
except Exception as e:
|
||||
print(f"错误: 人脸识别出错: {str(e)}")
|
||||
|
||||
# 3. 最后进行YOLO目标检测
|
||||
try:
|
||||
yolo_results = self.yolo_detector.detect(frame)
|
||||
if len(yolo_results.boxes) > 0:
|
||||
details = {
|
||||
"classes": yolo_results.names,
|
||||
"boxes": yolo_results.boxes.xyxy.tolist(),
|
||||
"confidences": yolo_results.boxes.conf.tolist(),
|
||||
"class_ids": yolo_results.boxes.cls.tolist()
|
||||
}
|
||||
print(f"警告: YOLO检测到违规目标: {details}")
|
||||
return (True, "yolo", details)
|
||||
except Exception as e:
|
||||
print(f"错误: YOLO检测出错: {str(e)}")
|
||||
|
||||
# 所有检测均未发现违规
|
||||
return (False, None, None)
|
||||
|
||||
|
||||
def load_config(config_path: str) -> dict:
|
||||
"""加载YAML配置文件"""
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
return yaml.safe_load(f)
|
||||
except FileNotFoundError:
|
||||
print(f"错误: 配置文件未找到: {config_path}")
|
||||
raise
|
||||
except yaml.YAMLError as e:
|
||||
print(f"错误: 配置文件格式错误: {config_path}, 错误: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"错误: 加载配置文件出错: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# 使用示例
|
||||
# if __name__ == "__main__":
|
||||
# # 加载配置文件
|
||||
# config = load_config("config.yaml") # 配置文件路径,可根据实际情况修改
|
||||
#
|
||||
# # 初始化多模型检测器
|
||||
# detector = MultiModelViolationDetector(
|
||||
# forbidden_words_path=config["forbidden_words_path"],
|
||||
# ocr_config_path=config["ocr_config_path"],
|
||||
# yolo_model_path=config["yolo_model_path"],
|
||||
# known_faces_dir=config["known_faces_dir"],
|
||||
# ocr_confidence_threshold=config.get("ocr_confidence_threshold", 0.5)
|
||||
# )
|
||||
#
|
||||
# # 读取测试图像(可替换为视频帧读取逻辑)
|
||||
# test_image_path = config.get("test_image_path") # 从配置文件获取测试图片路径
|
||||
# if test_image_path:
|
||||
# frame = cv2.imread(test_image_path)
|
||||
#
|
||||
# if frame is not None:
|
||||
# has_violation, violation_type, details = detector.detect_violations(frame)
|
||||
# if has_violation:
|
||||
# print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
|
||||
# else:
|
||||
# print("未检测到任何违规内容")
|
||||
# else:
|
||||
# print(f"无法读取测试图像: {test_image_path}")
|
||||
# else:
|
||||
# print("配置文件中未指定测试图像路径")
|
||||
@ -1,178 +0,0 @@
|
||||
import os
|
||||
import cv2
|
||||
from rapidocr import RapidOCR
|
||||
|
||||
|
||||
class OCRViolationDetector:
|
||||
"""
|
||||
封装RapidOCR引擎,用于检测图像帧中的违禁词。
|
||||
核心功能:加载违禁词、初始化OCR引擎、单帧图像违禁词检测
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
forbidden_words_path: str,
|
||||
ocr_config_path: str,
|
||||
ocr_confidence_threshold: float = 0.5):
|
||||
"""
|
||||
初始化OCR引擎和违禁词列表。
|
||||
|
||||
Args:
|
||||
forbidden_words_path (str): 违禁词列表 .txt 文件的路径。
|
||||
ocr_config_path (str): OCR配置文件(如1.yaml)的路径。
|
||||
ocr_confidence_threshold (float): OCR识别结果的置信度阈值(0~1)。
|
||||
"""
|
||||
# 加载违禁词
|
||||
self.forbidden_words = self._load_forbidden_words(forbidden_words_path)
|
||||
|
||||
# 初始化RapidOCR引擎
|
||||
self.ocr_engine = self._initialize_ocr(ocr_config_path)
|
||||
|
||||
# 校验核心依赖是否就绪
|
||||
self._check_dependencies()
|
||||
|
||||
# 设置置信度阈值(限制在0~1范围)
|
||||
self.OCR_CONFIDENCE_THRESHOLD = max(0.0, min(ocr_confidence_threshold, 1.0))
|
||||
print(f"OCR置信度阈值已设置(范围0~1): {self.OCR_CONFIDENCE_THRESHOLD:.4f}")
|
||||
|
||||
def _load_forbidden_words(self, path: str) -> set:
|
||||
"""
|
||||
从TXT文件加载违禁词(去重、过滤空行,支持UTF-8编码)
|
||||
"""
|
||||
forbidden_words = set()
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(path):
|
||||
print(f"错误:违禁词文件不存在: {path}")
|
||||
return forbidden_words
|
||||
|
||||
# 读取文件并处理内容
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
forbidden_words = {
|
||||
line.strip() for line in f
|
||||
if line.strip() # 跳过空行或纯空格行
|
||||
}
|
||||
print(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)")
|
||||
except UnicodeDecodeError:
|
||||
print(f"错误:违禁词文件编码错误(需UTF-8): {path}")
|
||||
except PermissionError:
|
||||
print(f"错误:无权限读取违禁词文件: {path}")
|
||||
except Exception as e:
|
||||
print(f"错误:加载违禁词失败: {str(e)}")
|
||||
|
||||
return forbidden_words
|
||||
|
||||
def _initialize_ocr(self, config_path: str) -> RapidOCR | None:
|
||||
"""
|
||||
初始化RapidOCR引擎(校验配置文件、捕获初始化异常)
|
||||
"""
|
||||
print("开始初始化RapidOCR引擎...")
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if not os.path.exists(config_path):
|
||||
print(f"错误:OCR配置文件不存在: {config_path}")
|
||||
return None
|
||||
|
||||
# 初始化OCR引擎
|
||||
try:
|
||||
ocr_engine = RapidOCR(config_path=config_path)
|
||||
print("RapidOCR引擎初始化成功")
|
||||
return ocr_engine
|
||||
except ImportError:
|
||||
print("错误:RapidOCR依赖未安装(需执行:pip install rapidocr-onnxruntime)")
|
||||
except Exception as e:
|
||||
print(f"错误:RapidOCR初始化失败: {str(e)}")
|
||||
|
||||
return None
|
||||
|
||||
def _check_dependencies(self) -> None:
|
||||
"""校验OCR引擎和违禁词列表是否就绪"""
|
||||
if not self.ocr_engine:
|
||||
print("警告:⚠️ OCR引擎未就绪,违禁词检测功能将禁用")
|
||||
if not self.forbidden_words:
|
||||
print("警告:⚠️ 违禁词列表为空,违禁词检测功能将禁用")
|
||||
|
||||
def detect(self, frame) -> tuple[bool, list, list]:
|
||||
"""
|
||||
对单帧图像进行OCR违禁词检测(核心方法)
|
||||
|
||||
Args:
|
||||
frame: 输入图像帧(NumPy数组,BGR格式,cv2读取的图像)。
|
||||
|
||||
Returns:
|
||||
tuple[bool, list, list]:
|
||||
- 第一个元素:是否检测到违禁词(True/False);
|
||||
- 第二个元素:检测到的违禁词列表(空列表表示无违禁词);
|
||||
- 第三个元素:对应违禁词的置信度列表(与违禁词列表一一对应)。
|
||||
"""
|
||||
# 初始化返回结果
|
||||
has_violation = False
|
||||
violation_words = []
|
||||
violation_confs = []
|
||||
|
||||
# 前置校验
|
||||
if frame is None or frame.size == 0:
|
||||
print("警告:输入图像帧为空或无效,跳过OCR检测")
|
||||
return has_violation, violation_words, violation_confs
|
||||
if not self.ocr_engine or not self.forbidden_words:
|
||||
print("OCR引擎未就绪或违禁词为空,跳过OCR检测")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
try:
|
||||
# 执行OCR识别
|
||||
print("开始执行OCR识别...")
|
||||
ocr_result = self.ocr_engine(frame)
|
||||
print(f"RapidOCR原始结果: {ocr_result}")
|
||||
|
||||
# 校验OCR结果是否有效
|
||||
if ocr_result is None:
|
||||
print("OCR识别未返回任何结果(图像无文本或识别失败)")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
# 检查txts和scores是否存在且不为None
|
||||
if not hasattr(ocr_result, 'txts') or ocr_result.txts is None:
|
||||
print("警告:OCR结果中txts为None或不存在")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
if not hasattr(ocr_result, 'scores') or ocr_result.scores is None:
|
||||
print("警告:OCR结果中scores为None或不存在")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
# 转为列表并去None
|
||||
if not isinstance(ocr_result.txts, (list, tuple)):
|
||||
print(f"警告:OCR txts不是可迭代类型,实际类型: {type(ocr_result.txts)}")
|
||||
texts = []
|
||||
else:
|
||||
texts = [txt.strip() for txt in ocr_result.txts if txt and isinstance(txt, str)]
|
||||
|
||||
if not isinstance(ocr_result.scores, (list, tuple)):
|
||||
print(f"警告:OCR scores不是可迭代类型,实际类型: {type(ocr_result.scores)}")
|
||||
confidences = []
|
||||
else:
|
||||
confidences = [conf for conf in ocr_result.scores if conf and isinstance(conf, (int, float))]
|
||||
|
||||
# 校验文本和置信度列表长度是否一致
|
||||
if len(texts) != len(confidences):
|
||||
print(f"警告:OCR文本与置信度数量不匹配(文本{len(texts)}个,置信度{len(confidences)}个),跳过检测")
|
||||
return has_violation, violation_words, violation_confs
|
||||
if len(texts) == 0:
|
||||
print("OCR未识别到任何有效文本")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
# 遍历识别结果,筛选违禁词
|
||||
print(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f})")
|
||||
for text, conf in zip(texts, confidences):
|
||||
if conf < self.OCR_CONFIDENCE_THRESHOLD:
|
||||
print(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过")
|
||||
continue
|
||||
matched_words = [word for word in self.forbidden_words if word in text]
|
||||
if matched_words:
|
||||
has_violation = True
|
||||
violation_words.extend(matched_words)
|
||||
violation_confs.extend([conf] * len(matched_words))
|
||||
print(f"警告:检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f})")
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误:OCR检测过程异常: {str(e)}")
|
||||
|
||||
return has_violation, violation_words, violation_confs
|
||||
@ -1,47 +0,0 @@
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
|
||||
class ViolationDetector:
|
||||
"""
|
||||
用于加载YOLOv8 .pt模型并进行违规内容检测的类。
|
||||
"""
|
||||
def __init__(self, model_path):
|
||||
"""
|
||||
初始化检测器。
|
||||
|
||||
Args:
|
||||
model_path (str): YOLO .pt模型的路径。
|
||||
"""
|
||||
print(f"正在从 '{model_path}' 加载YOLO模型...")
|
||||
self.model = YOLO(model_path)
|
||||
print("YOLO模型加载成功。")
|
||||
|
||||
def detect(self, frame):
|
||||
"""
|
||||
对单帧图像进行目标检测。
|
||||
|
||||
Args:
|
||||
frame: 输入的图像帧 (NumPy数组, BGR格式)。
|
||||
|
||||
Returns:
|
||||
ultralytics.engine.results.Results: YOLO的检测结果对象。
|
||||
"""
|
||||
# conf可以根据您的模型效果进行调整
|
||||
# --- 为了测试,我们暂时将置信度调低,例如 0.2 ---
|
||||
results = self.model(frame, conf=0.2)
|
||||
return results[0]
|
||||
|
||||
def draw_boxes(self, frame, result):
|
||||
"""
|
||||
在图像帧上绘制检测框。
|
||||
|
||||
Args:
|
||||
frame: 原始图像帧。
|
||||
result: YOLO的检测结果对象。
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: 绘制了检测框的图像帧。
|
||||
"""
|
||||
# 使用YOLO自带的plot功能,方便快捷
|
||||
annotated_frame = result.plot()
|
||||
return annotated_frame
|
||||
164
rtc/rtc.py
@ -1,164 +0,0 @@
|
||||
import queue
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import threading
|
||||
import time
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
|
||||
from aiortc.mediastreams import MediaStreamTrack
|
||||
|
||||
# 创建一个长度为1的队列,用于生产者和消费者之间的通信
|
||||
frame_queue = queue.Queue(maxsize=1)
|
||||
|
||||
|
||||
class VideoTrack(MediaStreamTrack):
|
||||
"""自定义视频轨道类,继承自MediaStreamTrack"""
|
||||
kind = "video"
|
||||
|
||||
def __init__(self, max_frames=100):
|
||||
super().__init__()
|
||||
self.frames = queue.Queue(maxsize=max_frames)
|
||||
|
||||
async def recv(self):
|
||||
return await super().recv()
|
||||
|
||||
|
||||
def webrtc_producer(webrtc_url):
|
||||
"""
|
||||
生产者方法:从WEBRTC读取视频帧并放入队列
|
||||
仅当队列空时才放入新帧,否则丢弃
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# 创建RTCPeerConnection对象,不使用ICE服务器
|
||||
pc = RTCPeerConnection(RTCConfiguration(iceServers=[]))
|
||||
video_track = VideoTrack()
|
||||
pc.addTrack(video_track)
|
||||
|
||||
@pc.on("track")
|
||||
async def on_track(track):
|
||||
if track.kind == "video":
|
||||
print("接收到视频轨道,开始接收视频帧")
|
||||
while True:
|
||||
# 从轨道接收视频帧
|
||||
frame = await track.recv()
|
||||
# 转换为BGR24格式的NumPy数组
|
||||
frame_bgr24 = frame.to_ndarray(format='bgr24')
|
||||
|
||||
# 检查队列是否为空,为空则加入,否则丢弃
|
||||
if frame_queue.empty():
|
||||
try:
|
||||
frame_queue.put_nowait(frame_bgr24)
|
||||
print("帧已放入队列")
|
||||
except queue.Full:
|
||||
print("队列已满,丢弃帧")
|
||||
else:
|
||||
print("队列非空,丢弃帧")
|
||||
|
||||
async def main():
|
||||
# 创建并发送SDP Offer
|
||||
offer = await pc.createOffer()
|
||||
print("已创建本地SDP Offer")
|
||||
await pc.setLocalDescription(offer)
|
||||
|
||||
# 发送Offer到服务器并接收Answer
|
||||
async with aiohttp.ClientSession() as session:
|
||||
print(f"开始向服务器 {webrtc_url} 发送SDP Offer")
|
||||
async with session.post(
|
||||
webrtc_url,
|
||||
data=offer.sdp.encode(),
|
||||
headers={
|
||||
"Content-Type": "application/sdp",
|
||||
"Content-Length": str(len(offer.sdp))
|
||||
},
|
||||
ssl=False
|
||||
) as response:
|
||||
print("已接收到服务器的响应")
|
||||
answer_sdp = await response.text()
|
||||
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer'))
|
||||
|
||||
# 保持连接
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(0.1)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
print("关闭RTCPeerConnection")
|
||||
await pc.close()
|
||||
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
def frame_consumer(ip):
|
||||
"""
|
||||
消费者方法:从队列中读取帧并处理
|
||||
每次处理后休眠200ms模拟延迟
|
||||
"""
|
||||
print("消费者启动,开始等待帧...")
|
||||
try:
|
||||
while True:
|
||||
# 阻塞等待队列中的帧
|
||||
frame = frame_queue.get()
|
||||
print(f"消费帧,大小: {frame.shape}")
|
||||
|
||||
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
|
||||
|
||||
|
||||
# 输出检测结果
|
||||
if has_violation:
|
||||
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
|
||||
for word, conf in zip(violations, confidences):
|
||||
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
|
||||
else:
|
||||
detector.logger.info("图片中未检测到违禁词")
|
||||
|
||||
|
||||
# 标记任务完成
|
||||
frame_queue.task_done()
|
||||
except KeyboardInterrupt:
|
||||
print("消费者退出")
|
||||
|
||||
|
||||
def start_webrtc_stream(ip, webrtc_url):
|
||||
"""
|
||||
启动WebRTC视频流处理的主方法
|
||||
参数: webrtc_url - WebRTC服务器地址
|
||||
"""
|
||||
print(f"开始连接到WebRTC服务器: {webrtc_url}")
|
||||
|
||||
# 启动生产者线程
|
||||
producer_thread = threading.Thread(
|
||||
target=webrtc_producer,
|
||||
args=(webrtc_url,),
|
||||
daemon=True,
|
||||
name="webrtc-producer"
|
||||
)
|
||||
|
||||
# 启动消费者线程
|
||||
consumer_thread = threading.Thread(
|
||||
target=frame_consumer(ip),
|
||||
daemon=True,
|
||||
name="frame-consumer"
|
||||
)
|
||||
|
||||
producer_thread.start()
|
||||
consumer_thread.start()
|
||||
print("生产者和消费者线程已启动")
|
||||
|
||||
try:
|
||||
# 保持主线程运行
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
print("程序正在退出...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 示例用法
|
||||
# 实际使用时替换为真实的WebRTC服务器地址
|
||||
webrtc_server_url = "http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=677a4845aa48cb8526c811ad56fc5e60"
|
||||
start_webrtc_stream(webrtc_server_url)
|
||||
101
rtmp/rtmp.py
@ -1,101 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import cv2
|
||||
import time
|
||||
|
||||
# 配置日志(与WHEP代码保持一致的日志风格)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("rtmp_video_puller")
|
||||
|
||||
|
||||
async def rtmp_pull_video_stream(rtmp_url):
|
||||
"""
|
||||
通过RTMP从指定URL拉取视频流并在收到每一帧时打印消息
|
||||
功能与WHEP拉流函数对齐:流状态反馈、帧信息打印、帧率统计、异常处理
|
||||
|
||||
Args:
|
||||
rtmp_url: RTMP流的URL地址(如 rtmp://xxx/live/stream_key)
|
||||
"""
|
||||
cap = None # 初始化视频捕获对象
|
||||
try:
|
||||
# 1. 异步打开RTMP流(指定FFmpeg后端确保RTMP兼容性,同步操作通过to_thread避免阻塞事件循环)
|
||||
cap = await asyncio.to_thread(
|
||||
cv2.VideoCapture,
|
||||
rtmp_url,
|
||||
cv2.CAP_FFMPEG # 必须指定FFmpeg后端,RTMP协议依赖该后端解析
|
||||
)
|
||||
|
||||
# 2. 检查RTMP流是否成功打开
|
||||
is_opened = await asyncio.to_thread(cap.isOpened)
|
||||
if not is_opened:
|
||||
raise Exception(f"RTMP流打开失败: {rtmp_url}(请检查URL有效性和FFmpeg环境)")
|
||||
|
||||
# 3. 异步获取RTMP流基础信息(分辨率、帧率)
|
||||
width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH)
|
||||
height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT)
|
||||
fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS)
|
||||
|
||||
# 处理异常情况:部分RTMP流未返回帧率时默认30FPS
|
||||
fps = fps if fps > 0 else 30.0
|
||||
# 分辨率转为整数(视频尺寸必然是整数)
|
||||
width, height = int(width), int(height)
|
||||
|
||||
# 打印流初始化成功信息(与WHEP连接成功信息风格一致)
|
||||
print(f"RTMP流状态: 已成功连接")
|
||||
print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS")
|
||||
print("开始接收视频帧...(按 Ctrl+C 中断)")
|
||||
|
||||
# 4. 初始化帧统计参数
|
||||
frame_count = 0 # 总接收帧数
|
||||
start_time = time.time() # 统计起始时间
|
||||
|
||||
# 5. 循环异步读取视频帧(核心逻辑)
|
||||
while True:
|
||||
# 异步读取一帧(cv2.read是同步操作,用to_thread适配异步环境)
|
||||
ret, frame = await asyncio.to_thread(cap.read)
|
||||
|
||||
# 检查帧是否读取成功(流中断/结束时ret为False)
|
||||
if not ret:
|
||||
print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)")
|
||||
break
|
||||
|
||||
# 帧计数累加
|
||||
frame_count += 1
|
||||
|
||||
# 6. 打印当前帧基础信息(与WHEP帧信息打印风格对齐)
|
||||
print(f"收到帧 (第{frame_count}帧)")
|
||||
print(f" 帧尺寸: {width}x{height}")
|
||||
print(f" 配置帧率: {fps:.2f} FPS")
|
||||
|
||||
# 7. 每100帧统计一次实际接收帧率(补充性能监控,与原RTMP示例逻辑一致)
|
||||
if frame_count % 100 == 0:
|
||||
elapsed_time = time.time() - start_time
|
||||
actual_fps = frame_count / elapsed_time # 实际接收帧率(可能低于配置帧率)
|
||||
print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----")
|
||||
|
||||
# (可选)帧数据处理入口:如需处理帧(如推流、分析),可在此处添加逻辑
|
||||
# 示例:yield frame (若需生成器模式,可调整函数为异步生成器)
|
||||
|
||||
# 8. 异常处理(覆盖用户中断、通用错误)
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n用户操作: 已通过 Ctrl+C 中断程序")
|
||||
except Exception as e:
|
||||
# 日志记录详细错误(便于问题排查),同时打印用户可见信息
|
||||
logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True)
|
||||
print(f"错误信息: {str(e)}")
|
||||
finally:
|
||||
# 9. 资源释放(无论成功/失败都确保释放,避免内存泄漏)
|
||||
if cap is not None:
|
||||
await asyncio.to_thread(cap.release)
|
||||
print(f"\n资源释放: RTMP流已关闭")
|
||||
print(f"最终统计: 共接收 {frame_count if 'frame_count' in locals() else 0} 帧")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416"
|
||||
|
||||
# 运行RTMP拉流任务(与WHEP一致的异步执行方式)
|
||||
try:
|
||||
asyncio.run(rtmp_pull_video_stream(RTMP_URL))
|
||||
except Exception as e:
|
||||
print(f"程序启动失败: {str(e)}")
|
||||
36
schema/device_action_schema.py
Normal file
@ -0,0 +1,36 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 请求模型
|
||||
# ------------------------------
|
||||
class DeviceActionCreate(BaseModel):
|
||||
"""设备操作记录创建模型(0=离线、1=上线)"""
|
||||
client_ip: str = Field(..., description="客户端IP")
|
||||
action: int = Field(..., ge=0, le=1, description="操作状态(0=离线、1=上线)")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 响应模型(单条记录)
|
||||
# ------------------------------
|
||||
class DeviceActionResponse(BaseModel):
|
||||
"""设备操作记录响应模型(与自增表对齐)"""
|
||||
id: int = Field(..., description="自增主键ID")
|
||||
client_ip: Optional[str] = Field(None, description="客户端IP")
|
||||
action: Optional[int] = Field(None, description="操作状态(0=离线、1=上线)")
|
||||
created_at: datetime = Field(..., description="记录创建时间")
|
||||
updated_at: datetime = Field(..., description="记录更新时间")
|
||||
|
||||
# 支持从数据库结果直接转换
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 列表响应模型(仅含 total + device_actions)
|
||||
# ------------------------------
|
||||
class DeviceActionListResponse(BaseModel):
|
||||
"""设备操作记录列表(仅核心返回字段)"""
|
||||
total: int = Field(..., description="总记录数")
|
||||
device_actions: List[DeviceActionResponse] = Field(..., description="操作记录列表")
|
||||
33
schema/device_danger_schema.py
Normal file
@ -0,0 +1,33 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 请求模型
|
||||
# ------------------------------
|
||||
class DeviceDangerCreateRequest(BaseModel):
|
||||
"""设备危险记录创建请求模型"""
|
||||
client_ip: str = Field(..., max_length=100, description="设备IP地址(必须与devices表中IP对应)")
|
||||
type: str = Field(..., max_length=50, description="危险类型(如:病毒检测、端口异常、权限泄露等)")
|
||||
result: str = Field(..., description="危险检测结果/处理结果(如:检测到木马病毒,已隔离;端口22异常开放,已关闭)")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 响应模型
|
||||
# ------------------------------
|
||||
class DeviceDangerResponse(BaseModel):
|
||||
"""单条设备危险记录响应模型(与device_danger表字段对齐,updated_at允许为null)"""
|
||||
id: int = Field(..., description="危险记录主键ID")
|
||||
client_ip: str = Field(..., max_length=100, description="设备IP地址")
|
||||
type: str = Field(..., max_length=50, description="危险类型")
|
||||
result: str = Field(..., description="危险检测结果/处理结果")
|
||||
created_at: datetime = Field(..., description="记录创建时间(危险发生/检测时间)")
|
||||
updated_at: Optional[datetime] = Field(None, description="记录更新时间(数据库中该字段当前为null)")
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class DeviceDangerListResponse(BaseModel):
|
||||
"""设备危险记录列表响应模型(带分页)"""
|
||||
total: int = Field(..., description="危险记录总数")
|
||||
dangers: List[DeviceDangerResponse] = Field(..., description="设备危险记录列表")
|
||||
@ -1,4 +1,3 @@
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict
|
||||
|
||||
@ -6,46 +5,50 @@ from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 请求模型(前端传参校验)
|
||||
# 请求模型
|
||||
# ------------------------------
|
||||
class DeviceCreateRequest(BaseModel):
|
||||
"""设备流信息创建请求模型"""
|
||||
"""设备创建请求模型"""
|
||||
ip: Optional[str] = Field(..., max_length=100, description="设备IP地址")
|
||||
hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
|
||||
params: Optional[Dict] = Field(None, description="设备详细信息")
|
||||
|
||||
|
||||
def md5_encrypt(text: str) -> str:
|
||||
"""对字符串进行MD5加密"""
|
||||
if not text:
|
||||
return ""
|
||||
md5_hash = hashlib.md5()
|
||||
md5_hash.update(text.encode('utf-8'))
|
||||
return md5_hash.hexdigest()
|
||||
params: Optional[Dict] = Field(None, description="设备扩展参数(JSON格式)")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 响应模型(后端返回设备数据)
|
||||
# 响应模型
|
||||
# ------------------------------
|
||||
class DeviceResponse(BaseModel):
|
||||
"""设备流信息响应模型(字段与表结构完全对齐)"""
|
||||
id: int = Field(..., description="设备ID")
|
||||
"""单设备信息响应模型(与数据库表字段对齐)"""
|
||||
id: int = Field(..., description="设备主键ID")
|
||||
client_ip: Optional[str] = Field(None, max_length=100, description="设备IP地址")
|
||||
hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
|
||||
rtmp_push_url: Optional[str] = Field(None, description="需要推送的RTMP地址")
|
||||
live_webrtc_url: Optional[str] = Field(None, description="直播的Webrtc地址")
|
||||
detection_webrtc_url: Optional[str] = Field(None, description="检测的Webrtc地址")
|
||||
device_online_status: int = Field(..., description="设备在线状态(1-在线、0-离线)")
|
||||
device_online_status: int = Field(..., description="在线状态(1-在线、0-离线)")
|
||||
device_type: Optional[str] = Field(None, description="设备类型")
|
||||
alarm_count: int = Field(..., description="报警次数")
|
||||
params: Optional[str] = Field(None, description="设备详细信息")
|
||||
params: Optional[str] = Field(None, description="扩展参数(JSON字符串)")
|
||||
created_at: datetime = Field(..., description="记录创建时间")
|
||||
updated_at: datetime = Field(..., description="记录更新时间")
|
||||
|
||||
# 支持从数据库查询结果转换
|
||||
model_config = {"from_attributes": True}
|
||||
model_config = {"from_attributes": True} # 支持从数据库结果直接转换
|
||||
|
||||
|
||||
class DeviceListResponse(BaseModel):
|
||||
"""设备流信息列表响应模型"""
|
||||
"""设备列表响应模型"""
|
||||
total: int = Field(..., description="设备总数")
|
||||
devices: List[DeviceResponse] = Field(..., description="设备列表")
|
||||
|
||||
|
||||
class DeviceStatusHistoryResponse(BaseModel):
|
||||
"""设备上下线记录响应模型"""
|
||||
id: int = Field(..., description="记录ID")
|
||||
device_id: int = Field(..., description="关联设备ID")
|
||||
client_ip: Optional[str] = Field(None, description="设备IP地址")
|
||||
status: int = Field(..., description="状态(1-在线、0-离线)")
|
||||
status_time: datetime = Field(..., description="状态变更时间")
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class DeviceStatusHistoryListResponse(BaseModel):
|
||||
"""设备上下线记录列表响应模型"""
|
||||
total: int = Field(..., description="记录总数")
|
||||
history: List[DeviceStatusHistoryResponse] = Field(..., description="上下线记录列表")
|
||||
@ -1,30 +1,41 @@
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 请求模型(前端传参校验)
|
||||
# 请求模型(前端传参校验)- 保留update的eigenvalue(如需更新特征值)
|
||||
# ------------------------------
|
||||
class FaceCreateRequest(BaseModel):
|
||||
"""创建人脸记录请求模型(无需ID,由数据库自增)"""
|
||||
name: str = Field(None, max_length=255, description="名称(可选,最长255字符)")
|
||||
"""创建人脸记录请求模型(无需ID、由数据库自增)"""
|
||||
name: Optional[str] = Field(None, max_length=255, description="名称(可选、最长255字符)")
|
||||
|
||||
|
||||
class FaceUpdateRequest(BaseModel):
|
||||
"""更新人脸记录请求模型(不变)"""
|
||||
name: str = Field(None, max_length=255, description="名称")
|
||||
eigenvalue: str = Field(None, max_length=255, description="特征(文件处理后可更新)")
|
||||
"""更新人脸记录请求模型 - 保留eigenvalue(如需更新特征值,不影响返回)"""
|
||||
name: Optional[str] = Field(None, max_length=255, description="名称(可选)")
|
||||
eigenvalue: Optional[str] = Field(None, description="特征值(可选,文件处理后可更新)") # 保留更新能力
|
||||
address: Optional[str] = Field(None, description="图片完整路径(可选,更新图片时使用)")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 响应模型(后端返回数据)
|
||||
# 响应模型(后端返回数据)- 核心修改:删除eigenvalue字段
|
||||
# ------------------------------
|
||||
class FaceResponse(BaseModel):
|
||||
"""人脸记录响应模型(仍包含ID,由数据库生成后返回)"""
|
||||
"""人脸记录响应模型(仅返回需要的字段,移除eigenvalue)"""
|
||||
id: int = Field(..., description="主键ID(数据库自增)")
|
||||
name: str = Field(None, description="名称")
|
||||
eigenvalue: str = Field(None, description="特征(暂为None)")
|
||||
created_at: datetime = Field(..., description="记录创建时间")
|
||||
updated_at: datetime = Field(..., description="记录更新时间")
|
||||
name: Optional[str] = Field(None, description="名称")
|
||||
address: Optional[str] = Field(None, description="人脸图片完整保存路径(数据库新增字段)") # 仅保留address
|
||||
created_at: datetime = Field(..., description="记录创建时间(数据库自动生成)")
|
||||
updated_at: datetime = Field(..., description="记录更新时间(数据库自动生成)")
|
||||
|
||||
# 关键配置:支持从数据库查询结果(字典)直接转换
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class FaceListResponse(BaseModel):
|
||||
"""人脸列表分页响应模型(结构不变,内部FaceResponse已移除eigenvalue)"""
|
||||
total: int = Field(..., description="筛选后的总记录数")
|
||||
faces: List[FaceResponse] = Field(..., description="当前页的人脸记录列表")
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
37
schema/model_schema.py
Normal file
@ -0,0 +1,37 @@
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
# 请求模型
|
||||
class ModelCreateRequest(BaseModel):
|
||||
name: str = Field(..., max_length=255, description="模型名称(必填,如:yolo-v8s-car)")
|
||||
description: Optional[str] = Field(None, description="模型描述(可选)")
|
||||
is_default: Optional[bool] = Field(False, description="是否设为默认模型")
|
||||
|
||||
|
||||
class ModelUpdateRequest(BaseModel):
|
||||
name: Optional[str] = Field(None, max_length=255, description="模型名称(可选修改)")
|
||||
description: Optional[str] = Field(None, description="模型描述(可选修改)")
|
||||
is_default: Optional[bool] = Field(None, description="是否设为默认模型(可选切换)")
|
||||
|
||||
|
||||
# 响应模型
|
||||
class ModelResponse(BaseModel):
|
||||
id: int = Field(..., description="模型ID")
|
||||
name: str = Field(..., description="模型名称")
|
||||
path: str = Field(..., description="模型文件相对路径")
|
||||
is_default: bool = Field(..., description="是否默认模型")
|
||||
description: Optional[str] = Field(None, description="模型描述")
|
||||
file_size: Optional[int] = Field(None, description="文件大小(字节)")
|
||||
created_at: datetime = Field(..., description="创建时间")
|
||||
updated_at: datetime = Field(..., description="更新时间")
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ModelListResponse(BaseModel):
|
||||
total: int = Field(..., description="总记录数")
|
||||
models: List[ModelResponse] = Field(..., description="当前页模型列表")
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
@ -5,9 +5,9 @@ from pydantic import BaseModel, Field
|
||||
|
||||
class APIResponse(BaseModel):
|
||||
"""统一 API 响应模型(所有接口必返此格式)"""
|
||||
code: int = Field(..., description="状态码:200=成功、4xx=客户端错误、5xx=服务端错误")
|
||||
message: str = Field(..., description="响应信息:成功/错误描述")
|
||||
data: Optional[Any] = Field(None, description="响应数据:成功时返回、错误时为 None")
|
||||
code: int = Field(..., description="状态码: 200=成功、4xx=客户端错误、5xx=服务端错误")
|
||||
message: str = Field(..., description="响应信息: 成功/错误描述")
|
||||
data: Optional[Any] = Field(None, description="响应数据: 成功时返回、错误时为 None")
|
||||
|
||||
# Pydantic V2 配置(支持从 ORM 对象转换)
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
# ------------------------------
|
||||
@ -7,24 +8,31 @@ from pydantic import BaseModel, Field
|
||||
# ------------------------------
|
||||
class SensitiveCreateRequest(BaseModel):
|
||||
"""创建敏感信息记录请求模型"""
|
||||
# 移除了id字段,由数据库自动生成
|
||||
name: str = Field(None, max_length=255, description="名称")
|
||||
name: str = Field(..., max_length=255, description="敏感词内容(必填)")
|
||||
|
||||
|
||||
class SensitiveUpdateRequest(BaseModel):
|
||||
"""更新敏感信息记录请求模型"""
|
||||
name: str = Field(None, max_length=255, description="名称")
|
||||
name: Optional[str] = Field(None, max_length=255, description="敏感词内容(可选修改)")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 响应模型(后端返回数据)
|
||||
# ------------------------------
|
||||
class SensitiveResponse(BaseModel):
|
||||
"""敏感信息记录响应模型"""
|
||||
id: int = Field(..., description="主键ID") # 响应中仍然包含ID
|
||||
name: str = Field(None, description="名称")
|
||||
"""敏感信息单条记录响应模型"""
|
||||
id: int = Field(..., description="主键ID")
|
||||
name: str = Field(..., description="敏感词内容")
|
||||
created_at: datetime = Field(..., description="记录创建时间")
|
||||
updated_at: datetime = Field(..., description="记录更新时间")
|
||||
|
||||
# 支持从数据库查询结果转换
|
||||
# 支持从数据库查询结果(字典/对象)自动转换
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class SensitiveListResponse(BaseModel):
|
||||
"""敏感信息分页列表响应模型(新增)"""
|
||||
total: int = Field(..., description="敏感词总记录数")
|
||||
sensitives: List[SensitiveResponse] = Field(..., description="当前页敏感词列表")
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
# ------------------------------
|
||||
@ -30,3 +30,11 @@ class UserResponse(BaseModel):
|
||||
|
||||
# Pydantic V2 配置(支持从数据库查询结果转换)
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class UserListResponse(BaseModel):
|
||||
"""用户列表分页响应模型(与设备/人脸列表结构对齐)"""
|
||||
total: int = Field(..., description="用户总数")
|
||||
users: List[UserResponse] = Field(..., description="当前页用户列表")
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
161
service/device_action_service.py
Normal file
@ -0,0 +1,161 @@
|
||||
from fastapi import APIRouter, Query, Path
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from schema.device_action_schema import (
|
||||
DeviceActionCreate,
|
||||
DeviceActionResponse,
|
||||
DeviceActionListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
|
||||
# 路由配置
|
||||
router = APIRouter(
|
||||
prefix="/api/device/actions",
|
||||
tags=["设备操作记录"]
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 内部方法: 新增设备操作记录(适配id自增)
|
||||
# ------------------------------
|
||||
def add_device_action(action_data: DeviceActionCreate) -> DeviceActionResponse:
|
||||
"""
|
||||
新增设备操作记录(内部方法、非接口)
|
||||
:param action_data: 含client_ip和action(0/1)
|
||||
:return: 新增的完整记录
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 插入SQL(id自增、依赖数据库自动生成)
|
||||
insert_query = """
|
||||
INSERT INTO device_action
|
||||
(client_ip, action, created_at, updated_at)
|
||||
VALUES (%s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
"""
|
||||
cursor.execute(insert_query, (
|
||||
action_data.client_ip,
|
||||
action_data.action
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
# 获取新增记录(通过自增ID查询)
|
||||
new_id = cursor.lastrowid
|
||||
cursor.execute("SELECT * FROM device_action WHERE id = %s", (new_id,))
|
||||
new_action = cursor.fetchone()
|
||||
|
||||
return DeviceActionResponse(**new_action)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"新增记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 接口: 分页查询操作记录列表(仅返回 total + device_actions)
|
||||
# ------------------------------
|
||||
@router.get("/list", response_model=APIResponse, summary="分页查询设备操作记录")
|
||||
@encrypt_response()
|
||||
async def get_device_action_list(
|
||||
page: int = Query(1, ge=1, description="页码、默认1"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100"),
|
||||
client_ip: str = Query(None, description="按客户端IP筛选"),
|
||||
action: int = Query(None, ge=0, le=1, description="按状态筛选(0=离线、1=上线)")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 1. 构建筛选条件(参数化查询、避免注入)
|
||||
where_clause = []
|
||||
params = []
|
||||
if client_ip:
|
||||
where_clause.append("client_ip = %s")
|
||||
params.append(client_ip)
|
||||
if action is not None:
|
||||
where_clause.append("action = %s")
|
||||
params.append(action)
|
||||
|
||||
# 2. 查询总记录数(用于返回 total)
|
||||
count_sql = "SELECT COUNT(*) AS total FROM device_action"
|
||||
if where_clause:
|
||||
count_sql += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_sql, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 3. 分页查询记录(按创建时间倒序、确保最新记录在前)
|
||||
offset = (page - 1) * page_size
|
||||
list_sql = "SELECT * FROM device_action"
|
||||
if where_clause:
|
||||
list_sql += " WHERE " + " AND ".join(where_clause)
|
||||
list_sql += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset]) # 追加分页参数(page/page_size仅用于查询、不返回)
|
||||
|
||||
cursor.execute(list_sql, params)
|
||||
action_list = cursor.fetchall()
|
||||
|
||||
# 4. 仅返回 total + device_actions
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="查询成功",
|
||||
data=DeviceActionListResponse(
|
||||
total=total,
|
||||
device_actions=[DeviceActionResponse(**item) for item in action_list]
|
||||
)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"查询记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
@router.get("/{client_ip}", response_model=APIResponse, summary="根据IP查询设备操作记录")
|
||||
@encrypt_response()
|
||||
async def get_device_actions_by_ip(
|
||||
client_ip: str = Path(..., description="客户端IP地址")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 1. 查询总记录数
|
||||
count_sql = "SELECT COUNT(*) AS total FROM device_action WHERE client_ip = %s"
|
||||
cursor.execute(count_sql, (client_ip,))
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 2. 查询该IP的所有记录(按创建时间倒序)
|
||||
list_sql = """
|
||||
SELECT * FROM device_action
|
||||
WHERE client_ip = %s
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
cursor.execute(list_sql, (client_ip,))
|
||||
action_list = cursor.fetchall()
|
||||
|
||||
# 3. 返回结果
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="查询成功",
|
||||
data=DeviceActionListResponse(
|
||||
total=total,
|
||||
device_actions=[DeviceActionResponse(**item) for item in action_list]
|
||||
)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"查询记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
267
service/device_danger_service.py
Normal file
@ -0,0 +1,267 @@
|
||||
import json
|
||||
from datetime import date
|
||||
|
||||
from fastapi import APIRouter, Query, HTTPException, Path
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from schema.device_danger_schema import (
|
||||
DeviceDangerCreateRequest, DeviceDangerResponse, DeviceDangerListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
|
||||
# 路由初始化(前缀与设备管理相关,标签区分功能)
|
||||
router = APIRouter(
|
||||
prefix="/api/devices/dangers",
|
||||
tags=["设备管理-危险记录"]
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 内部工具方法 - 检查设备是否存在(复用设备表逻辑)
|
||||
# ------------------------------
|
||||
def check_device_exist(client_ip: str) -> bool:
|
||||
"""
|
||||
检查指定IP的设备是否在devices表中存在
|
||||
|
||||
:param client_ip: 设备IP地址
|
||||
:return: 存在返回True,不存在返回False
|
||||
"""
|
||||
if not client_ip:
|
||||
raise ValueError("设备IP不能为空")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
|
||||
return cursor.fetchone() is not None
|
||||
except MySQLError as e:
|
||||
raise Exception(f"检查设备存在性失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 内部工具方法 - 创建设备危险记录(核心插入逻辑)
|
||||
# ------------------------------
|
||||
def create_danger_record(danger_data: DeviceDangerCreateRequest) -> DeviceDangerResponse:
|
||||
"""
|
||||
内部工具方法:向device_danger表插入新的危险记录
|
||||
|
||||
:param danger_data: 危险记录创建请求数据
|
||||
:return: 创建成功的危险记录模型对象
|
||||
"""
|
||||
# 先检查设备是否存在
|
||||
if not check_device_exist(danger_data.client_ip):
|
||||
raise ValueError(f"IP为 {danger_data.client_ip} 的设备不存在,无法创建危险记录")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 插入危险记录(id自增,时间自动填充)
|
||||
insert_query = """
|
||||
INSERT INTO device_danger
|
||||
(client_ip, type, result, created_at, updated_at)
|
||||
VALUES (%s, %s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
"""
|
||||
cursor.execute(insert_query, (
|
||||
danger_data.client_ip,
|
||||
danger_data.type,
|
||||
danger_data.result
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
# 获取刚创建的记录(用自增ID查询)
|
||||
danger_id = cursor.lastrowid
|
||||
cursor.execute("SELECT * FROM device_danger WHERE id = %s", (danger_id,))
|
||||
new_danger = cursor.fetchone()
|
||||
|
||||
return DeviceDangerResponse(**new_danger)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"插入危险记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 接口1:创建设备危险记录
|
||||
# ------------------------------
|
||||
@router.post("/add", response_model=APIResponse, summary="创建设备危险记录")
|
||||
@encrypt_response()
|
||||
async def add_device_danger(danger_data: DeviceDangerCreateRequest):
|
||||
try:
|
||||
# 调用内部方法创建记录
|
||||
new_danger = create_danger_record(danger_data)
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"设备[{danger_data.client_ip}]危险记录创建成功",
|
||||
data=new_danger
|
||||
)
|
||||
except ValueError as e:
|
||||
# 设备不存在等业务异常
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
except Exception as e:
|
||||
# 数据库异常等系统错误
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 接口2:获取危险记录列表(支持多条件筛选+分页)
|
||||
# ------------------------------
|
||||
@router.get("/", response_model=APIResponse, summary="获取设备危险记录列表(多条件筛选)")
|
||||
@encrypt_response()
|
||||
async def get_danger_list(
|
||||
page: int = Query(1, ge=1, description="页码,默认第1页"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"),
|
||||
client_ip: str = Query(None, max_length=100, description="按设备IP筛选"),
|
||||
danger_type: str = Query(None, max_length=50, alias="type", description="按危险类型筛选"),
|
||||
start_date: date = Query(None, description="按创建时间筛选(开始日期,格式YYYY-MM-DD)"),
|
||||
end_date: date = Query(None, description="按创建时间筛选(结束日期,格式YYYY-MM-DD)")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 构建筛选条件
|
||||
where_clause = []
|
||||
params = []
|
||||
|
||||
if client_ip:
|
||||
where_clause.append("client_ip = %s")
|
||||
params.append(client_ip)
|
||||
if danger_type:
|
||||
where_clause.append("type = %s")
|
||||
params.append(danger_type)
|
||||
if start_date:
|
||||
where_clause.append("DATE(created_at) >= %s")
|
||||
params.append(start_date.strftime("%Y-%m-%d"))
|
||||
if end_date:
|
||||
where_clause.append("DATE(created_at) <= %s")
|
||||
params.append(end_date.strftime("%Y-%m-%d"))
|
||||
|
||||
# 1. 统计符合条件的总记录数
|
||||
count_query = "SELECT COUNT(*) AS total FROM device_danger"
|
||||
if where_clause:
|
||||
count_query += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_query, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 2. 分页查询记录(按创建时间倒序,最新的在前)
|
||||
offset = (page - 1) * page_size
|
||||
list_query = "SELECT * FROM device_danger"
|
||||
if where_clause:
|
||||
list_query += " WHERE " + " AND ".join(where_clause)
|
||||
list_query += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset]) # 追加分页参数
|
||||
|
||||
cursor.execute(list_query, params)
|
||||
danger_list = cursor.fetchall()
|
||||
|
||||
# 转换为响应模型
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取危险记录列表成功",
|
||||
data=DeviceDangerListResponse(
|
||||
total=total,
|
||||
dangers=[DeviceDangerResponse(**item) for item in danger_list]
|
||||
)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询危险记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 接口3:获取单个设备的所有危险记录
|
||||
# ------------------------------
|
||||
@router.get("/device/{client_ip}", response_model=APIResponse, summary="获取单个设备的所有危险记录")
|
||||
# @encrypt_response()
|
||||
async def get_device_dangers(
|
||||
client_ip: str = Path(..., max_length=100, description="设备IP地址"),
|
||||
page: int = Query(1, ge=1, description="页码,默认第1页"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间")
|
||||
):
|
||||
# 先检查设备是否存在
|
||||
if not check_device_exist(client_ip):
|
||||
raise HTTPException(status_code=404, detail=f"IP为 {client_ip} 的设备不存在")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 1. 统计该设备的危险记录总数
|
||||
count_query = "SELECT COUNT(*) AS total FROM device_danger WHERE client_ip = %s"
|
||||
cursor.execute(count_query, (client_ip,))
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 2. 分页查询该设备的危险记录
|
||||
offset = (page - 1) * page_size
|
||||
list_query = """
|
||||
SELECT * FROM device_danger
|
||||
WHERE client_ip = %s
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s OFFSET %s
|
||||
"""
|
||||
cursor.execute(list_query, (client_ip, page_size, offset))
|
||||
danger_list = cursor.fetchall()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"获取设备[{client_ip}]危险记录成功(共{total}条)",
|
||||
data=DeviceDangerListResponse(
|
||||
total=total,
|
||||
dangers=[DeviceDangerResponse(**item) for item in danger_list]
|
||||
)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询设备[{client_ip}]危险记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 接口4:根据ID获取单个危险记录详情
|
||||
# ------------------------------
|
||||
@router.get("/{danger_id}", response_model=APIResponse, summary="根据ID获取单个危险记录详情")
|
||||
@encrypt_response()
|
||||
async def get_danger_detail(
|
||||
danger_id: int = Path(..., ge=1, description="危险记录ID")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 查询单个危险记录
|
||||
query = "SELECT * FROM device_danger WHERE id = %s"
|
||||
cursor.execute(query, (danger_id,))
|
||||
danger = cursor.fetchone()
|
||||
|
||||
if not danger:
|
||||
raise HTTPException(status_code=404, detail=f"ID为 {danger_id} 的危险记录不存在")
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取危险记录详情成功",
|
||||
data=DeviceDangerResponse(**danger)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询危险记录详情失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
@ -1,116 +1,210 @@
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from datetime import date
|
||||
|
||||
from fastapi import HTTPException, Query, APIRouter, Depends, Request
|
||||
from fastapi import APIRouter, Query, HTTPException, Request, Path
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.config import LIVE_CONFIG
|
||||
from ds.db import db
|
||||
from middle.auth_middleware import get_current_user
|
||||
# 注意:导入的Schema已更新字段
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from schema.device_schema import (
|
||||
DeviceCreateRequest,
|
||||
DeviceResponse,
|
||||
DeviceListResponse,
|
||||
md5_encrypt
|
||||
DeviceCreateRequest, DeviceResponse, DeviceListResponse,
|
||||
DeviceStatusHistoryResponse, DeviceStatusHistoryListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
from schema.user_schema import UserResponse
|
||||
|
||||
# 导入之前封装的WEBRTC处理函数
|
||||
from core.rtmp import rtmp_pull_video_stream
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/devices",
|
||||
prefix="/api/devices",
|
||||
tags=["设备管理"]
|
||||
)
|
||||
|
||||
|
||||
# 在后台线程中运行WEBRTC处理
|
||||
def run_webrtc_processing(ip, webrtc_url):
|
||||
try:
|
||||
print(f"开始处理来自设备 {ip} 的WEBRTC流: {webrtc_url}")
|
||||
rtmp_pull_video_stream(webrtc_url)
|
||||
except Exception as e:
|
||||
print(f"WEBRTC处理出错: {str(e)}")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 1. 创建设备信息
|
||||
# 内部工具方法 - 记录设备状态变更历史
|
||||
# ------------------------------
|
||||
@router.post("/add", response_model=APIResponse, summary="创建设备信息")
|
||||
async def create_device(request: Request, device_data: DeviceCreateRequest):
|
||||
def record_status_change(client_ip: str, status: int) -> bool:
|
||||
"""
|
||||
记录设备状态变更历史(写入 device_action 表)
|
||||
|
||||
:param client_ip: 设备IP
|
||||
:param status: 状态(1-在线、0-离线)
|
||||
:return: 操作是否成功
|
||||
"""
|
||||
if not client_ip:
|
||||
raise ValueError("客户端IP不能为空")
|
||||
|
||||
if status not in (0, 1):
|
||||
raise ValueError("状态必须是0(离线)或1(在线)")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 检查client_ip是否已存在
|
||||
# 插入状态变更记录到 device_action
|
||||
insert_query = """
|
||||
INSERT INTO device_action
|
||||
(client_ip, action, created_at, updated_at)
|
||||
VALUES (%s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
"""
|
||||
cursor.execute(insert_query, (client_ip, status))
|
||||
conn.commit()
|
||||
|
||||
return True
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"记录设备状态变更失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 内部工具方法 - 通过客户端IP增加设备报警次数
|
||||
# ------------------------------
|
||||
def increment_alarm_count_by_ip(client_ip: str) -> bool:
|
||||
"""
|
||||
通过客户端IP增加设备的报警次数
|
||||
|
||||
:param client_ip: 客户端IP地址
|
||||
:return: 操作是否成功
|
||||
"""
|
||||
if not client_ip:
|
||||
raise ValueError("客户端IP不能为空")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 检查设备是否存在
|
||||
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
|
||||
device = cursor.fetchone()
|
||||
if not device:
|
||||
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
|
||||
|
||||
# 报警次数加1、并更新时间戳
|
||||
update_query = """
|
||||
UPDATE devices
|
||||
SET alarm_count = alarm_count + 1,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE client_ip = %s
|
||||
"""
|
||||
cursor.execute(update_query, (client_ip,))
|
||||
conn.commit()
|
||||
|
||||
return True
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"更新报警次数失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 内部工具方法 - 通过客户端IP更新设备在线状态
|
||||
# ------------------------------
|
||||
def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
|
||||
"""
|
||||
通过客户端IP更新设备的在线状态
|
||||
|
||||
:param client_ip: 客户端IP地址
|
||||
:param online_status: 在线状态(1-在线、0-离线)
|
||||
:return: 操作是否成功
|
||||
"""
|
||||
if not client_ip:
|
||||
raise ValueError("客户端IP不能为空")
|
||||
|
||||
if online_status not in (0, 1):
|
||||
raise ValueError("在线状态必须是0(离线)或1(在线)")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 检查设备是否存在并获取设备ID
|
||||
cursor.execute("SELECT id, device_online_status FROM devices WHERE client_ip = %s", (client_ip,))
|
||||
device = cursor.fetchone()
|
||||
if not device:
|
||||
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
|
||||
|
||||
# 状态无变化则不操作
|
||||
if device['device_online_status'] == online_status:
|
||||
return True
|
||||
|
||||
# 更新在线状态和时间戳
|
||||
update_query = """
|
||||
UPDATE devices
|
||||
SET device_online_status = %s,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE client_ip = %s
|
||||
"""
|
||||
cursor.execute(update_query, (online_status, client_ip))
|
||||
|
||||
# 记录状态变更历史
|
||||
record_status_change(client_ip, online_status)
|
||||
|
||||
conn.commit()
|
||||
return True
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"更新设备在线状态失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 创建设备信息接口
|
||||
# ------------------------------
|
||||
@router.post("/add", response_model=APIResponse, summary="创建设备信息")
|
||||
@encrypt_response()
|
||||
async def create_device(device_data: DeviceCreateRequest, request: Request):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 检查设备是否已存在
|
||||
cursor.execute("SELECT * FROM devices WHERE client_ip = %s", (device_data.ip,))
|
||||
existing_device = cursor.fetchone()
|
||||
if existing_device:
|
||||
# 设备创建成功后,在后台线程启动WEBRTC流处理
|
||||
threading.Thread(
|
||||
target=run_webrtc_processing,
|
||||
# args=(device_data.ip, existing_device["live_webrtc_url"]),
|
||||
args=(device_data.ip, existing_device["rtmp_push_url"]),
|
||||
|
||||
daemon=True # 设为守护线程,主程序退出时自动结束
|
||||
).start()
|
||||
# IP已存在时返回该设备信息
|
||||
# 更新设备为在线状态
|
||||
update_online_status_by_ip(client_ip=device_data.ip, online_status=1)
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"客户端IP {device_data.ip} 已存在",
|
||||
message=f"设备IP {device_data.ip} 已存在、返回已有设备信息",
|
||||
data=DeviceResponse(**existing_device)
|
||||
)
|
||||
|
||||
# 获取RTMP URL和WEBRTC URL配置
|
||||
rtmp_url = str(LIVE_CONFIG.get("rtmp_url", ""))
|
||||
webrtc_url = str(LIVE_CONFIG.get("webrtc_url", ""))
|
||||
|
||||
# 将设备详细信息(params)转换为JSON字符串
|
||||
device_params_json = json.dumps(device_data.params) if device_data.params else None
|
||||
|
||||
# 对JSON字符串进行MD5加密
|
||||
device_md5 = md5_encrypt(device_params_json) if device_params_json else ""
|
||||
|
||||
# 解析User-Agent获取设备类型
|
||||
# 通过 User-Agent 判断设备类型
|
||||
user_agent = request.headers.get("User-Agent", "").lower()
|
||||
|
||||
# 优先处理User-Agent为default的情况
|
||||
if user_agent == "default":
|
||||
# 检查params中是否存在os键
|
||||
if device_data.params and isinstance(device_data.params, dict) and "os" in device_data.params:
|
||||
device_type = device_data.params["os"]
|
||||
else:
|
||||
device_type = "unknown"
|
||||
if user_agent == "default":
|
||||
device_type = device_data.params.get("os") if (device_data.params and isinstance(device_data.params, dict)) else "unknown"
|
||||
elif "windows" in user_agent:
|
||||
device_type = "windows"
|
||||
elif "android" in user_agent:
|
||||
device_type = "android"
|
||||
elif "linux" in user_agent:
|
||||
device_type = "linux"
|
||||
else:
|
||||
device_type = "unknown"
|
||||
|
||||
# 构建完整的WEBRTC URL
|
||||
full_webrtc_url = webrtc_url + device_md5
|
||||
device_params_json = json.dumps(device_data.params) if device_data.params else None
|
||||
|
||||
# SQL插入语句
|
||||
# 插入新设备
|
||||
insert_query = """
|
||||
INSERT INTO devices
|
||||
(client_ip, hostname, rtmp_push_url, live_webrtc_url, detection_webrtc_url,
|
||||
device_online_status, device_type, alarm_count, params)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
(client_ip, hostname, device_online_status, device_type, alarm_count, params)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)
|
||||
"""
|
||||
cursor.execute(insert_query, (
|
||||
device_data.ip,
|
||||
device_data.hostname,
|
||||
rtmp_url + device_md5,
|
||||
full_webrtc_url, # 存储完整的WEBRTC URL
|
||||
"",
|
||||
1,
|
||||
device_type,
|
||||
0,
|
||||
@ -118,28 +212,26 @@ async def create_device(request: Request, device_data: DeviceCreateRequest):
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
# 获取刚创建的设备信息
|
||||
# 获取新设备并返回
|
||||
device_id = cursor.lastrowid
|
||||
cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,))
|
||||
device = cursor.fetchone()
|
||||
new_device = cursor.fetchone()
|
||||
|
||||
# 记录上线历史
|
||||
record_status_change(device_data.ip, 1)
|
||||
|
||||
# 设备创建成功后,在后台线程启动WEBRTC流处理
|
||||
threading.Thread(
|
||||
target=run_webrtc_processing,
|
||||
args=(device_data.ip, full_webrtc_url),
|
||||
daemon=True # 设为守护线程,主程序退出时自动结束
|
||||
).start()
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="设备创建成功,已开始处理WEBRTC流",
|
||||
data=DeviceResponse(**device)
|
||||
message="设备创建成功",
|
||||
data=DeviceResponse(**new_device)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"创建设备失败:{str(e)}") from e
|
||||
raise Exception(f"创建设备失败: {str(e)}") from e
|
||||
except json.JSONDecodeError as e:
|
||||
raise Exception(f"设备信息JSON序列化失败:{str(e)}") from e
|
||||
raise Exception(f"设备参数JSON序列化失败: {str(e)}") from e
|
||||
except Exception as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
@ -149,14 +241,15 @@ async def create_device(request: Request, device_data: DeviceCreateRequest):
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 2. 获取设备列表
|
||||
# 获取设备列表接口
|
||||
# ------------------------------
|
||||
@router.get("/", response_model=APIResponse, summary="获取设备列表")
|
||||
@router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)")
|
||||
@encrypt_response()
|
||||
async def get_device_list(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数"),
|
||||
device_type: str = Query(None, description="设备类型筛选"),
|
||||
online_status: int = Query(None, ge=0, le=1, description="在线状态筛选(1-在线、0-离线)")
|
||||
page: int = Query(1, ge=1, description="页码,默认第1页"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"),
|
||||
device_type: str = Query(None, description="按设备类型筛选"),
|
||||
online_status: int = Query(None, ge=0, le=1, description="按在线状态筛选")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
@ -164,58 +257,60 @@ async def get_device_list(
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 构建查询条件
|
||||
where_clause = []
|
||||
params = []
|
||||
|
||||
if device_type:
|
||||
where_clause.append("device_type = %s")
|
||||
params.append(device_type)
|
||||
|
||||
if online_status is not None:
|
||||
where_clause.append("device_online_status = %s")
|
||||
params.append(online_status)
|
||||
|
||||
# 总条数查询
|
||||
count_query = "SELECT COUNT(*) as total FROM devices"
|
||||
# 统计总数
|
||||
count_query = "SELECT COUNT(*) AS total FROM devices"
|
||||
if where_clause:
|
||||
count_query += " WHERE " + " AND ".join(where_clause)
|
||||
|
||||
cursor.execute(count_query, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 分页查询(SELECT * 会自动匹配表字段、响应模型已对齐)
|
||||
# 分页查询列表
|
||||
offset = (page - 1) * page_size
|
||||
query = "SELECT * FROM devices"
|
||||
list_query = "SELECT * FROM devices"
|
||||
if where_clause:
|
||||
query += " WHERE " + " AND ".join(where_clause)
|
||||
query += " ORDER BY id DESC LIMIT %s OFFSET %s"
|
||||
list_query += " WHERE " + " AND ".join(where_clause)
|
||||
list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset])
|
||||
|
||||
cursor.execute(query, params)
|
||||
devices = cursor.fetchall()
|
||||
|
||||
# 响应模型已更新为params字段、直接转换即可
|
||||
device_list = [DeviceResponse(**device) for device in devices]
|
||||
cursor.execute(list_query, params)
|
||||
device_list = cursor.fetchall()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取设备列表成功",
|
||||
data=DeviceListResponse(total=total, devices=device_list)
|
||||
data=DeviceListResponse(
|
||||
total=total,
|
||||
devices=[DeviceResponse(**device) for device in device_list]
|
||||
)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取设备列表失败:{str(e)}") from e
|
||||
raise Exception(f"获取设备列表失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 3. 获取单个设备详情
|
||||
# 获取设备上下线记录接口
|
||||
# ------------------------------
|
||||
@router.get("/{device_id}", response_model=APIResponse, summary="获取设备详情")
|
||||
async def get_device_detail(
|
||||
device_id: int,
|
||||
current_user: UserResponse = Depends(get_current_user)
|
||||
@router.get("/{device_id}/status-history", response_model=APIResponse, summary="获取设备上下线记录")
|
||||
@encrypt_response()
|
||||
async def get_device_status_history(
|
||||
device_id: int = Path(..., description="设备ID"),
|
||||
page: int = Query(1, ge=1, description="页码,默认第1页"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"),
|
||||
start_date: date = Query(None, description="开始日期,格式YYYY-MM-DD"),
|
||||
end_date: date = Query(None, description="结束日期,格式YYYY-MM-DD")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
@ -223,36 +318,76 @@ async def get_device_detail(
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 查询设备信息(SELECT * 匹配表字段)
|
||||
query = "SELECT * FROM devices WHERE id = %s"
|
||||
cursor.execute(query, (device_id,))
|
||||
# 检查设备是否存在并获取 client_ip
|
||||
cursor.execute("SELECT id, client_ip FROM devices WHERE id = %s", (device_id,))
|
||||
device = cursor.fetchone()
|
||||
|
||||
if not device:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"设备ID为 {device_id} 的设备不存在"
|
||||
)
|
||||
raise HTTPException(status_code=404, detail=f"设备ID为 {device_id} 的设备不存在")
|
||||
client_ip = device['client_ip']
|
||||
|
||||
where_clause = ["client_ip = %s"]
|
||||
params = [client_ip]
|
||||
|
||||
# 日期筛选
|
||||
if start_date:
|
||||
where_clause.append("DATE(created_at) >= %s")
|
||||
params.append(start_date.strftime("%Y-%m-%d"))
|
||||
if end_date:
|
||||
where_clause.append("DATE(created_at) <= %s")
|
||||
params.append(end_date.strftime("%Y-%m-%d"))
|
||||
|
||||
# 统计记录总数
|
||||
count_query = "SELECT COUNT(*) AS total FROM device_action WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_query, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 分页查询记录
|
||||
offset = (page - 1) * page_size
|
||||
list_query = f"""
|
||||
SELECT * FROM device_action
|
||||
WHERE {' AND '.join(where_clause)}
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s OFFSET %s
|
||||
"""
|
||||
params.extend([page_size, offset])
|
||||
cursor.execute(list_query, params)
|
||||
history_list = cursor.fetchall()
|
||||
|
||||
# 格式化为响应模型结构
|
||||
formatted_history = []
|
||||
for item in history_list:
|
||||
formatted_item = {
|
||||
"id": item["id"],
|
||||
"device_id": device_id,
|
||||
"client_ip": item["client_ip"],
|
||||
"status": item["action"],
|
||||
"status_time": item["created_at"]
|
||||
}
|
||||
formatted_history.append(formatted_item)
|
||||
|
||||
# 响应模型已更新为params字段
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取设备详情成功",
|
||||
data=DeviceResponse(**device)
|
||||
message="获取设备上下线记录成功",
|
||||
data=DeviceStatusHistoryListResponse(
|
||||
total=total,
|
||||
history=[DeviceStatusHistoryResponse(**item) for item in formatted_history]
|
||||
)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取设备详情失败:{str(e)}") from e
|
||||
raise Exception(f"获取设备上下线记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 4. 删除设备信息
|
||||
# 手动更新设备在线状态接口
|
||||
# ------------------------------
|
||||
@router.delete("/{device_id}", response_model=APIResponse, summary="删除设备信息")
|
||||
async def delete_device(
|
||||
device_id: int,
|
||||
current_user: UserResponse = Depends(get_current_user)
|
||||
@router.put("/{device_id}/status", response_model=APIResponse, summary="更新设备在线状态")
|
||||
@encrypt_response()
|
||||
async def update_device_status(
|
||||
device_id: int = Path(..., description="设备ID"),
|
||||
status: int = Query(..., ge=0, le=1, description="在线状态(1-在线、0-离线)")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
@ -260,27 +395,51 @@ async def delete_device(
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 检查设备是否存在
|
||||
cursor.execute("SELECT id FROM devices WHERE id = %s", (device_id,))
|
||||
if not cursor.fetchone():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"设备ID为 {device_id} 的设备不存在"
|
||||
)
|
||||
# 获取设备 client_ip
|
||||
cursor.execute("SELECT id, client_ip FROM devices WHERE id = %s", (device_id,))
|
||||
device = cursor.fetchone()
|
||||
if not device:
|
||||
raise HTTPException(status_code=404, detail=f"设备ID为 {device_id} 的设备不存在")
|
||||
|
||||
# 执行删除
|
||||
delete_query = "DELETE FROM devices WHERE id = %s"
|
||||
cursor.execute(delete_query, (device_id,))
|
||||
conn.commit()
|
||||
# 更新状态
|
||||
success = update_online_status_by_ip(device['client_ip'], status)
|
||||
|
||||
if success:
|
||||
status_text = "在线" if status == 1 else "离线"
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"设备ID为 {device_id} 的设备已成功删除",
|
||||
message=f"设备已更新为{status_text}状态",
|
||||
data={"device_id": device_id, "status": status, "status_text": status_text}
|
||||
)
|
||||
return APIResponse(
|
||||
code=500,
|
||||
message="更新设备状态失败",
|
||||
data=None
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"删除设备失败:{str(e)}") from e
|
||||
raise Exception(f"更新设备状态失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 获取所有去重的客户端IP列表
|
||||
# ------------------------------
|
||||
def get_unique_client_ips() -> list[str]:
|
||||
"""获取所有去重的客户端IP列表"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
query = "SELECT DISTINCT client_ip FROM devices WHERE client_ip IS NOT NULL"
|
||||
cursor.execute(query)
|
||||
results = cursor.fetchall()
|
||||
return [item['client_ip'] for item in results]
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取客户端IP列表失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
@ -1,118 +1,144 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, Request
|
||||
from fastapi.responses import FileResponse
|
||||
from mysql.connector import Error as MySQLError
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from ds.db import db
|
||||
from schema.face_schema import FaceCreateRequest, FaceUpdateRequest, FaceResponse
|
||||
from schema.response_schema import APIResponse
|
||||
from middle.auth_middleware import get_current_user
|
||||
from schema.user_schema import UserResponse
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/faces",
|
||||
tags=["人脸管理"]
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from schema.face_schema import (
|
||||
FaceCreateRequest,
|
||||
FaceUpdateRequest,
|
||||
FaceResponse,
|
||||
FaceListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
|
||||
from util.face_util import add_binary_data, get_average_feature
|
||||
from util.file_util import save_face_to_up_images
|
||||
|
||||
router = APIRouter(prefix="/api/faces", tags=["人脸管理"])
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 1. 创建人脸记录(核心修正:ID 数据库自增,前端无需传)
|
||||
# 1. 创建人脸记录(使用修复后的路径)
|
||||
# ------------------------------
|
||||
@router.post("", response_model=APIResponse, summary="创建人脸记录(传名称+文件,ID自增)")
|
||||
@router.post("", response_model=APIResponse, summary="创建人脸记录")
|
||||
@encrypt_response()
|
||||
async def create_face(
|
||||
# 前端仅需传:name(可选,Form格式)、file(必传,文件)
|
||||
request: Request,
|
||||
name: str = Form(None, max_length=255, description="名称(可选)"),
|
||||
file: UploadFile = File(..., description="人脸文件(必传,暂不处理内容)")
|
||||
file: UploadFile = File(..., description="人脸文件(必传)")
|
||||
):
|
||||
"""
|
||||
创建人脸记录:
|
||||
- 需登录认证
|
||||
- 前端传参:multipart/form-data 表单(name 可选,file 必传)
|
||||
- ID 由数据库自动生成,无需前端传入
|
||||
- 暂不处理文件内容,eigenvalue 设为 None
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
# 1. 用模型校验 name(仅校验长度,无需ID)
|
||||
face_create = FaceCreateRequest(name=name)
|
||||
client_ip = request.client.host if request.client else ""
|
||||
if not client_ip:
|
||||
raise HTTPException(status_code=400, detail="无法获取客户端IP")
|
||||
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 把文件转为二进制数组
|
||||
# 读取图片并保存(使用修复后的路径逻辑)
|
||||
file_content = await file.read()
|
||||
file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else "jpg"
|
||||
save_result = save_face_to_up_images(
|
||||
client_ip=client_ip,
|
||||
face_name=name,
|
||||
image_bytes=file_content,
|
||||
image_format=file_ext
|
||||
)
|
||||
if not save_result["success"]:
|
||||
raise HTTPException(status_code=500, detail=f"图片保存失败:{save_result['msg']}")
|
||||
db_image_path = save_result["db_path"] # 从修复后的方法获取路径
|
||||
|
||||
# 调用人脸识别得到特征值
|
||||
# 提取人脸特征
|
||||
detect_success, detect_result = add_binary_data(file_content)
|
||||
if not detect_success:
|
||||
raise HTTPException(status_code=400, detail=f"人脸检测失败:{detect_result}")
|
||||
eigenvalue = detect_result
|
||||
|
||||
|
||||
# 2. 插入数据库:无需传 ID(自增),只传 name 和 eigenvalue(None)
|
||||
# 插入数据库
|
||||
insert_query = """
|
||||
INSERT INTO face (name, eigenvalue)
|
||||
VALUES (%s, %s)
|
||||
INSERT INTO face (name, eigenvalue, address)
|
||||
VALUES (%s, %s, %s)
|
||||
"""
|
||||
cursor.execute(insert_query, (face_create.name, None))
|
||||
cursor.execute(insert_query, (face_create.name, str(eigenvalue), db_image_path))
|
||||
conn.commit()
|
||||
|
||||
# 3. 获取数据库自动生成的 ID(关键:用 LAST_INSERT_ID() 查刚插入的记录)
|
||||
select_new_query = "SELECT * FROM face WHERE id = LAST_INSERT_ID()"
|
||||
cursor.execute(select_new_query)
|
||||
# 查询新记录
|
||||
cursor.execute("""
|
||||
SELECT id, name, address, created_at, updated_at
|
||||
FROM face
|
||||
WHERE id = LAST_INSERT_ID()
|
||||
""")
|
||||
created_face = cursor.fetchone()
|
||||
if not created_face:
|
||||
raise HTTPException(status_code=500, detail="创建成功但无法获取记录")
|
||||
|
||||
return APIResponse(
|
||||
code=201,
|
||||
message=f"人脸记录创建成功(ID:{created_face['id']},文件名:{file.filename})",
|
||||
code=200,
|
||||
message=f"人脸记录创建成功(ID: {created_face['id']})",
|
||||
data=FaceResponse(**created_face)
|
||||
)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"创建人脸记录失败:{str(e)}") from e
|
||||
raise HTTPException(status_code=500, detail=f"创建失败: {str(e)}") from e
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
finally:
|
||||
await file.close() # 关闭文件流
|
||||
await file.close()
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 其他接口(获取单条/列表、更新、删除、获取图片)与之前一致,无需修改
|
||||
# ------------------------------
|
||||
# 2. 获取单个人脸记录(不变,用自增ID查询)
|
||||
# 2. 获取单个人脸记录
|
||||
# ------------------------------
|
||||
@router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录")
|
||||
async def get_face(
|
||||
face_id: int, # 这里的 ID 是数据库自增的,前端从创建响应中获取
|
||||
current_user: UserResponse = Depends(get_current_user)
|
||||
):
|
||||
@encrypt_response()
|
||||
async def get_face(face_id: int):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
query = "SELECT * FROM face WHERE id = %s"
|
||||
query = """
|
||||
SELECT id, name, address, created_at, updated_at
|
||||
FROM face
|
||||
WHERE id = %s
|
||||
"""
|
||||
cursor.execute(query, (face_id,))
|
||||
face = cursor.fetchone()
|
||||
|
||||
if not face:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"ID为 {face_id} 的人脸记录不存在"
|
||||
)
|
||||
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="人脸记录查询成功",
|
||||
message="查询成功",
|
||||
data=FaceResponse(**face)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise Exception(f"查询人脸记录失败:{str(e)}") from e
|
||||
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 后续 3.获取所有、4.更新、5.删除 接口(不变,仅用数据库自增的 ID 操作,无需修改)
|
||||
# ------------------------------
|
||||
# 3. 获取所有人脸记录(不变)
|
||||
# 3. 获取人脸列表
|
||||
# ------------------------------
|
||||
@router.get("", response_model=APIResponse, summary="获取所有人脸记录")
|
||||
async def get_all_faces(
|
||||
current_user: UserResponse = Depends(get_current_user)
|
||||
@router.get("", response_model=APIResponse, summary="获取人脸列表(分页+筛选)")
|
||||
@encrypt_response()
|
||||
async def get_face_list(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1, le=100),
|
||||
name: str = Query(None),
|
||||
has_eigenvalue: bool = Query(None)
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
@ -120,47 +146,67 @@ async def get_all_faces(
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
query = "SELECT * FROM face ORDER BY id" # 按自增ID排序
|
||||
cursor.execute(query)
|
||||
faces = cursor.fetchall()
|
||||
where_clause = []
|
||||
params = []
|
||||
if name:
|
||||
where_clause.append("name LIKE %s")
|
||||
params.append(f"%{name}%")
|
||||
if has_eigenvalue is not None:
|
||||
where_clause.append("eigenvalue IS NOT NULL" if has_eigenvalue else "eigenvalue IS NULL")
|
||||
|
||||
# 总记录数
|
||||
count_query = "SELECT COUNT(*) AS total FROM face"
|
||||
if where_clause:
|
||||
count_query += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_query, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 列表数据
|
||||
offset = (page - 1) * page_size
|
||||
list_query = """
|
||||
SELECT id, name, address, created_at, updated_at
|
||||
FROM face
|
||||
"""
|
||||
if where_clause:
|
||||
list_query += " WHERE " + " AND ".join(where_clause)
|
||||
list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset])
|
||||
|
||||
cursor.execute(list_query, params)
|
||||
face_list = cursor.fetchall()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="所有人脸记录查询成功",
|
||||
data=[FaceResponse(**face) for face in faces]
|
||||
message=f"获取成功(共{total}条)",
|
||||
data=FaceListResponse(
|
||||
total=total,
|
||||
faces=[FaceResponse(**face) for face in face_list]
|
||||
)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise Exception(f"查询所有人脸记录失败:{str(e)}") from e
|
||||
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 4. 更新人脸记录(不变,用自增ID更新)
|
||||
# 4. 更新人脸记录
|
||||
# ------------------------------
|
||||
@router.put("/{face_id}", response_model=APIResponse, summary="更新人脸记录")
|
||||
async def update_face(
|
||||
face_id: int,
|
||||
face_update: FaceUpdateRequest,
|
||||
current_user: UserResponse = Depends(get_current_user)
|
||||
):
|
||||
@encrypt_response()
|
||||
async def update_face(face_id: int, face_update: FaceUpdateRequest):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 检查记录是否存在
|
||||
check_query = "SELECT id FROM face WHERE id = %s"
|
||||
cursor.execute(check_query, (face_id,))
|
||||
existing_face = cursor.fetchone()
|
||||
if not existing_face:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"ID为 {face_id} 的人脸记录不存在"
|
||||
)
|
||||
cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,))
|
||||
exist_face = cursor.fetchone()
|
||||
if not exist_face:
|
||||
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
|
||||
old_db_path = exist_face["address"]
|
||||
|
||||
# 构建更新语句
|
||||
update_fields = []
|
||||
params = []
|
||||
if face_update.name is not None:
|
||||
@ -169,6 +215,18 @@ async def update_face(
|
||||
if face_update.eigenvalue is not None:
|
||||
update_fields.append("eigenvalue = %s")
|
||||
params.append(face_update.eigenvalue)
|
||||
if face_update.address is not None:
|
||||
# 删除旧图片(相对路径转绝对路径)
|
||||
if old_db_path:
|
||||
old_abs_path = Path(old_db_path).resolve()
|
||||
if old_abs_path.exists():
|
||||
try:
|
||||
old_abs_path.unlink() # 使用Path方法删除更安全
|
||||
print(f"[FaceRouter] 已删除旧图片:{old_abs_path}")
|
||||
except Exception as e:
|
||||
print(f"[FaceRouter] 删除旧图片失败:{str(e)}")
|
||||
update_fields.append("address = %s")
|
||||
params.append(face_update.address)
|
||||
|
||||
if not update_fields:
|
||||
raise HTTPException(status_code=400, detail="至少需提供一个更新字段")
|
||||
@ -178,92 +236,145 @@ async def update_face(
|
||||
cursor.execute(update_query, params)
|
||||
conn.commit()
|
||||
|
||||
# 查询更新后记录
|
||||
select_query = "SELECT * FROM face WHERE id = %s"
|
||||
cursor.execute(select_query, (face_id,))
|
||||
cursor.execute("""
|
||||
SELECT id, name, address, created_at, updated_at
|
||||
FROM face
|
||||
WHERE id = %s
|
||||
""", (face_id,))
|
||||
updated_face = cursor.fetchone()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="人脸记录更新成功",
|
||||
message="更新成功",
|
||||
data=FaceResponse(**updated_face)
|
||||
)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"更新人脸记录失败:{str(e)}") from e
|
||||
raise HTTPException(status_code=500, detail=f"更新失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 5. 删除人脸记录(不变,用自增ID删除)
|
||||
# 5. 删除人脸记录
|
||||
# ------------------------------
|
||||
@router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录")
|
||||
async def delete_face(
|
||||
face_id: int,
|
||||
current_user: UserResponse = Depends(get_current_user)
|
||||
):
|
||||
@encrypt_response()
|
||||
async def delete_face(face_id: int):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
check_query = "SELECT id FROM face WHERE id = %s"
|
||||
cursor.execute(check_query, (face_id,))
|
||||
existing_face = cursor.fetchone()
|
||||
if not existing_face:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"ID为 {face_id} 的人脸记录不存在"
|
||||
)
|
||||
cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,))
|
||||
exist_face = cursor.fetchone()
|
||||
if not exist_face:
|
||||
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
|
||||
old_db_path = exist_face["address"]
|
||||
|
||||
delete_query = "DELETE FROM face WHERE id = %s"
|
||||
cursor.execute(delete_query, (face_id,))
|
||||
cursor.execute("DELETE FROM face WHERE id = %s", (face_id,))
|
||||
conn.commit()
|
||||
|
||||
# 删除图片
|
||||
extra_msg = ""
|
||||
if old_db_path:
|
||||
old_abs_path = Path(old_db_path).resolve()
|
||||
if old_abs_path.exists():
|
||||
try:
|
||||
old_abs_path.unlink()
|
||||
print(f"[FaceRouter] 已删除图片:{old_abs_path}")
|
||||
extra_msg = "(已同步删除图片)"
|
||||
except Exception as e:
|
||||
print(f"[FaceRouter] 删除图片失败:{str(e)}")
|
||||
extra_msg = "(图片删除失败)"
|
||||
else:
|
||||
extra_msg = "(图片不存在)"
|
||||
else:
|
||||
extra_msg = "(无关联图片)"
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"ID为 {face_id} 的人脸记录删除成功",
|
||||
message=f"ID为 {face_id} 的记录删除成功 {extra_msg}",
|
||||
data=None
|
||||
)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"删除人脸记录失败:{str(e)}") from e
|
||||
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
def get_all_face_name_with_eigenvalue() -> dict:
|
||||
"""
|
||||
获取所有人脸的名称及其对应的特征值,组成字典返回
|
||||
key: 人脸名称(name)
|
||||
value: 人脸特征值(eigenvalue)
|
||||
注:过滤掉name为None的记录,避免字典key为None的情况
|
||||
"""
|
||||
# ------------------------------
|
||||
# 6. 获取人脸图片
|
||||
# ------------------------------
|
||||
@router.get("/{face_id}/image", summary="获取人脸图片")
|
||||
@encrypt_response()
|
||||
async def get_face_image(face_id: int):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
query = "SELECT address, name FROM face WHERE id = %s"
|
||||
cursor.execute(query, (face_id,))
|
||||
face = cursor.fetchone()
|
||||
|
||||
if not face:
|
||||
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
|
||||
|
||||
db_path = face["address"]
|
||||
abs_path = Path(db_path).resolve() # 转为绝对路径
|
||||
if not db_path or not abs_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"图片不存在(路径:{db_path})")
|
||||
|
||||
return FileResponse(
|
||||
path=abs_path,
|
||||
filename=f"face_{face_id}_{face['name'] or '未命名'}.{db_path.split('.')[-1]}",
|
||||
media_type=f"image/{db_path.split('.')[-1]}"
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"获取图片失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 内部工具方法
|
||||
# ------------------------------
|
||||
def get_all_face_name_with_eigenvalue() -> dict:
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 只查询需要的字段,提高效率
|
||||
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
|
||||
cursor.execute(query)
|
||||
faces = cursor.fetchall()
|
||||
|
||||
# 构建name到eigenvalue的映射字典
|
||||
face_dict = {
|
||||
face["name"]: face["eigenvalue"]
|
||||
for face in faces
|
||||
}
|
||||
name_to_eigenvalues = {}
|
||||
for face in faces:
|
||||
name = face["name"]
|
||||
eigenvalue = face["eigenvalue"]
|
||||
if name in name_to_eigenvalues:
|
||||
name_to_eigenvalues[name].append(eigenvalue)
|
||||
else:
|
||||
name_to_eigenvalues[name] = [eigenvalue]
|
||||
|
||||
face_dict = {}
|
||||
for name, eigenvalues in name_to_eigenvalues.items():
|
||||
if len(eigenvalues) > 1:
|
||||
face_dict[name] = get_average_feature(eigenvalues)
|
||||
else:
|
||||
face_dict[name] = eigenvalues[0]
|
||||
|
||||
return face_dict
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取人脸名称与特征值失败:{str(e)}") from e
|
||||
raise Exception(f"获取人脸特征失败: {str(e)}") from e
|
||||
finally:
|
||||
# 确保资源释放
|
||||
db.close_connection(conn, cursor)
|
||||
174
service/file_service.py
Normal file
@ -0,0 +1,174 @@
|
||||
from fastapi import FastAPI, HTTPException, Request, Depends, APIRouter
|
||||
from fastapi.responses import FileResponse
|
||||
import os
|
||||
import logging
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from typing import Annotated
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/file",
|
||||
tags=["文件管理"]
|
||||
)
|
||||
|
||||
# ------------------------------
|
||||
# 4. 路径配置
|
||||
# ------------------------------
|
||||
CURRENT_FILE_PATH = Path(__file__).resolve()
|
||||
PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent # 项目根目录
|
||||
|
||||
# 资源目录定义
|
||||
BASE_IMAGE_DIR_DECT = str((PROJECT_ROOT / "resource" / "dect").resolve()) # 检测图片目录
|
||||
BASE_IMAGE_DIR_UP_IMAGES = str((PROJECT_ROOT / "up_images").resolve()) # 人脸图片目录
|
||||
BASE_MODEL_DIR = str((PROJECT_ROOT / "resource" / "models").resolve()) # 模型文件目录
|
||||
|
||||
# 确保资源目录存在
|
||||
for dir_path in [BASE_IMAGE_DIR_DECT, BASE_IMAGE_DIR_UP_IMAGES, BASE_MODEL_DIR]:
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
print(f"[创建目录] {dir_path}")
|
||||
|
||||
# ------------------------------
|
||||
# 5. 安全依赖项(替代Flask装饰器)
|
||||
# ------------------------------
|
||||
def safe_path_check(root_dir: str):
|
||||
"""
|
||||
安全路径校验依赖项:
|
||||
1. 禁止路径遍历(确保请求文件在根目录内)
|
||||
2. 校验文件存在且为有效文件(非目录)
|
||||
3. 限制文件大小(模型200MB,图片10MB)
|
||||
"""
|
||||
async def dependency(request: Request, resource_path: str):
|
||||
# 统一路径分隔符
|
||||
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
|
||||
# 拼接完整路径
|
||||
full_file_path = os.path.abspath(os.path.join(root_dir, resource_path))
|
||||
|
||||
# 校验1:禁止路径遍历
|
||||
if not full_file_path.startswith(root_dir):
|
||||
print(f"[安全检查] 禁止路径遍历!IP:{request.client.host} | 请求路径:{resource_path}")
|
||||
raise HTTPException(status_code=403, detail="非法路径访问")
|
||||
|
||||
# 校验2:文件存在且为有效文件
|
||||
if not os.path.exists(full_file_path) or not os.path.isfile(full_file_path):
|
||||
print(f"[资源错误] 文件不存在/非文件!IP:{request.client.host} | 路径:{full_file_path}")
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
|
||||
# 校验3:文件大小限制
|
||||
max_size = 200 * 1024 * 1024 if "models" in root_dir else 10 * 1024 * 1024
|
||||
if os.path.getsize(full_file_path) > max_size:
|
||||
print(f"[大小超限] 文件超过{max_size//1024//1024}MB!IP:{request.client.host} | 路径:{full_file_path}")
|
||||
raise HTTPException(status_code=413, detail=f"文件大小超过限制({max_size//1024//1024}MB)")
|
||||
|
||||
return full_file_path
|
||||
return dependency
|
||||
|
||||
# ------------------------------
|
||||
# 6. 核心接口
|
||||
# ------------------------------
|
||||
@router.get("/model/download/{resource_path:path}", summary="模型下载接口")
|
||||
async def download_model(
|
||||
resource_path: str,
|
||||
full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_MODEL_DIR))],
|
||||
request: Request
|
||||
):
|
||||
"""模型下载接口(仅允许 .pt 格式,强制浏览器下载)"""
|
||||
try:
|
||||
dir_path, file_name = os.path.split(full_file_path)
|
||||
|
||||
# 额外校验:仅允许 YOLO 模型格式(.pt)
|
||||
if not file_name.lower().endswith(".pt"):
|
||||
print(f"[格式错误] 非 .pt 模型文件!IP:{request.client.host} | 文件名:{file_name}")
|
||||
raise HTTPException(status_code=415, detail="仅支持 .pt 格式的模型文件")
|
||||
|
||||
print(f"[模型下载] 尝试下载!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
|
||||
|
||||
# 强制下载
|
||||
return FileResponse(
|
||||
full_file_path,
|
||||
filename=file_name,
|
||||
media_type="application/octet-stream"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"[下载异常] IP:{request.client.host} | 错误:{str(e)}")
|
||||
raise HTTPException(status_code=500, detail="服务器内部错误")
|
||||
|
||||
|
||||
@router.get("/up_images/{resource_path:path}", summary="人脸图片访问接口")
|
||||
async def get_face_image(
|
||||
resource_path: str,
|
||||
full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_UP_IMAGES))],
|
||||
request: Request
|
||||
):
|
||||
"""人脸图片访问接口(允许浏览器预览,仅支持常见图片格式)"""
|
||||
try:
|
||||
dir_path, file_name = os.path.split(full_file_path)
|
||||
|
||||
# 图片格式校验
|
||||
allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp")
|
||||
if not file_name.lower().endswith(allowed_ext):
|
||||
print(f"[格式错误] 非图片文件!IP:{request.client.host} | 文件名:{file_name}")
|
||||
raise HTTPException(status_code=415, detail="仅支持常见图片格式")
|
||||
|
||||
print(f"[人脸图片] 尝试访问!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
|
||||
|
||||
return FileResponse(full_file_path)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"[人脸图片异常] IP:{request.client.host} | 错误:{str(e)}")
|
||||
|
||||
|
||||
@router.get("/resource/dect/{resource_path:path}", summary="检测图片访问接口")
|
||||
async def get_dect_image(
|
||||
resource_path: str,
|
||||
full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_DECT))],
|
||||
request: Request
|
||||
):
|
||||
"""检测图片访问接口(允许浏览器预览,仅支持常见图片格式)"""
|
||||
try:
|
||||
dir_path, file_name = os.path.split(full_file_path)
|
||||
|
||||
# 图片格式校验
|
||||
allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp")
|
||||
if not file_name.lower().endswith(allowed_ext):
|
||||
print(f"[格式错误] 非图片文件!IP:{request.client.host} | 文件名:{file_name}")
|
||||
raise HTTPException(status_code=415, detail="仅支持常见图片格式")
|
||||
|
||||
print(f"[检测图片] 尝试访问!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
|
||||
|
||||
return FileResponse(full_file_path)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"[检测图片异常] IP:{request.client.host} | 错误:{str(e)}")
|
||||
raise HTTPException(status_code=500, detail="服务器内部错误")
|
||||
|
||||
|
||||
@router.get("/images/{resource_path:path}", summary="兼容旧接口")
|
||||
async def get_compatible_image(
|
||||
resource_path: str,
|
||||
full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_DECT))],
|
||||
request: Request
|
||||
):
|
||||
"""兼容旧接口(/images/* → 映射到 /resource/dect/*,保留历史兼容性)"""
|
||||
try:
|
||||
dir_path, file_name = os.path.split(full_file_path)
|
||||
|
||||
# 图片格式校验
|
||||
allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp")
|
||||
if not file_name.lower().endswith(allowed_ext):
|
||||
print(f"[格式错误] 非图片文件!IP:{request.client.host} | 文件名:{file_name}")
|
||||
raise HTTPException(status_code=415, detail="仅支持常见图片格式")
|
||||
print(f"[兼容图片] 尝试访问!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
|
||||
|
||||
return FileResponse(full_file_path)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"[兼容图片异常] IP:{request.client.host} | 错误:{str(e)}")
|
||||
raise HTTPException(status_code=500, detail="服务器内部错误")
|
||||
686
service/model_service.py
Normal file
@ -0,0 +1,686 @@
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
# 复用项目依赖
|
||||
from ds.db import db
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from schema.model_schema import (
|
||||
ModelCreateRequest,
|
||||
ModelUpdateRequest,
|
||||
ModelResponse,
|
||||
ModelListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
from util.model_util import load_yolo_model # 模型加载工具
|
||||
|
||||
# 路径配置
|
||||
CURRENT_FILE_PATH = Path(__file__).resolve()
|
||||
PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent
|
||||
MODEL_SAVE_ROOT = PROJECT_ROOT / "resource" / "models"
|
||||
MODEL_SAVE_ROOT.mkdir(exist_ok=True, parents=True)
|
||||
DB_PATH_PREFIX_TO_REMOVE = str(PROJECT_ROOT) + os.sep
|
||||
|
||||
# 模型限制
|
||||
ALLOWED_MODEL_EXT = {"pt"}
|
||||
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
|
||||
|
||||
# 全局模型变量(带版本标识和置信度)
|
||||
global _yolo_model, _current_model_version, _current_conf_threshold
|
||||
_yolo_model = None
|
||||
_current_model_version = None # 模型版本标识
|
||||
_current_conf_threshold = 0.8 # 默认置信度初始值
|
||||
|
||||
router = APIRouter(prefix="/api/models", tags=["模型管理"])
|
||||
|
||||
|
||||
# 服务重启核心工具函数(保持不变)
|
||||
def restart_service():
|
||||
"""重启当前FastAPI服务进程"""
|
||||
print("\n[服务重启] 检测到默认模型更换,开始清理资源并重启...")
|
||||
try:
|
||||
# 关闭所有WebSocket连接
|
||||
try:
|
||||
from ws import connected_clients
|
||||
if connected_clients:
|
||||
print(f"[服务重启] 关闭{len(connected_clients)}个WebSocket旧连接")
|
||||
for ip, conn in list(connected_clients.items()):
|
||||
try:
|
||||
if conn.consumer_task and not conn.consumer_task.done():
|
||||
conn.consumer_task.cancel()
|
||||
conn.websocket.close(code=1001, reason="模型更新,服务重启")
|
||||
connected_clients.pop(ip)
|
||||
except Exception as e:
|
||||
print(f"[服务重启] 关闭客户端{ip}连接失败:{str(e)}")
|
||||
except ImportError:
|
||||
print("[服务重启] 未找到WebSocket连接管理模块,跳过连接关闭")
|
||||
|
||||
# 关闭数据库连接
|
||||
if hasattr(db, "close_all_connections"):
|
||||
db.close_all_connections()
|
||||
else:
|
||||
print("[警告] db模块未实现close_all_connections,可能存在连接泄漏")
|
||||
|
||||
# 启动新进程
|
||||
python_exec = sys.executable
|
||||
current_argv = sys.argv
|
||||
print(f"[服务重启] 启动新进程:{python_exec} {' '.join(current_argv)}")
|
||||
subprocess.Popen(
|
||||
[python_exec] + current_argv,
|
||||
close_fds=True,
|
||||
start_new_session=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE
|
||||
)
|
||||
|
||||
# 退出当前进程
|
||||
print("[服务重启] 新进程已启动,当前进程退出")
|
||||
sys.exit(0)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[服务重启] 重启失败:{str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"模型更换成功,但服务重启失败:{str(e)}") from e
|
||||
|
||||
|
||||
# 模型路径验证工具函数(保持不变)
|
||||
def get_valid_model_abs_path(relative_path: str) -> str:
|
||||
try:
|
||||
relative_path = relative_path.replace("/", os.sep)
|
||||
model_abs_path = PROJECT_ROOT / relative_path
|
||||
model_abs_path = model_abs_path.resolve()
|
||||
model_abs_path_str = str(model_abs_path)
|
||||
|
||||
if not model_abs_path_str.startswith(str(MODEL_SAVE_ROOT)):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"模型路径非法!允许目录:{str(MODEL_SAVE_ROOT)},当前路径:{model_abs_path_str}"
|
||||
)
|
||||
|
||||
if not model_abs_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"模型文件不存在!路径:{model_abs_path_str}"
|
||||
)
|
||||
|
||||
if not model_abs_path.is_file():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"路径不是文件!路径:{model_abs_path_str}"
|
||||
)
|
||||
|
||||
file_size = model_abs_path.stat().st_size
|
||||
if file_size > MAX_MODEL_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"模型文件过大({file_size // 1024 // 1024}MB),超过限制{MAX_MODEL_SIZE // 1024 // 1024}MB"
|
||||
)
|
||||
|
||||
file_ext = model_abs_path.suffix.lower()
|
||||
if file_ext not in [f".{ext}" for ext in ALLOWED_MODEL_EXT]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"模型格式非法(仅支持{ALLOWED_MODEL_EXT})!当前格式:{file_ext}"
|
||||
)
|
||||
|
||||
print(f"[模型路径校验] 成功!路径:{model_abs_path_str},大小:{file_size // 1024}KB")
|
||||
return model_abs_path_str
|
||||
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"路径处理失败:{str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
# 对外提供当前模型(带版本校验)(保持不变)
|
||||
def get_current_yolo_model():
|
||||
"""供检测模块获取当前最新默认模型(仅版本变化时重新加载)"""
|
||||
global _yolo_model, _current_model_version
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor.execute("SELECT path FROM model WHERE is_default = 1")
|
||||
default_model = cursor.fetchone()
|
||||
if not default_model:
|
||||
print("[get_current_yolo_model] 暂无默认模型")
|
||||
return None
|
||||
|
||||
# 1. 计算当前默认模型的唯一版本标识
|
||||
valid_abs_path = get_valid_model_abs_path(default_model["path"])
|
||||
model_stat = os.stat(valid_abs_path)
|
||||
model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
|
||||
|
||||
# 2. 版本未变化则复用已有模型
|
||||
if _yolo_model and _current_model_version == model_version:
|
||||
return _yolo_model
|
||||
|
||||
# 3. 版本变化时重新加载模型
|
||||
_yolo_model = load_yolo_model(valid_abs_path)
|
||||
if _yolo_model:
|
||||
setattr(_yolo_model, "model_path", valid_abs_path)
|
||||
_current_model_version = model_version
|
||||
print(f"[get_current_yolo_model] 模型版本更新,重新加载(新版本:{model_version[:10]}...)")
|
||||
else:
|
||||
print(f"[get_current_yolo_model] 加载最新默认模型失败:{valid_abs_path}")
|
||||
return _yolo_model
|
||||
|
||||
except Exception as e:
|
||||
print(f"[get_current_yolo_model] 加载失败:{str(e)}")
|
||||
return None
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 新增:获取当前置信度阈值
|
||||
def get_current_conf_threshold():
|
||||
"""供检测模块获取当前设置的置信度阈值"""
|
||||
global _current_conf_threshold
|
||||
return _current_conf_threshold
|
||||
|
||||
|
||||
# 1. 上传模型(保持不变)
|
||||
@router.post("", response_model=APIResponse, summary="上传YOLO模型(.pt格式)")
|
||||
@encrypt_response()
|
||||
async def upload_model(
|
||||
name: str = Form(..., description="模型名称"),
|
||||
description: str = Form(None, description="模型描述"),
|
||||
is_default: bool = Form(False, description="是否设为默认模型"),
|
||||
file: UploadFile = File(..., description=f"YOLO模型文件(.pt,最大{MAX_MODEL_SIZE // 1024 // 1024}MB)")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
saved_file_path = None
|
||||
try:
|
||||
# 校验文件
|
||||
file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else ""
|
||||
if file_ext not in ALLOWED_MODEL_EXT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"仅支持{ALLOWED_MODEL_EXT}格式,当前:{file_ext}"
|
||||
)
|
||||
if file.size > MAX_MODEL_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"文件过大!最大{MAX_MODEL_SIZE // 1024 // 1024}MB,当前{file.size // 1024 // 1024}MB"
|
||||
)
|
||||
|
||||
# 保存文件
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
safe_filename = f"model_{timestamp}_{file.filename.replace(' ', '_')}"
|
||||
saved_file_path = MODEL_SAVE_ROOT / safe_filename
|
||||
with open(saved_file_path, "wb") as f:
|
||||
shutil.copyfileobj(file.file, f)
|
||||
saved_file_path.chmod(0o644) # 设置权限
|
||||
|
||||
# 数据库路径处理
|
||||
db_relative_path = str(saved_file_path).replace(DB_PATH_PREFIX_TO_REMOVE, "").replace(os.sep, "/")
|
||||
|
||||
# 数据库操作
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
if is_default:
|
||||
cursor.execute("UPDATE model SET is_default = 0")
|
||||
|
||||
insert_sql = """
|
||||
INSERT INTO model (name, path, is_default, description, file_size)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
"""
|
||||
cursor.execute(insert_sql, (name, db_relative_path, 1 if is_default else 0, description, file.size))
|
||||
conn.commit()
|
||||
|
||||
cursor.execute("SELECT * FROM model WHERE id = LAST_INSERT_ID()")
|
||||
new_model = cursor.fetchone()
|
||||
if not new_model:
|
||||
raise HTTPException(status_code=500, detail="上传成功但无法获取记录")
|
||||
|
||||
# 加载默认模型并更新版本
|
||||
global _yolo_model, _current_model_version
|
||||
if is_default:
|
||||
valid_abs_path = get_valid_model_abs_path(db_relative_path)
|
||||
_yolo_model = load_yolo_model(valid_abs_path)
|
||||
if _yolo_model:
|
||||
setattr(_yolo_model, "model_path", valid_abs_path)
|
||||
model_stat = os.stat(valid_abs_path)
|
||||
_current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"上传成功,但加载默认模型失败(路径:{valid_abs_path})"
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"模型上传成功!ID:{new_model['id']}",
|
||||
data=ModelResponse(** new_model)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
if saved_file_path and saved_file_path.exists():
|
||||
saved_file_path.unlink()
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
except Exception as e:
|
||||
if saved_file_path and saved_file_path.exists():
|
||||
saved_file_path.unlink()
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") from e
|
||||
finally:
|
||||
await file.close()
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 2. 获取模型列表(保持不变)
|
||||
@router.get("", response_model=APIResponse, summary="获取模型列表(分页)")
|
||||
@encrypt_response()
|
||||
async def get_model_list(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1, le=100),
|
||||
name: str = Query(None),
|
||||
is_default: bool = Query(None)
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
where_clause = []
|
||||
params = []
|
||||
if name:
|
||||
where_clause.append("name LIKE %s")
|
||||
params.append(f"%{name}%")
|
||||
if is_default is not None:
|
||||
where_clause.append("is_default = %s")
|
||||
params.append(1 if is_default else 0)
|
||||
|
||||
# 总记录数
|
||||
count_sql = "SELECT COUNT(*) AS total FROM model"
|
||||
if where_clause:
|
||||
count_sql += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_sql, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 分页数据
|
||||
offset = (page - 1) * page_size
|
||||
list_sql = "SELECT * FROM model"
|
||||
if where_clause:
|
||||
list_sql += " WHERE " + " AND ".join(where_clause)
|
||||
list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset])
|
||||
|
||||
cursor.execute(list_sql, params)
|
||||
model_list = cursor.fetchall()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"获取成功!共{total}条记录",
|
||||
data=ModelListResponse(
|
||||
total=total,
|
||||
models=[ModelResponse(** model) for model in model_list]
|
||||
)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 3. 获取默认模型(保持不变)
|
||||
@router.get("/default", response_model=APIResponse, summary="获取当前默认模型")
|
||||
@encrypt_response()
|
||||
async def get_default_model():
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
cursor.execute("SELECT * FROM model WHERE is_default = 1")
|
||||
default_model = cursor.fetchone()
|
||||
|
||||
if not default_model:
|
||||
raise HTTPException(status_code=404, detail="暂无默认模型")
|
||||
|
||||
valid_abs_path = get_valid_model_abs_path(default_model["path"])
|
||||
global _yolo_model, _current_model_version
|
||||
|
||||
if not _yolo_model:
|
||||
_yolo_model = load_yolo_model(valid_abs_path)
|
||||
if _yolo_model:
|
||||
setattr(_yolo_model, "model_path", valid_abs_path)
|
||||
model_stat = os.stat(valid_abs_path)
|
||||
_current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"默认模型存在,但加载失败(路径:{valid_abs_path})"
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="默认模型查询成功",
|
||||
data=ModelResponse(**default_model)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 4. 获取单个模型详情(保持不变)
|
||||
@router.get("/{model_id}", response_model=APIResponse, summary="获取单个模型详情")
|
||||
@encrypt_response()
|
||||
async def get_model(model_id: int):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
model = cursor.fetchone()
|
||||
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}")
|
||||
|
||||
try:
|
||||
model_abs_path = get_valid_model_abs_path(model["path"])
|
||||
except HTTPException as e:
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"查询成功,但路径异常:{e.detail}",
|
||||
data=ModelResponse(** model)
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="查询成功",
|
||||
data=ModelResponse(**model)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 5. 更新模型信息(保持不变)
|
||||
@router.put("/{model_id}", response_model=APIResponse, summary="更新模型信息")
|
||||
@encrypt_response()
|
||||
async def update_model(model_id: int, model_update: ModelUpdateRequest):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
exist_model = cursor.fetchone()
|
||||
if not exist_model:
|
||||
raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}")
|
||||
|
||||
update_fields = []
|
||||
params = []
|
||||
if model_update.name is not None:
|
||||
update_fields.append("name = %s")
|
||||
params.append(model_update.name)
|
||||
if model_update.description is not None:
|
||||
update_fields.append("description = %s")
|
||||
params.append(model_update.description)
|
||||
|
||||
need_load_default = False
|
||||
if model_update.is_default is not None:
|
||||
if model_update.is_default:
|
||||
cursor.execute("UPDATE model SET is_default = 0")
|
||||
update_fields.append("is_default = 1")
|
||||
need_load_default = True
|
||||
else:
|
||||
cursor.execute("SELECT COUNT(*) AS cnt FROM model WHERE is_default = 1")
|
||||
default_count = cursor.fetchone()["cnt"]
|
||||
if default_count == 1 and exist_model["is_default"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="当前是唯一默认模型,不可取消!"
|
||||
)
|
||||
update_fields.append("is_default = 0")
|
||||
|
||||
if not update_fields:
|
||||
raise HTTPException(status_code=400, detail="至少需提供一个更新字段")
|
||||
|
||||
params.append(model_id)
|
||||
update_sql = f"""
|
||||
UPDATE model
|
||||
SET {', '.join(update_fields)}, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = %s
|
||||
"""
|
||||
cursor.execute(update_sql, params)
|
||||
conn.commit()
|
||||
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
updated_model = cursor.fetchone()
|
||||
|
||||
# 更新模型后重置版本标识
|
||||
global _yolo_model, _current_model_version
|
||||
if need_load_default:
|
||||
valid_abs_path = get_valid_model_abs_path(updated_model["path"])
|
||||
_yolo_model = load_yolo_model(valid_abs_path)
|
||||
if _yolo_model:
|
||||
setattr(_yolo_model, "model_path", valid_abs_path)
|
||||
model_stat = os.stat(valid_abs_path)
|
||||
_current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"更新成功,但加载新默认模型失败(路径:{valid_abs_path})"
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="模型更新成功",
|
||||
data=ModelResponse(** updated_model)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 5.1 更换默认模型(添加置信度参数)
|
||||
@router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型(自动重启服务)")
|
||||
@encrypt_response()
|
||||
async def set_default_model(
|
||||
model_id: int,
|
||||
conf_threshold: float = Query(0.8, ge=0.01, le=0.99, description="模型检测置信度阈值(0.01-0.99)")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
conn.autocommit = False # 开启事务
|
||||
|
||||
# 1. 校验目标模型是否存在
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
target_model = cursor.fetchone()
|
||||
if not target_model:
|
||||
raise HTTPException(status_code=404, detail=f"目标模型不存在!ID:{model_id}")
|
||||
|
||||
# 2. 检查是否已为默认模型
|
||||
if target_model["is_default"]:
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"模型ID:{model_id} 已是默认模型,无需更换和重启",
|
||||
data=ModelResponse(**target_model)
|
||||
)
|
||||
|
||||
# 3. 校验目标模型文件合法性
|
||||
try:
|
||||
valid_abs_path = get_valid_model_abs_path(target_model["path"])
|
||||
except HTTPException as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"目标模型文件非法,无法设为默认:{e.detail}"
|
||||
) from e
|
||||
|
||||
# 4. 数据库事务:更新默认模型状态
|
||||
try:
|
||||
cursor.execute("UPDATE model SET is_default = 0, updated_at = CURRENT_TIMESTAMP")
|
||||
cursor.execute(
|
||||
"UPDATE model SET is_default = 1, updated_at = CURRENT_TIMESTAMP WHERE id = %s",
|
||||
(model_id,)
|
||||
)
|
||||
conn.commit()
|
||||
except MySQLError as e:
|
||||
conn.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"更新默认模型状态失败(已回滚):{str(e)}"
|
||||
) from e
|
||||
|
||||
# 5. 验证新模型可加载性
|
||||
test_model = load_yolo_model(valid_abs_path)
|
||||
if not test_model:
|
||||
conn.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"新默认模型加载失败,已回滚状态(路径:{valid_abs_path})"
|
||||
)
|
||||
|
||||
# 6. 重新查询更新后的模型信息
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
updated_model = cursor.fetchone()
|
||||
|
||||
# 7. 重置版本标识和更新置信度
|
||||
global _current_model_version, _current_conf_threshold
|
||||
_current_model_version = None
|
||||
_current_conf_threshold = conf_threshold # 保存动态置信度
|
||||
print(f"[更换默认模型] 已重置模型版本标识,设置新置信度:{conf_threshold}")
|
||||
|
||||
# 8. 延迟重启服务
|
||||
print(f"[更换默认模型] 成功!将在1秒后重启服务以应用新模型(ID:{model_id})")
|
||||
threading.Timer(
|
||||
interval=1.0,
|
||||
function=restart_service
|
||||
).start()
|
||||
|
||||
# 9. 返回成功响应
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"已成功更换默认模型(ID:{model_id}),置信度:{conf_threshold}!服务将在1秒后自动重启以应用新模型",
|
||||
data=ModelResponse(** updated_model)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
if conn:
|
||||
conn.autocommit = True
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 6. 删除模型(保持不变)
|
||||
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
|
||||
@encrypt_response()
|
||||
async def delete_model(model_id: int):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
exist_model = cursor.fetchone()
|
||||
if not exist_model:
|
||||
raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}")
|
||||
if exist_model["is_default"]:
|
||||
raise HTTPException(status_code=400, detail="默认模型不可删除!")
|
||||
|
||||
try:
|
||||
model_abs_path_str = get_valid_model_abs_path(exist_model["path"])
|
||||
model_abs_path = Path(model_abs_path_str)
|
||||
except HTTPException as e:
|
||||
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
|
||||
conn.commit()
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"记录删除成功,文件异常:{e.detail}",
|
||||
data=None
|
||||
)
|
||||
|
||||
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
|
||||
conn.commit()
|
||||
|
||||
extra_msg = ""
|
||||
try:
|
||||
model_abs_path.unlink()
|
||||
extra_msg = f"(已删除文件)"
|
||||
except Exception as e:
|
||||
extra_msg = f"(文件删除失败:{str(e)})"
|
||||
|
||||
# 如果删除的是当前加载的模型,重置缓存
|
||||
global _yolo_model, _current_model_version
|
||||
if _yolo_model and str(getattr(_yolo_model, "model_path", "")) == model_abs_path_str:
|
||||
_yolo_model = None
|
||||
_current_model_version = None
|
||||
print(f"[模型删除] 已清空全局模型缓存(路径:{model_abs_path_str})")
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"模型删除成功!ID:{model_id} {extra_msg}",
|
||||
data=None
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 7. 下载模型文件(保持不变)
|
||||
@router.get("/{model_id}/download", summary="下载模型文件")
|
||||
@encrypt_response()
|
||||
async def download_model(model_id: int):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
model = cursor.fetchone()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}")
|
||||
|
||||
valid_abs_path = get_valid_model_abs_path(model["path"])
|
||||
model_abs_path = Path(valid_abs_path)
|
||||
|
||||
return FileResponse(
|
||||
path=model_abs_path,
|
||||
filename=f"model_{model_id}_{model['name']}.pt",
|
||||
media_type="application/octet-stream"
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
@ -1,15 +1,22 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from mysql.connector import Error as MySQLError
|
||||
from typing import Optional
|
||||
|
||||
from ds.db import db
|
||||
from schema.sensitive_schema import SensitiveCreateRequest, SensitiveUpdateRequest, SensitiveResponse
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from schema.sensitive_schema import (
|
||||
SensitiveCreateRequest,
|
||||
SensitiveUpdateRequest,
|
||||
SensitiveResponse,
|
||||
SensitiveListResponse # 导入新增的分页响应模型
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
from middle.auth_middleware import get_current_user
|
||||
from schema.user_schema import UserResponse
|
||||
|
||||
# 创建敏感信息接口路由(前缀 /sensitives、标签用于 Swagger 分类)
|
||||
router = APIRouter(
|
||||
prefix="/sensitives",
|
||||
prefix="/api/sensitives",
|
||||
tags=["敏感信息管理"]
|
||||
)
|
||||
|
||||
@ -18,10 +25,13 @@ router = APIRouter(
|
||||
# 1. 创建敏感信息记录
|
||||
# ------------------------------
|
||||
@router.post("", response_model=APIResponse, summary="创建敏感信息记录")
|
||||
@encrypt_response()
|
||||
async def create_sensitive(
|
||||
sensitive: SensitiveCreateRequest): # 添加了登录认证依赖
|
||||
sensitive: SensitiveCreateRequest,
|
||||
current_user: UserResponse = Depends(get_current_user) # 补充登录认证依赖(与其他接口保持一致)
|
||||
):
|
||||
"""
|
||||
创建敏感信息记录:
|
||||
创建敏感信息记录:
|
||||
- 需登录认证
|
||||
- 插入新的敏感信息记录到数据库(ID由数据库自动生成)
|
||||
- 返回创建成功信息
|
||||
@ -32,10 +42,10 @@ async def create_sensitive(
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 插入新敏感信息记录到数据库(不包含ID,由数据库自动生成)
|
||||
# 插入新敏感信息记录到数据库(不包含ID、由数据库自动生成)
|
||||
insert_query = """
|
||||
INSERT INTO sensitives (name)
|
||||
VALUES (%s)
|
||||
INSERT INTO sensitives (name, created_at, updated_at)
|
||||
VALUES (%s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
"""
|
||||
cursor.execute(insert_query, (sensitive.name,))
|
||||
conn.commit()
|
||||
@ -49,29 +59,33 @@ async def create_sensitive(
|
||||
created_sensitive = cursor.fetchone()
|
||||
|
||||
return APIResponse(
|
||||
code=201, # 201 表示资源创建成功
|
||||
code=200, # 200 表示资源创建成功
|
||||
message="敏感信息记录创建成功",
|
||||
data=SensitiveResponse(**created_sensitive)
|
||||
)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"创建敏感信息记录失败:{str(e)}") from e
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"创建敏感信息记录失败: {str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 以下接口代码保持不变
|
||||
# ------------------------------
|
||||
# 2. 获取单个敏感信息记录
|
||||
# ------------------------------
|
||||
|
||||
@router.get("/{sensitive_id}", response_model=APIResponse, summary="获取单个敏感信息记录")
|
||||
@encrypt_response()
|
||||
async def get_sensitive(
|
||||
sensitive_id: int,
|
||||
current_user: UserResponse = Depends(get_current_user) # 需登录认证
|
||||
):
|
||||
"""
|
||||
获取单个敏感信息记录:
|
||||
获取单个敏感信息记录:
|
||||
- 需登录认证
|
||||
- 根据ID查询敏感信息记录
|
||||
- 返回查询到的敏感信息
|
||||
@ -98,21 +112,29 @@ async def get_sensitive(
|
||||
data=SensitiveResponse(**sensitive)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise Exception(f"查询敏感信息记录失败:{str(e)}") from e
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"查询敏感信息记录失败: {str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 3. 获取所有敏感信息记录
|
||||
# 3. 获取敏感信息分页列表(重构:支持分页+关键词搜索)
|
||||
# ------------------------------
|
||||
@router.get("", response_model=APIResponse, summary="获取所有敏感信息记录")
|
||||
async def get_all_sensitives():
|
||||
@router.get("", response_model=APIResponse, summary="获取敏感信息分页列表(支持关键词搜索)")
|
||||
@encrypt_response()
|
||||
async def get_sensitive_list(
|
||||
page: int = Query(1, ge=1, description="页码(默认1,最小1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数(默认10,1-100)"),
|
||||
name: Optional[str] = Query(None, description="敏感词关键词搜索(模糊匹配)")
|
||||
):
|
||||
"""
|
||||
获取所有敏感信息记录:
|
||||
获取敏感信息分页列表:
|
||||
- 需登录认证
|
||||
- 查询所有敏感信息记录(不需要分页)
|
||||
- 返回所有敏感信息列表
|
||||
- 支持分页(page/page_size)和敏感词关键词模糊搜索(name)
|
||||
- 返回总记录数+当前页数据
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
@ -120,17 +142,49 @@ async def get_all_sensitives():
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
query = "SELECT * FROM sensitives ORDER BY id"
|
||||
cursor.execute(query)
|
||||
sensitives = cursor.fetchall()
|
||||
# 1. 构建查询条件(支持关键词搜索)
|
||||
where_clause = []
|
||||
params = []
|
||||
if name:
|
||||
where_clause.append("name LIKE %s")
|
||||
params.append(f"%{name}%") # 模糊匹配关键词
|
||||
|
||||
# 2. 查询总记录数(用于分页计算)
|
||||
count_sql = "SELECT COUNT(*) AS total FROM sensitives"
|
||||
if where_clause:
|
||||
count_sql += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_sql, params.copy()) # 复制参数列表,避免后续污染
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 3. 计算分页偏移量
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# 4. 分页查询敏感词数据(按更新时间倒序,最新的在前)
|
||||
list_sql = "SELECT * FROM sensitives"
|
||||
if where_clause:
|
||||
list_sql += " WHERE " + " AND ".join(where_clause)
|
||||
# 排序+分页(LIMIT 条数 OFFSET 偏移量)
|
||||
list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
|
||||
# 补充分页参数(page_size和offset)
|
||||
params.extend([page_size, offset])
|
||||
|
||||
cursor.execute(list_sql, params)
|
||||
sensitive_list = cursor.fetchall()
|
||||
|
||||
# 5. 构造分页响应数据
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="所有敏感信息记录查询成功",
|
||||
data=[SensitiveResponse(**sensitive) for sensitive in sensitives]
|
||||
message=f"敏感信息列表查询成功(共{total}条记录,当前第{page}页)",
|
||||
data=SensitiveListResponse(
|
||||
total=total,
|
||||
sensitives=[SensitiveResponse(**item) for item in sensitive_list]
|
||||
)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise Exception(f"查询所有敏感信息记录失败:{str(e)}") from e
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"查询敏感信息列表失败: {str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
@ -139,13 +193,14 @@ async def get_all_sensitives():
|
||||
# 4. 更新敏感信息记录
|
||||
# ------------------------------
|
||||
@router.put("/{sensitive_id}", response_model=APIResponse, summary="更新敏感信息记录")
|
||||
@encrypt_response()
|
||||
async def update_sensitive(
|
||||
sensitive_id: int,
|
||||
sensitive_update: SensitiveUpdateRequest,
|
||||
current_user: UserResponse = Depends(get_current_user) # 需登录认证
|
||||
):
|
||||
"""
|
||||
更新敏感信息记录:
|
||||
更新敏感信息记录:
|
||||
- 需登录认证
|
||||
- 根据ID更新敏感信息记录
|
||||
- 返回更新后的敏感信息
|
||||
@ -177,14 +232,16 @@ async def update_sensitive(
|
||||
if not update_fields:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="至少需要提供一个字段进行更新"
|
||||
detail="至少需要提供一个字段进行更新(如:name)"
|
||||
)
|
||||
|
||||
params.append(sensitive_id) # WHERE条件的参数
|
||||
# 补充更新时间和WHERE条件参数
|
||||
update_fields.append("updated_at = CURRENT_TIMESTAMP")
|
||||
params.append(sensitive_id)
|
||||
|
||||
update_query = f"""
|
||||
UPDATE sensitives
|
||||
SET {', '.join(update_fields)}, updated_at = CURRENT_TIMESTAMP
|
||||
SET {', '.join(update_fields)}
|
||||
WHERE id = %s
|
||||
"""
|
||||
cursor.execute(update_query, params)
|
||||
@ -203,7 +260,10 @@ async def update_sensitive(
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"更新敏感信息记录失败:{str(e)}") from e
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"更新敏感信息记录失败: {str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
@ -212,12 +272,13 @@ async def update_sensitive(
|
||||
# 5. 删除敏感信息记录
|
||||
# ------------------------------
|
||||
@router.delete("/{sensitive_id}", response_model=APIResponse, summary="删除敏感信息记录")
|
||||
@encrypt_response()
|
||||
async def delete_sensitive(
|
||||
sensitive_id: int,
|
||||
current_user: UserResponse = Depends(get_current_user) # 需登录认证
|
||||
):
|
||||
"""
|
||||
删除敏感信息记录:
|
||||
删除敏感信息记录:
|
||||
- 需登录认证
|
||||
- 根据ID删除敏感信息记录
|
||||
- 返回删除成功信息
|
||||
@ -251,14 +312,20 @@ async def delete_sensitive(
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"删除敏感信息记录失败:{str(e)}") from e
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"删除敏感信息记录失败: {str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 6. 业务辅助函数:获取所有敏感词(供其他模块调用)
|
||||
# ------------------------------
|
||||
def get_all_sensitive_words() -> list[str]:
|
||||
"""
|
||||
获取所有敏感词,返回字符串数组
|
||||
获取所有敏感词(返回纯字符串列表,用于过滤业务)
|
||||
|
||||
返回:
|
||||
list[str]: 包含所有敏感词的数组
|
||||
@ -273,17 +340,17 @@ def get_all_sensitive_words() -> list[str]:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 执行查询,只获取敏感词字段
|
||||
# 执行查询(只获取敏感词字段,按ID排序)
|
||||
query = "SELECT name FROM sensitives ORDER BY id"
|
||||
cursor.execute(query)
|
||||
sensitive_records = cursor.fetchall()
|
||||
|
||||
# 提取敏感词到数组中
|
||||
# 提取敏感词到纯字符串数组
|
||||
return [record['name'] for record in sensitive_records]
|
||||
|
||||
except MySQLError as e:
|
||||
# 数据库错误处理
|
||||
raise MySQLError(f"查询敏感词失败:{str(e)}") from e
|
||||
# 数据库错误向上抛出,由调用方处理
|
||||
raise MySQLError(f"查询敏感词列表失败: {str(e)}") from e
|
||||
finally:
|
||||
# 确保资源正确释放
|
||||
# 确保数据库连接正确释放
|
||||
db.close_connection(conn, cursor)
|
||||
@ -1,9 +1,11 @@
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from schema.user_schema import UserRegisterRequest, UserLoginRequest, UserResponse
|
||||
from schema.response_schema import APIResponse
|
||||
from middle.auth_middleware import (
|
||||
@ -11,12 +13,12 @@ from middle.auth_middleware import (
|
||||
verify_password,
|
||||
create_access_token,
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
get_current_user
|
||||
get_current_user # 仅保留登录用户校验,移除is_admin导入
|
||||
)
|
||||
|
||||
# 创建用户接口路由(前缀 /users、标签用于 Swagger 分类)
|
||||
router = APIRouter(
|
||||
prefix="/users",
|
||||
prefix="/api/users",
|
||||
tags=["用户管理"]
|
||||
)
|
||||
|
||||
@ -25,9 +27,10 @@ router = APIRouter(
|
||||
# 1. 用户注册接口
|
||||
# ------------------------------
|
||||
@router.post("/register", response_model=APIResponse, summary="用户注册")
|
||||
@encrypt_response()
|
||||
async def user_register(request: UserRegisterRequest):
|
||||
"""
|
||||
用户注册:
|
||||
用户注册:
|
||||
- 校验用户名是否已存在
|
||||
- 加密密码后插入数据库
|
||||
- 返回注册成功信息
|
||||
@ -61,13 +64,13 @@ async def user_register(request: UserRegisterRequest):
|
||||
|
||||
# 4. 返回注册成功响应
|
||||
return APIResponse(
|
||||
code=201, # 201 表示资源创建成功
|
||||
code=200, # 200 表示资源创建成功
|
||||
message=f"用户 '{request.username}' 注册成功",
|
||||
data=None
|
||||
)
|
||||
except MySQLError as e:
|
||||
conn.rollback() # 数据库错误时回滚事务
|
||||
raise Exception(f"注册失败:{str(e)}") from e
|
||||
raise Exception(f"注册失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
@ -76,9 +79,10 @@ async def user_register(request: UserRegisterRequest):
|
||||
# 2. 用户登录接口
|
||||
# ------------------------------
|
||||
@router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token)")
|
||||
@encrypt_response()
|
||||
async def user_login(request: UserLoginRequest):
|
||||
"""
|
||||
用户登录:
|
||||
用户登录:
|
||||
- 校验用户名是否存在
|
||||
- 校验密码是否正确
|
||||
- 生成 JWT Token 并返回
|
||||
@ -89,7 +93,7 @@ async def user_login(request: UserLoginRequest):
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 修复:SQL查询添加 created_at 和 updated_at 字段
|
||||
# 修复: SQL查询添加 created_at 和 updated_at 字段
|
||||
query = """
|
||||
SELECT id, username, password, created_at, updated_at
|
||||
FROM users
|
||||
@ -129,7 +133,7 @@ async def user_login(request: UserLoginRequest):
|
||||
}
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise Exception(f"登录失败:{str(e)}") from e
|
||||
raise Exception(f"登录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
@ -138,12 +142,13 @@ async def user_login(request: UserLoginRequest):
|
||||
# 3. 获取当前登录用户信息(需认证)
|
||||
# ------------------------------
|
||||
@router.get("/me", response_model=APIResponse, summary="获取当前用户信息")
|
||||
@encrypt_response()
|
||||
async def get_current_user_info(
|
||||
current_user: UserResponse = Depends(get_current_user) # 依赖认证中间件
|
||||
):
|
||||
"""
|
||||
获取当前登录用户信息:
|
||||
- 需在请求头携带 Token(格式:Bearer <token>)
|
||||
获取当前登录用户信息:
|
||||
- 需在请求头携带 Token(格式: Bearer <token>)
|
||||
- 认证通过后返回用户信息
|
||||
"""
|
||||
return APIResponse(
|
||||
@ -152,3 +157,99 @@ async def get_current_user_info(
|
||||
data=current_user
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 4. 获取用户列表(仅需登录权限)
|
||||
# ------------------------------
|
||||
@router.get("/list", response_model=APIResponse, summary="获取用户列表")
|
||||
@encrypt_response()
|
||||
async def get_user_list(
|
||||
page: int = Query(1, ge=1, description="页码,从1开始"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"),
|
||||
username: Optional[str] = Query(None, description="用户名模糊搜索"),
|
||||
current_user: UserResponse = Depends(get_current_user) # 仅需登录即可访问(移除管理员校验)
|
||||
):
|
||||
"""
|
||||
获取用户列表:
|
||||
- 需登录权限(请求头携带 Token: Bearer <token>)
|
||||
- 支持分页查询(page=页码,page_size=每页条数)
|
||||
- 支持用户名模糊搜索(如输入"test"可匹配"test123"、"admin_test"等)
|
||||
- 仅返回用户ID、用户名、创建时间、更新时间(不包含密码等敏感信息)
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 计算分页偏移量(page从1开始,偏移量=(页码-1)*每页条数)
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# 基础查询(仅查非敏感字段)
|
||||
base_query = """
|
||||
SELECT id, username, created_at, updated_at
|
||||
FROM users
|
||||
"""
|
||||
# 总条数查询(用于分页计算)
|
||||
count_query = "SELECT COUNT(*) as total FROM users"
|
||||
|
||||
# 条件拼接(支持用户名模糊搜索)
|
||||
conditions = []
|
||||
params = []
|
||||
if username:
|
||||
conditions.append("username LIKE %s")
|
||||
params.append(f"%{username}%") # 模糊匹配:%表示任意字符
|
||||
|
||||
# 构建最终查询语句
|
||||
if conditions:
|
||||
where_clause = " WHERE " + " AND ".join(conditions)
|
||||
final_query = f"{base_query}{where_clause} LIMIT %s OFFSET %s"
|
||||
final_count_query = f"{count_query}{where_clause}"
|
||||
params.extend([page_size, offset]) # 追加分页参数
|
||||
else:
|
||||
final_query = f"{base_query} LIMIT %s OFFSET %s"
|
||||
final_count_query = count_query
|
||||
params = [page_size, offset]
|
||||
|
||||
# 1. 查询用户列表数据
|
||||
cursor.execute(final_query, params)
|
||||
users = cursor.fetchall()
|
||||
|
||||
# 2. 查询总条数(用于计算总页数)
|
||||
count_params = [f"%{username}%"] if username else []
|
||||
cursor.execute(final_count_query, count_params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 3. 转换为UserResponse模型(确保字段匹配)
|
||||
user_list = [
|
||||
UserResponse(
|
||||
id=user["id"],
|
||||
username=user["username"],
|
||||
created_at=user["created_at"],
|
||||
updated_at=user["updated_at"]
|
||||
)
|
||||
for user in users
|
||||
]
|
||||
|
||||
# 4. 计算总页数(向上取整,如11条数据每页10条=2页)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# 返回结果(包含列表和分页信息)
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取用户列表成功",
|
||||
data={
|
||||
"users": user_list,
|
||||
"pagination": {
|
||||
"page": page, # 当前页码
|
||||
"page_size": page_size, # 每页条数
|
||||
"total": total, # 总数据量
|
||||
"total_pages": total_pages # 总页数
|
||||
}
|
||||
}
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取用户列表失败: {str(e)}") from e
|
||||
finally:
|
||||
# 无论成功失败,都关闭数据库连接
|
||||
db.close_connection(conn, cursor)
|
||||
156
util/face_util.py
Normal file
@ -0,0 +1,156 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import insightface
|
||||
from insightface.app import FaceAnalysis
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
import logging
|
||||
|
||||
# 配置日志(便于排查)
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - [FaceUtil] - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 全局变量存储InsightFace引擎和特征列表
|
||||
_insightface_app = None
|
||||
_feature_list = []
|
||||
|
||||
|
||||
def init_insightface():
|
||||
"""初始化InsightFace引擎(确保成功后再使用)"""
|
||||
global _insightface_app
|
||||
try:
|
||||
if _insightface_app is not None:
|
||||
logger.info("InsightFace引擎已初始化,无需重复执行")
|
||||
return _insightface_app
|
||||
|
||||
logger.info("正在初始化InsightFace引擎(模型:buffalo_l)...")
|
||||
# 手动指定模型下载路径(避免权限问题,可选)
|
||||
app = FaceAnalysis(
|
||||
name='buffalo_l',
|
||||
root='~/.insightface', # 模型默认下载路径
|
||||
providers=['CPUExecutionProvider'] # 强制用CPU(若有GPU可加'CUDAExecutionProvider')
|
||||
)
|
||||
app.prepare(ctx_id=0, det_size=(640, 640)) # det_size越大,小人脸检测越准
|
||||
logger.info("InsightFace引擎初始化完成")
|
||||
_insightface_app = app
|
||||
return app
|
||||
except Exception as e:
|
||||
logger.error(f"InsightFace初始化失败:{str(e)}", exc_info=True) # 打印详细堆栈
|
||||
_insightface_app = None
|
||||
return None
|
||||
|
||||
|
||||
def add_binary_data(binary_data):
|
||||
"""
|
||||
接收单张图片的二进制数据、提取特征并保存
|
||||
返回:(True, 特征值numpy数组) 或 (False, 错误信息字符串)
|
||||
"""
|
||||
global _insightface_app, _feature_list
|
||||
|
||||
# 1. 先检查引擎是否初始化成功
|
||||
if not _insightface_app:
|
||||
init_result = init_insightface() # 尝试重新初始化
|
||||
if not init_result:
|
||||
error_msg = "InsightFace引擎未初始化,无法检测人脸"
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
try:
|
||||
# 2. 验证二进制数据有效性
|
||||
if len(binary_data) < 1024: # 过滤过小的无效图片(小于1KB)
|
||||
error_msg = f"图片过小({len(binary_data)}字节),可能不是有效图片"
|
||||
logger.warning(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
# 3. 二进制数据转CV2格式(关键步骤,避免通道错误)
|
||||
try:
|
||||
img = Image.open(BytesIO(binary_data)).convert("RGB") # 强制转RGB
|
||||
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) # InsightFace需要BGR格式
|
||||
except Exception as e:
|
||||
error_msg = f"图片格式转换失败:{str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return False, error_msg
|
||||
|
||||
# 4. 检查图片尺寸(避免极端尺寸导致检测失败)
|
||||
height, width = frame.shape[:2]
|
||||
if height < 64 or width < 64: # 人脸检测最小建议尺寸
|
||||
error_msg = f"图片尺寸过小({width}x{height}),需至少64x64像素"
|
||||
logger.warning(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
# 5. 调用InsightFace检测人脸
|
||||
logger.info(f"开始检测人脸(图片尺寸:{width}x{height},格式:BGR)")
|
||||
faces = _insightface_app.get(frame)
|
||||
|
||||
if not faces:
|
||||
error_msg = "未检测到人脸(请确保图片包含清晰正面人脸,无遮挡、不模糊)"
|
||||
logger.warning(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
# 6. 提取特征并保存
|
||||
current_feature = faces[0].embedding
|
||||
_feature_list.append(current_feature)
|
||||
logger.info(f"人脸检测成功,提取特征值(维度:{current_feature.shape[0]}),累计特征数:{len(_feature_list)}")
|
||||
return True, current_feature
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"处理图片时发生异常:{str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return False, error_msg
|
||||
|
||||
|
||||
# 以下函数保持不变(get_average_feature/clear_features/get_feature_list)
|
||||
def get_average_feature(features=None):
|
||||
global _feature_list
|
||||
try:
|
||||
if features is None:
|
||||
features = _feature_list
|
||||
if not isinstance(features, list) or len(features) == 0:
|
||||
logger.warning("输入必须是包含至少一个特征值的列表")
|
||||
return None
|
||||
|
||||
processed_features = []
|
||||
for i, embedding in enumerate(features):
|
||||
try:
|
||||
if isinstance(embedding, str):
|
||||
embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip()
|
||||
embedding_list = [float(num) for num in embedding_str.split() if num.strip()]
|
||||
embedding_np = np.array(embedding_list, dtype=np.float32)
|
||||
else:
|
||||
embedding_np = np.array(embedding, dtype=np.float32)
|
||||
|
||||
if len(embedding_np.shape) == 1:
|
||||
processed_features.append(embedding_np)
|
||||
logger.info(f"已添加第 {i + 1} 个特征值用于计算平均值")
|
||||
else:
|
||||
logger.warning(f"跳过第 {i + 1} 个特征值:不是一维数组")
|
||||
except Exception as e:
|
||||
logger.error(f"处理第 {i + 1} 个特征值时出错:{str(e)}")
|
||||
|
||||
if not processed_features:
|
||||
logger.warning("没有有效的特征值用于计算平均值")
|
||||
return None
|
||||
|
||||
dims = {feat.shape[0] for feat in processed_features}
|
||||
if len(dims) > 1:
|
||||
logger.error(f"特征值维度不一致:{dims},无法计算平均值")
|
||||
return None
|
||||
|
||||
avg_feature = np.mean(processed_features, axis=0)
|
||||
logger.info(f"计算成功:{len(processed_features)} 个特征值的平均向量(维度:{avg_feature.shape[0]})")
|
||||
return avg_feature
|
||||
except Exception as e:
|
||||
logger.error(f"计算平均特征值出错:{str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def clear_features():
|
||||
global _feature_list
|
||||
_feature_list = []
|
||||
logger.info("已清空所有特征数据")
|
||||
|
||||
|
||||
def get_feature_list():
|
||||
global _feature_list
|
||||
logger.info(f"当前特征列表长度:{len(_feature_list)}")
|
||||
return _feature_list.copy()
|
||||
86
util/file_util.py
Normal file
@ -0,0 +1,86 @@
|
||||
import os
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def save_face_to_up_images(
|
||||
client_ip: str,
|
||||
face_name: str,
|
||||
image_bytes: bytes,
|
||||
image_format: str = "jpg"
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
保存人脸图片到 `/up_images/用户IP/人脸名字/` 路径
|
||||
确保db_path以 /api/file/up_images 开头,且统一使用正斜杠
|
||||
本地不创建/api/file/文件夹,仅URL访问时使用该前缀路由
|
||||
|
||||
参数:
|
||||
client_ip: 客户端IP(原始格式,如192.168.1.101)
|
||||
face_name: 人脸名称(用户输入,可为空)
|
||||
image_bytes: 人脸图片二进制数据
|
||||
image_format: 图片格式(默认jpg)
|
||||
|
||||
返回:
|
||||
字典:success(是否成功)、db_path(存数据库的路径,带/api/file/前缀)、local_abs_path(本地绝对路径)、msg(提示)
|
||||
"""
|
||||
try:
|
||||
# 1. 基础参数校验(不变)
|
||||
if not client_ip.strip():
|
||||
return {"success": False, "db_path": "", "local_abs_path": "", "msg": "客户端IP不能为空"}
|
||||
if not image_bytes:
|
||||
return {"success": False, "db_path": "", "local_abs_path": "", "msg": "图片二进制数据为空"}
|
||||
if image_format.lower() not in ["jpg", "jpeg", "png"]:
|
||||
return {"success": False, "db_path": "", "local_abs_path": "", "msg": "仅支持jpg/jpeg/png格式"}
|
||||
|
||||
# 2. 处理特殊字符(避免路径错误)(不变)
|
||||
safe_ip = client_ip.strip().replace(".", "_") # IP中的.替换为_
|
||||
safe_face_name = face_name.strip() if (face_name and face_name.strip()) else "未命名"
|
||||
safe_face_name = "".join([c for c in safe_face_name if c not in r'\/:*?"<>|']) # 过滤非法字符
|
||||
|
||||
# 3. 构建根目录(强制转为绝对路径,避免相对路径混淆)
|
||||
root_dir = Path("up_images").resolve()
|
||||
if not root_dir.exists():
|
||||
root_dir.mkdir(parents=True, exist_ok=True)
|
||||
print(f"[FileUtil] 已创建up_images根目录:{root_dir}")
|
||||
|
||||
# 4. 构建文件层级路径(确保在root_dir子目录下)(不变)
|
||||
ip_dir = root_dir / safe_ip
|
||||
face_name_dir = ip_dir / safe_face_name
|
||||
face_name_dir.mkdir(parents=True, exist_ok=True)
|
||||
print(f"[FileUtil] 图片存储目录(本地):{face_name_dir}")
|
||||
|
||||
# 5. 生成唯一文件名(毫秒级时间戳)(不变)
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3]
|
||||
|
||||
image_filename = f"face_{safe_ip}_{timestamp}.{image_format.lower()}"
|
||||
local_abs_path = face_name_dir / image_filename
|
||||
|
||||
if not local_abs_path.resolve().is_relative_to(root_dir.resolve()):
|
||||
raise Exception(f"图片路径不在up_images根目录下(安全校验失败):{local_abs_path}")
|
||||
|
||||
# 数据库存储路径:核心修改——在原有relative_path前添加 /api/file/ 前缀
|
||||
relative_path = local_abs_path.relative_to(root_dir.parent)
|
||||
|
||||
relative_path_str = str(relative_path).replace("\\", "/")
|
||||
# 2. 再拼接/api/file/前缀
|
||||
db_path = f"/api/file/{relative_path_str}"
|
||||
|
||||
# 7. 写入图片文件(不变)
|
||||
with open(local_abs_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
print(f"[FileUtil] 图片保存成功:")
|
||||
print(f" 数据库路径(带/api/file/前缀):{db_path}")
|
||||
print(f" 本地绝对路径(无/api/file/):{local_abs_path}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"db_path": db_path,
|
||||
"local_abs_path": str(local_abs_path),
|
||||
"msg": "图片保存成功"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"图片保存失败:{str(e)}"
|
||||
print(f"[FileUtil] 错误:{error_msg}")
|
||||
return {"success": False, "db_path": "", "local_abs_path": "", "msg": error_msg}
|
||||
61
util/model_util.py
Normal file
@ -0,0 +1,61 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import traceback
|
||||
from ultralytics import YOLO
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def load_yolo_model(model_path: str) -> Optional[YOLO]:
|
||||
"""
|
||||
加载YOLO模型(支持v5/v8),并校验模型有效性
|
||||
:param model_path: 模型文件的绝对路径
|
||||
:return: 加载成功返回YOLO模型实例,失败返回None
|
||||
"""
|
||||
try:
|
||||
# 加载前的基础信息检查
|
||||
print(f"\n[模型工具] 开始加载模型:{model_path}")
|
||||
print(f"[模型工具] 文件是否存在:{os.path.exists(model_path)}")
|
||||
if os.path.exists(model_path):
|
||||
print(f"[模型工具] 文件大小:{os.path.getsize(model_path) / 1024 / 1024:.2f} MB")
|
||||
|
||||
# 强制重新加载模型,避免缓存问题
|
||||
model = YOLO(model_path)
|
||||
|
||||
# 兼容性校验:使用numpy空数组测试模型
|
||||
dummy_image = np.zeros((640, 640, 3), dtype=np.uint8)
|
||||
|
||||
try:
|
||||
# 优先使用新版本参数
|
||||
model.predict(
|
||||
source=dummy_image,
|
||||
imgsz=640,
|
||||
conf=0.25,
|
||||
verbose=False,
|
||||
stream=False
|
||||
)
|
||||
except Exception as pred_e:
|
||||
print(f"[模型工具] 预测校验兼容处理:{str(pred_e)}")
|
||||
# 兼容旧版本YOLO参数
|
||||
model.predict(
|
||||
img=dummy_image,
|
||||
imgsz=640,
|
||||
conf=0.25,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# 验证模型基本属性
|
||||
if not hasattr(model, 'names'):
|
||||
print("[模型工具] 警告:模型缺少类别名称属性")
|
||||
else:
|
||||
print(f"[模型工具] 模型包含类别:{list(model.names.values())[:5]}...") # 显示前5个类别
|
||||
|
||||
print(f"[模型工具] 模型加载成功!")
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
# 详细错误信息输出
|
||||
print(f"\n[模型工具] 加载模型失败!路径:{model_path}")
|
||||
print(f"[模型工具] 异常类型:{type(e).__name__}")
|
||||
print(f"[模型工具] 异常详情:{str(e)}")
|
||||
print(f"[模型工具] 堆栈跟踪:\n{traceback.format_exc()}")
|
||||
return None
|
||||
482
ws.html
@ -1,482 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>WebSocket 测试工具</title>
|
||||
<style>
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
font-family: 'Arial', 'Microsoft YaHei', sans-serif;
|
||||
}
|
||||
|
||||
body {
|
||||
max-width: 1200px;
|
||||
margin: 20px auto;
|
||||
padding: 0 20px;
|
||||
background-color: #f5f7fa;
|
||||
}
|
||||
|
||||
.container {
|
||||
background: white;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
|
||||
padding: 25px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
h1 {
|
||||
color: #2c3e50;
|
||||
margin-bottom: 20px;
|
||||
font-size: 24px;
|
||||
border-bottom: 2px solid #3498db;
|
||||
padding-bottom: 10px;
|
||||
}
|
||||
|
||||
.status-bar {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 15px;
|
||||
margin-bottom: 20px;
|
||||
padding: 12px 15px;
|
||||
background-color: #f8f9fa;
|
||||
border-radius: 6px;
|
||||
}
|
||||
|
||||
.status-label {
|
||||
font-weight: bold;
|
||||
color: #495057;
|
||||
}
|
||||
|
||||
.status-value {
|
||||
padding: 4px 10px;
|
||||
border-radius: 4px;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.status-connected {
|
||||
background-color: #d4edda;
|
||||
color: #155724;
|
||||
}
|
||||
|
||||
.status-disconnected {
|
||||
background-color: #f8d7da;
|
||||
color: #721c24;
|
||||
}
|
||||
|
||||
.status-connecting {
|
||||
background-color: #fff3cd;
|
||||
color: #856404;
|
||||
}
|
||||
|
||||
.btn {
|
||||
padding: 8px 16px;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
transition: background-color 0.2s;
|
||||
}
|
||||
|
||||
.btn-primary {
|
||||
background-color: #3498db;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-primary:hover {
|
||||
background-color: #2980b9;
|
||||
}
|
||||
|
||||
.btn-danger {
|
||||
background-color: #e74c3c;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-danger:hover {
|
||||
background-color: #c0392b;
|
||||
}
|
||||
|
||||
.btn-success {
|
||||
background-color: #2ecc71;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-success:hover {
|
||||
background-color: #27ae60;
|
||||
}
|
||||
|
||||
.control-group {
|
||||
display: flex;
|
||||
gap: 15px;
|
||||
margin-bottom: 20px;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.input-group {
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.input-group label {
|
||||
color: #495057;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.input-group input, .input-group select {
|
||||
padding: 8px 12px;
|
||||
border: 1px solid #ced4da;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.message-area {
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.message-input {
|
||||
width: 100%;
|
||||
height: 100px;
|
||||
padding: 12px;
|
||||
border: 1px solid #ced4da;
|
||||
border-radius: 6px;
|
||||
resize: none;
|
||||
font-size: 14px;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.log-area {
|
||||
width: 100%;
|
||||
height: 300px;
|
||||
padding: 15px;
|
||||
border: 1px solid #ced4da;
|
||||
border-radius: 6px;
|
||||
background-color: #f8f9fa;
|
||||
overflow-y: auto;
|
||||
font-size: 14px;
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
.log-item {
|
||||
margin-bottom: 8px;
|
||||
padding-bottom: 8px;
|
||||
border-bottom: 1px dashed #e9ecef;
|
||||
}
|
||||
|
||||
.log-time {
|
||||
color: #6c757d;
|
||||
font-size: 12px;
|
||||
margin-right: 10px;
|
||||
}
|
||||
|
||||
.log-send {
|
||||
color: #2980b9;
|
||||
}
|
||||
|
||||
.log-receive {
|
||||
color: #27ae60;
|
||||
}
|
||||
|
||||
.log-status {
|
||||
color: #856404;
|
||||
}
|
||||
|
||||
.log-error {
|
||||
color: #e74c3c;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>WebSocket 测试工具</h1>
|
||||
|
||||
<!-- 连接状态区 -->
|
||||
<div class="status-bar">
|
||||
<div class="status-label">连接状态:</div>
|
||||
<div id="connectionStatus" class="status-value status-disconnected">未连接</div>
|
||||
<div class="status-label">服务地址:</div>
|
||||
<div id="wsUrl" class="status-value">ws://192.168.110.25:8000/ws</div>
|
||||
<div class="status-label">连接时间:</div>
|
||||
<div id="connectTime" class="status-value">-</div>
|
||||
</div>
|
||||
|
||||
<!-- 控制按钮区 -->
|
||||
<div class="control-group">
|
||||
<button id="connectBtn" class="btn btn-primary">建立连接</button>
|
||||
<button id="disconnectBtn" class="btn btn-danger" disabled>断开连接</button>
|
||||
|
||||
<!-- 心跳控制 -->
|
||||
<div class="input-group">
|
||||
<label>自动心跳:</label>
|
||||
<select id="autoHeartbeat">
|
||||
<option value="on">开启</option>
|
||||
<option value="off">关闭</option>
|
||||
</select>
|
||||
<label>间隔(秒):</label>
|
||||
<input type="number" id="heartbeatInterval" value="30" min="10" max="120" style="width: 80px;">
|
||||
<button id="sendHeartbeatBtn" class="btn btn-success">手动发送心跳</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 自定义消息发送区 -->
|
||||
<div class="message-area">
|
||||
<h3>发送自定义消息</h3>
|
||||
<textarea id="messageInput" class="message-input"
|
||||
placeholder='示例:{"type":"test","content":"Hello WebSocket"}'>{"type":"test","content":"Hello WebSocket"}</textarea>
|
||||
<button id="sendMessageBtn" class="btn btn-primary" disabled>发送消息</button>
|
||||
</div>
|
||||
|
||||
<!-- 日志显示区 -->
|
||||
<div class="message-area">
|
||||
<h3>消息日志</h3>
|
||||
<div id="logContainer" class="log-area">
|
||||
<div class="log-item"><span class="log-time">[加载完成]</span> 请点击「建立连接」开始测试</div>
|
||||
</div>
|
||||
<button id="clearLogBtn" class="btn btn-primary" style="margin-top: 10px;">清空日志</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// 全局变量
|
||||
let ws = null;
|
||||
let heartbeatTimer = null;
|
||||
const wsUrl = "ws://192.168.110.25:8000/ws";
|
||||
|
||||
// DOM 元素
|
||||
const connectionStatus = document.getElementById('connectionStatus');
|
||||
const connectTime = document.getElementById('connectTime');
|
||||
const connectBtn = document.getElementById('connectBtn');
|
||||
const disconnectBtn = document.getElementById('disconnectBtn');
|
||||
const sendMessageBtn = document.getElementById('sendMessageBtn');
|
||||
const sendHeartbeatBtn = document.getElementById('sendHeartbeatBtn');
|
||||
const autoHeartbeat = document.getElementById('autoHeartbeat');
|
||||
const heartbeatInterval = document.getElementById('heartbeatInterval');
|
||||
const messageInput = document.getElementById('messageInput');
|
||||
const logContainer = document.getElementById('logContainer');
|
||||
const clearLogBtn = document.getElementById('clearLogBtn');
|
||||
|
||||
// 工具函数:添加日志
|
||||
function addLog(content, type = 'status') {
|
||||
const now = new Date().toLocaleString('zh-CN', {
|
||||
year: 'numeric', month: '2-digit', day: '2-digit',
|
||||
hour: '2-digit', minute: '2-digit', second: '2-digit'
|
||||
});
|
||||
const logItem = document.createElement('div');
|
||||
logItem.className = 'log-item';
|
||||
|
||||
let logClass = '';
|
||||
switch (type) {
|
||||
case 'send':
|
||||
logClass = 'log-send';
|
||||
break;
|
||||
case 'receive':
|
||||
logClass = 'log-receive';
|
||||
break;
|
||||
case 'error':
|
||||
logClass = 'log-error';
|
||||
break;
|
||||
default:
|
||||
logClass = 'log-status';
|
||||
}
|
||||
|
||||
logItem.innerHTML = `<span class="log-time">[${now}]</span> <span class="${logClass}">${content}</span>`;
|
||||
logContainer.appendChild(logItem);
|
||||
// 滚动到最新日志
|
||||
logContainer.scrollTop = logContainer.scrollHeight;
|
||||
}
|
||||
|
||||
// 工具函数:格式化JSON(便于日志显示)
|
||||
function formatJson(jsonStr) {
|
||||
try {
|
||||
const obj = JSON.parse(jsonStr);
|
||||
return JSON.stringify(obj, null, 2);
|
||||
} catch (e) {
|
||||
return jsonStr; // 非JSON格式直接返回
|
||||
}
|
||||
}
|
||||
|
||||
// 建立WebSocket连接
|
||||
function connectWebSocket() {
|
||||
if (ws) {
|
||||
addLog('已存在连接,无需重复建立', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
ws = new WebSocket(wsUrl);
|
||||
|
||||
// 连接成功
|
||||
ws.onopen = function () {
|
||||
connectionStatus.className = 'status-value status-connected';
|
||||
connectionStatus.textContent = '已连接';
|
||||
const now = new Date().toLocaleString('zh-CN');
|
||||
connectTime.textContent = now;
|
||||
addLog(`连接成功!服务地址:${wsUrl}`, 'status');
|
||||
|
||||
// 更新按钮状态
|
||||
connectBtn.disabled = true;
|
||||
disconnectBtn.disabled = false;
|
||||
sendMessageBtn.disabled = false;
|
||||
|
||||
// 开启自动心跳(默认开启)
|
||||
if (autoHeartbeat.value === 'on') {
|
||||
startAutoHeartbeat();
|
||||
}
|
||||
};
|
||||
|
||||
// 接收消息
|
||||
ws.onmessage = function (event) {
|
||||
const message = event.data;
|
||||
addLog(`收到消息:\n${formatJson(message)}`, 'receive');
|
||||
};
|
||||
|
||||
// 连接关闭
|
||||
ws.onclose = function (event) {
|
||||
connectionStatus.className = 'status-value status-disconnected';
|
||||
connectionStatus.textContent = '已断开';
|
||||
addLog(`连接断开!代码:${event.code},原因:${event.reason || '未知'}`, 'status');
|
||||
|
||||
// 清除自动心跳
|
||||
stopAutoHeartbeat();
|
||||
|
||||
// 更新按钮状态
|
||||
connectBtn.disabled = false;
|
||||
disconnectBtn.disabled = true;
|
||||
sendMessageBtn.disabled = true;
|
||||
|
||||
// 重置WebSocket对象
|
||||
ws = null;
|
||||
};
|
||||
|
||||
// 连接错误
|
||||
ws.onerror = function (error) {
|
||||
addLog(`连接错误:${error.message || '未知错误'}`, 'error');
|
||||
};
|
||||
|
||||
} catch (e) {
|
||||
addLog(`建立连接失败:${e.message}`, 'error');
|
||||
ws = null;
|
||||
}
|
||||
}
|
||||
|
||||
// 断开WebSocket连接
|
||||
function disconnectWebSocket() {
|
||||
if (!ws) {
|
||||
addLog('当前无连接,无需断开', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
ws.close(1000, '手动断开连接');
|
||||
}
|
||||
|
||||
// 发送心跳消息(符合约定格式:{"timestamp":xxxxx, "type":"heartbeat"})
|
||||
function sendHeartbeat() {
|
||||
if (!ws || ws.readyState !== WebSocket.OPEN) {
|
||||
addLog('发送心跳失败:当前无有效连接', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
const heartbeatMsg = {
|
||||
timestamp: Date.now(), // 当前毫秒时间戳
|
||||
type: "heartbeat"
|
||||
};
|
||||
const msgStr = JSON.stringify(heartbeatMsg);
|
||||
|
||||
ws.send(msgStr);
|
||||
addLog(`发送心跳:\n${formatJson(msgStr)}`, 'send');
|
||||
}
|
||||
|
||||
// 开启自动心跳
|
||||
function startAutoHeartbeat() {
|
||||
// 先停止已有定时器
|
||||
stopAutoHeartbeat();
|
||||
|
||||
const interval = parseInt(heartbeatInterval.value) * 1000;
|
||||
if (isNaN(interval) || interval < 10000) {
|
||||
addLog('自动心跳间隔无效,已重置为30秒', 'error');
|
||||
heartbeatInterval.value = 30;
|
||||
return startAutoHeartbeat();
|
||||
}
|
||||
|
||||
addLog(`开启自动心跳,间隔:${heartbeatInterval.value}秒`, 'status');
|
||||
heartbeatTimer = setInterval(sendHeartbeat, interval);
|
||||
}
|
||||
|
||||
// 停止自动心跳
|
||||
function stopAutoHeartbeat() {
|
||||
if (heartbeatTimer) {
|
||||
clearInterval(heartbeatTimer);
|
||||
heartbeatTimer = null;
|
||||
addLog('已停止自动心跳', 'status');
|
||||
}
|
||||
}
|
||||
|
||||
// 发送自定义消息
|
||||
function sendCustomMessage() {
|
||||
if (!ws || ws.readyState !== WebSocket.OPEN) {
|
||||
addLog('发送消息失败:当前无有效连接', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
const msgStr = messageInput.value.trim();
|
||||
if (!msgStr) {
|
||||
addLog('发送消息失败:消息内容不能为空', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// 验证JSON格式(可选,仅提示不强制)
|
||||
JSON.parse(msgStr);
|
||||
ws.send(msgStr);
|
||||
addLog(`发送自定义消息:\n${formatJson(msgStr)}`, 'send');
|
||||
} catch (e) {
|
||||
addLog(`JSON格式错误:${e.message},仍尝试发送原始内容`, 'error');
|
||||
ws.send(msgStr);
|
||||
addLog(`发送自定义消息(非JSON):\n${msgStr}`, 'send');
|
||||
}
|
||||
}
|
||||
|
||||
// 绑定按钮事件
|
||||
connectBtn.addEventListener('click', connectWebSocket);
|
||||
disconnectBtn.addEventListener('click', disconnectWebSocket);
|
||||
sendMessageBtn.addEventListener('click', sendCustomMessage);
|
||||
sendHeartbeatBtn.addEventListener('click', sendHeartbeat);
|
||||
clearLogBtn.addEventListener('click', () => {
|
||||
logContainer.innerHTML = '<div class="log-item"><span class="log-time">[日志已清空]</span> 请继续操作...</div>';
|
||||
});
|
||||
|
||||
// 自动心跳开关变更事件
|
||||
autoHeartbeat.addEventListener('change', function () {
|
||||
if (ws && ws.readyState === WebSocket.OPEN) {
|
||||
if (this.value === 'on') {
|
||||
startAutoHeartbeat();
|
||||
} else {
|
||||
stopAutoHeartbeat();
|
||||
}
|
||||
} else {
|
||||
addLog('需先建立有效连接才能控制自动心跳', 'error');
|
||||
// 重置选择
|
||||
this.value = 'off';
|
||||
}
|
||||
});
|
||||
|
||||
// 心跳间隔变更事件(实时生效)
|
||||
heartbeatInterval.addEventListener('change', function () {
|
||||
if (autoHeartbeat.value === 'on' && ws && ws.readyState === WebSocket.OPEN) {
|
||||
startAutoHeartbeat();
|
||||
}
|
||||
});
|
||||
|
||||
// 快捷键支持(Ctrl+Enter发送消息)
|
||||
messageInput.addEventListener('keydown', function (e) {
|
||||
if (e.ctrlKey && e.key === 'Enter') {
|
||||
sendCustomMessage();
|
||||
e.preventDefault();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
384
ws/ws.py
@ -2,289 +2,309 @@ import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import base64
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict, Optional, AsyncGenerator
|
||||
from typing import Dict, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
|
||||
|
||||
from ocr.model_violation_detector import MultiModelViolationDetector
|
||||
from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip
|
||||
from service.device_action_service import add_device_action
|
||||
from schema.device_action_schema import DeviceActionCreate
|
||||
from core.all import detect, load_model
|
||||
|
||||
# 配置文件相对路径(根据实际目录结构调整)
|
||||
YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt"
|
||||
FORBIDDEN_WORDS_PATH = r"D:\Git\bin\video\ocr\forbidden_words.txt"
|
||||
OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml"
|
||||
KNOWN_FACES_DIR = r"D:\Git\bin\video\ocr\known_faces"
|
||||
# -------------------------- 1. AES 加密工具(仅用于服务器向客户端发送消息)--------------------------
|
||||
AES_SECRET_KEY = b"jr1vA6tfWMHOYi6UXw67UuO6fdak2rMa" # 约定密钥(32字节)
|
||||
AES_BLOCK_SIZE = 16 # AES固定块大小
|
||||
|
||||
# 创建检测器实例
|
||||
detector = MultiModelViolationDetector(
|
||||
forbidden_words_path=FORBIDDEN_WORDS_PATH,
|
||||
ocr_config_path=OCR_CONFIG_PATH,
|
||||
yolo_model_path=YOLO_MODEL_PATH,
|
||||
known_faces_dir=KNOWN_FACES_DIR,
|
||||
ocr_confidence_threshold=0.5
|
||||
)
|
||||
|
||||
# -------------------------- 配置常量 --------------------------
|
||||
def aes_encrypt(plaintext: str) -> dict:
|
||||
"""AES-CBC加密:返回{密文, IV, 算法标识}(均Base64编码)- 仅用于服务器发消息"""
|
||||
try:
|
||||
iv = os.urandom(AES_BLOCK_SIZE) # 随机IV(16字节)
|
||||
cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv)
|
||||
padded_plaintext = pad(plaintext.encode("utf-8"), AES_BLOCK_SIZE)
|
||||
ciphertext = base64.b64encode(cipher.encrypt(padded_plaintext)).decode("utf-8")
|
||||
iv_base64 = base64.b64encode(iv).decode("utf-8")
|
||||
return {
|
||||
"ciphertext": ciphertext,
|
||||
"iv": iv_base64,
|
||||
"algorithm": "AES-CBC"
|
||||
}
|
||||
except Exception as e:
|
||||
raise Exception(f"AES加密失败: {str(e)}") from e
|
||||
|
||||
|
||||
# -------------------------- 2. 配置常量(保持原有)--------------------------
|
||||
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
|
||||
HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒)
|
||||
WS_ENDPOINT = "/ws" # WebSocket端点路径
|
||||
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制(保持1,确保单帧处理)
|
||||
|
||||
# -------------------------- 核心数据结构与全局变量 --------------------------
|
||||
ws_router = APIRouter()
|
||||
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制
|
||||
|
||||
|
||||
# 客户端连接封装(包含帧队列)
|
||||
# -------------------------- 3. 工具函数(保持原有)--------------------------
|
||||
def get_current_time_str() -> str:
|
||||
"""获取格式化时间字符串"""
|
||||
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def get_current_time_file_str() -> str:
|
||||
"""获取文件命名用时间字符串"""
|
||||
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
|
||||
|
||||
# -------------------------- 4. 客户端连接封装(服务器发消息仍加密,接收消息改明文)--------------------------
|
||||
class ClientConnection:
|
||||
def __init__(self, websocket: WebSocket, client_ip: str):
|
||||
self.websocket = websocket
|
||||
self.client_ip = client_ip
|
||||
self.last_heartbeat = datetime.datetime.now()
|
||||
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE) # 帧队列,长度为1
|
||||
self.consumer_task: Optional[asyncio.Task] = None # 消费者任务
|
||||
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE)
|
||||
self.consumer_task: Optional[asyncio.Task] = None
|
||||
|
||||
# 更新心跳时间
|
||||
def update_heartbeat(self):
|
||||
"""更新心跳时间"""
|
||||
self.last_heartbeat = datetime.datetime.now()
|
||||
|
||||
# 检查是否存活(超时返回False)
|
||||
def is_alive(self) -> bool:
|
||||
timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
|
||||
return timeout < HEARTBEAT_TIMEOUT
|
||||
"""判断客户端是否存活"""
|
||||
timeout_seconds = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
|
||||
return timeout_seconds < HEARTBEAT_TIMEOUT
|
||||
|
||||
# 启动帧消费任务
|
||||
def start_consumer(self):
|
||||
"""启动帧消费任务"""
|
||||
self.consumer_task = asyncio.create_task(self.consume_frames())
|
||||
return self.consumer_task
|
||||
|
||||
# ---------- 新增:发送“允许发送二进制帧”的信号给客户端 ----------
|
||||
async def send_allow_send_frame(self):
|
||||
"""向客户端发送JSON信号,通知其可发送下一帧二进制数据"""
|
||||
async def send_frame_permit(self):
|
||||
"""发送加密的帧许可信号(服务器→客户端:加密)"""
|
||||
try:
|
||||
allow_msg = {
|
||||
"type": "allow_send_frame", # 信号类型,与客户端约定
|
||||
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"status": "ready", # 表示服务器已准备好接收下一帧
|
||||
"client_ip": self.client_ip # 可选:便于客户端确认自身身份
|
||||
frame_permit_msg = {
|
||||
"type": "frame",
|
||||
"timestamp": get_current_time_str(),
|
||||
"client_ip": self.client_ip
|
||||
}
|
||||
await self.websocket.send_json(allow_msg)
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:已发送「允许发送帧」信号")
|
||||
encrypted_msg = aes_encrypt(json.dumps(frame_permit_msg)) # 保持加密
|
||||
await self.websocket.send_json(encrypted_msg)
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送加密帧许可")
|
||||
except Exception as e:
|
||||
# 发送失败大概率是客户端已断开,不影响主流程,仅日志记录
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:发送「允许发送帧」信号失败 - {str(e)}")
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可加密/发送失败 - {str(e)}")
|
||||
|
||||
# 帧消费协程
|
||||
async def consume_frames(self) -> None:
|
||||
"""从队列中获取帧并进行处理,处理完后通知客户端可发送下一帧"""
|
||||
"""消费队列中的明文图像帧并处理"""
|
||||
try:
|
||||
while True:
|
||||
# 从队列获取帧数据(队列空时会阻塞,等待客户端发送)
|
||||
frame_data = await self.frame_queue.get()
|
||||
await self.send_frame_permit() # 回复仍加密
|
||||
try:
|
||||
# 处理帧数据
|
||||
await self.process_frame(frame_data)
|
||||
finally:
|
||||
# 标记任务完成(队列计数-1,此时队列回到空状态)
|
||||
self.frame_queue.task_done()
|
||||
# ---------- 修改:处理完当前帧后,立即通知客户端可发送下一帧 ----------
|
||||
await self.send_allow_send_frame()
|
||||
except asyncio.CancelledError:
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:帧消费任务已取消")
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费任务已取消")
|
||||
except Exception as e:
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:帧处理错误 - {str(e)}")
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费错误 - {str(e)}")
|
||||
|
||||
async def process_frame(self, frame_data: bytes) -> None:
|
||||
"""处理单帧图像数据(原有逻辑不变)"""
|
||||
# 将二进制数据转换为NumPy数组(uint8类型)
|
||||
"""处理明文图像帧(危险通知仍加密发送)"""
|
||||
# 二进制转OpenCV图像(客户端发的是明文二进制,直接解析)
|
||||
nparr = np.frombuffer(frame_data, np.uint8)
|
||||
# 解码为图像,返回与cv2.imread相同的格式(BGR通道的ndarray)
|
||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
|
||||
# 确保images文件夹存在
|
||||
if not os.path.exists('images'):
|
||||
os.makedirs('images')
|
||||
|
||||
# 生成唯一的文件名,包含时间戳和客户端IP,避免文件名冲突
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
filename = f"images/{self.client_ip.replace('.', '_')}_{timestamp}.jpg"
|
||||
if img is None:
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析明文图像")
|
||||
return
|
||||
|
||||
try:
|
||||
# 保存图像到本地
|
||||
cv2.imwrite(filename, img)
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 图像已保存至:{filename}")
|
||||
has_violation, data, detector_type = await asyncio.to_thread(
|
||||
detect, self.client_ip, img
|
||||
)
|
||||
print(
|
||||
f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - 违规: {has_violation}, 类型: {detector_type}")
|
||||
|
||||
# 进行检测
|
||||
if img is not None:
|
||||
has_violation, violation_type, details = detector.detect_violations(img)
|
||||
# 违规通知:服务器→客户端,仍加密
|
||||
if has_violation:
|
||||
print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
|
||||
# 发送检测结果回客户端(原有逻辑不变)
|
||||
await self.websocket.send_json({
|
||||
"type": "detection_result",
|
||||
"has_violation": has_violation,
|
||||
"violation_type": violation_type,
|
||||
"details": details,
|
||||
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
})
|
||||
else:
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:未检测到任何违规内容")
|
||||
else:
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:无法解析图像数据")
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - {data}")
|
||||
# 违规次数+1
|
||||
try:
|
||||
await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip)
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数+1")
|
||||
except Exception as e:
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:图像处理错误 - {str(e)}")
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}")
|
||||
|
||||
# 构建危险通知并加密发送
|
||||
danger_msg = {
|
||||
"type": "danger",
|
||||
"timestamp": get_current_time_str(),
|
||||
"client_ip": self.client_ip,
|
||||
"detail": data
|
||||
}
|
||||
encrypted_danger_msg = aes_encrypt(json.dumps(danger_msg)) # 保持加密
|
||||
await self.websocket.send_json(encrypted_danger_msg)
|
||||
except Exception as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 明文图像处理错误 - {str(e)}")
|
||||
|
||||
|
||||
# 全局连接管理(IP -> 连接实例)
|
||||
# -------------------------- 5. 全局状态与心跳管理(保持原有)--------------------------
|
||||
connected_clients: Dict[str, ClientConnection] = {}
|
||||
# 心跳任务(全局引用,用于关闭时清理)
|
||||
heartbeat_task: Optional[asyncio.Task] = None
|
||||
|
||||
|
||||
# -------------------------- 心跳检查逻辑(原有逻辑不变) --------------------------
|
||||
async def heartbeat_checker():
|
||||
"""全局心跳检查任务"""
|
||||
while True:
|
||||
now = datetime.datetime.now()
|
||||
# 1. 筛选超时客户端(避免遍历中修改字典)
|
||||
current_time = get_current_time_str()
|
||||
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
|
||||
|
||||
# 2. 处理超时连接(关闭+移除)
|
||||
if timeout_ips:
|
||||
print(f"[{now:%H:%M:%S}] 心跳检查:{len(timeout_ips)}个客户端超时({timeout_ips})")
|
||||
print(f"[{current_time}] 心跳检查: {len(timeout_ips)}个客户端超时(IP: {timeout_ips})")
|
||||
for ip in timeout_ips:
|
||||
try:
|
||||
# 取消消费者任务
|
||||
if connected_clients[ip].consumer_task and not connected_clients[ip].consumer_task.done():
|
||||
connected_clients[ip].consumer_task.cancel()
|
||||
await connected_clients[ip].websocket.close(code=1008, reason="心跳超时")
|
||||
conn = connected_clients[ip]
|
||||
if conn.consumer_task and not conn.consumer_task.done():
|
||||
conn.consumer_task.cancel()
|
||||
await conn.websocket.close(code=1008, reason="心跳超时")
|
||||
await asyncio.to_thread(update_online_status_by_ip, ip, 0)
|
||||
action_data = DeviceActionCreate(client_ip=ip, action=0)
|
||||
await asyncio.to_thread(add_device_action, action_data)
|
||||
print(f"[{current_time}] 客户端{ip}: 已标记为离线")
|
||||
except Exception as e:
|
||||
print(f"[{current_time}] 客户端{ip}: 离线处理失败 - {str(e)}")
|
||||
finally:
|
||||
connected_clients.pop(ip, None)
|
||||
else:
|
||||
print(f"[{now:%H:%M:%S}] 心跳检查:{len(connected_clients)}个客户端在线,无超时")
|
||||
print(f"[{current_time}] 心跳检查: {len(connected_clients)}个客户端在线")
|
||||
|
||||
# 3. 等待下一轮检查
|
||||
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
||||
|
||||
|
||||
# -------------------------- 应用生命周期(原有逻辑不变) --------------------------
|
||||
# -------------------------- 6. 客户端明文消息处理(关键修改:删除解密逻辑)--------------------------
|
||||
async def send_heartbeat_ack(conn: ClientConnection):
|
||||
"""发送加密的心跳确认(服务器→客户端:加密)"""
|
||||
try:
|
||||
heartbeat_ack_msg = {
|
||||
"type": "heart",
|
||||
"timestamp": get_current_time_str(),
|
||||
"client_ip": conn.client_ip
|
||||
}
|
||||
encrypted_msg = aes_encrypt(json.dumps(heartbeat_ack_msg)) # 保持加密
|
||||
await conn.websocket.send_json(encrypted_msg)
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 已发送加密心跳确认")
|
||||
return True
|
||||
except Exception as e:
|
||||
connected_clients.pop(conn.client_ip, None)
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 心跳确认失败 - {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def handle_text_msg(conn: ClientConnection, text: str):
|
||||
"""处理客户端明文文本消息(如心跳)- 关键修改:无需解密,直接解析JSON"""
|
||||
try:
|
||||
# 客户端发的是明文JSON,直接解析(删除原解密步骤)
|
||||
msg = json.loads(text)
|
||||
if msg.get("type") == "heart":
|
||||
conn.update_heartbeat()
|
||||
await send_heartbeat_ack(conn) # 服务器回复仍加密
|
||||
else:
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知明文文本类型({msg.get('type')})")
|
||||
except json.JSONDecodeError:
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON格式(明文文本)")
|
||||
except Exception as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 明文文本消息处理失败 - {str(e)}")
|
||||
|
||||
|
||||
# -------------------------- 7. WebSocket路由与生命周期(关键修改:处理明文二进制图像)--------------------------
|
||||
ws_router = APIRouter()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期:启动/停止心跳任务"""
|
||||
global heartbeat_task
|
||||
# 启动心跳任务
|
||||
heartbeat_task = asyncio.create_task(heartbeat_checker())
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 心跳任务启动(ID:{id(heartbeat_task)})")
|
||||
print(f"[{get_current_time_str()}] 心跳检查任务启动(ID: {id(heartbeat_task)})")
|
||||
yield
|
||||
# 关闭时取消心跳任务
|
||||
if heartbeat_task and not heartbeat_task.done():
|
||||
heartbeat_task.cancel()
|
||||
try:
|
||||
await heartbeat_task
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 心跳任务已取消")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
print(f"[{get_current_time_str()}] 心跳检查任务已取消")
|
||||
|
||||
|
||||
# -------------------------- 消息处理(文本/心跳逻辑不变,二进制逻辑保留) --------------------------
|
||||
async def send_heartbeat_ack(client_ip: str):
|
||||
"""回复心跳确认(原有逻辑不变)"""
|
||||
if client_ip not in connected_clients:
|
||||
return False
|
||||
try:
|
||||
ack = {
|
||||
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"type": "heartbeat"
|
||||
}
|
||||
await connected_clients[client_ip].websocket.send_json(ack)
|
||||
return True
|
||||
except Exception:
|
||||
connected_clients.pop(client_ip, None)
|
||||
return False
|
||||
|
||||
|
||||
async def handle_text_msg(client_ip: str, text: str, conn: ClientConnection):
|
||||
"""处理文本消息(核心:心跳+JSON解析,原有逻辑不变)"""
|
||||
try:
|
||||
msg = json.loads(text)
|
||||
# 仅处理心跳类型消息
|
||||
if msg.get("type") == "heartbeat":
|
||||
conn.update_heartbeat()
|
||||
await send_heartbeat_ack(client_ip)
|
||||
else:
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:收到文本消息:{msg}")
|
||||
except json.JSONDecodeError:
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:无效JSON消息")
|
||||
|
||||
|
||||
async def handle_binary_msg(client_ip: str, data: bytes):
|
||||
"""处理二进制消息(原有逻辑不变,因客户端仅在收到允许信号后发送,队列不会满)"""
|
||||
if client_ip not in connected_clients:
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:连接不存在,丢弃{len(data)}字节数据")
|
||||
return
|
||||
|
||||
conn = connected_clients[client_ip]
|
||||
|
||||
# 检查队列是否已满(理论上不会触发,因客户端按信号发送)
|
||||
if conn.frame_queue.full():
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:队列已满,丢弃{len(data)}字节数据")
|
||||
return
|
||||
|
||||
# 队列未满,添加帧到队列
|
||||
try:
|
||||
conn.frame_queue.put_nowait(data)
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:已接收{len(data)}字节二进制数据,加入队列")
|
||||
except asyncio.QueueFull:
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:队列突然满了,丢弃{len(data)}字节数据")
|
||||
|
||||
|
||||
# -------------------------- WebSocket核心端点(修改连接初始化逻辑) --------------------------
|
||||
@ws_router.websocket(WS_ENDPOINT)
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
# 接受连接 + 获取客户端IP
|
||||
"""WebSocket连接处理入口 - 关键修改:接收客户端明文二进制图像"""
|
||||
load_model() # 加载检测模型(建议移到全局,避免重复加载)
|
||||
await websocket.accept()
|
||||
client_ip = websocket.client.host if websocket.client else "unknown"
|
||||
now = datetime.datetime.now()
|
||||
print(f"[{now:%H:%M:%S}] 客户端{client_ip}:连接成功")
|
||||
client_ip = websocket.client.host if websocket.client else "unknown_ip"
|
||||
current_time = get_current_time_str()
|
||||
print(f"[{current_time}] 客户端{client_ip}: 连接已建立")
|
||||
is_online_updated = False
|
||||
|
||||
consumer_task = None
|
||||
try:
|
||||
# 处理重复连接(关闭旧连接)
|
||||
if client_ip in connected_clients:
|
||||
# 取消旧连接的消费者任务
|
||||
if connected_clients[client_ip].consumer_task and not connected_clients[client_ip].consumer_task.done():
|
||||
connected_clients[client_ip].consumer_task.cancel()
|
||||
await connected_clients[client_ip].websocket.close(code=1008, reason="同一IP新连接")
|
||||
old_conn = connected_clients[client_ip]
|
||||
if old_conn.consumer_task and not old_conn.consumer_task.done():
|
||||
old_conn.consumer_task.cancel()
|
||||
await old_conn.websocket.close(code=1008, reason="新连接建立")
|
||||
connected_clients.pop(client_ip)
|
||||
print(f"[{now:%H:%M:%S}] 客户端{client_ip}:关闭旧连接")
|
||||
print(f"[{current_time}] 客户端{client_ip}: 已关闭旧连接")
|
||||
|
||||
# 注册新连接
|
||||
new_conn = ClientConnection(websocket, client_ip)
|
||||
connected_clients[client_ip] = new_conn
|
||||
new_conn.start_consumer()
|
||||
await new_conn.send_frame_permit() # 首次许可仍加密
|
||||
|
||||
# 启动帧消费任务
|
||||
consumer_task = new_conn.start_consumer()
|
||||
# ---------- 修改:客户端刚连接时,队列空,立即发送「允许发送帧」信号 ----------
|
||||
await new_conn.send_allow_send_frame()
|
||||
print(f"[{now:%H:%M:%S}] 客户端{client_ip}:注册成功,已启动帧消费任务,当前在线{len(connected_clients)}个")
|
||||
# 标记客户端上线
|
||||
try:
|
||||
await asyncio.to_thread(update_online_status_by_ip, client_ip, 1)
|
||||
action_data = DeviceActionCreate(client_ip=client_ip, action=1)
|
||||
await asyncio.to_thread(add_device_action, action_data)
|
||||
print(f"[{current_time}] 客户端{client_ip}: 已标记为在线")
|
||||
is_online_updated = True
|
||||
except Exception as e:
|
||||
print(f"[{current_time}] 客户端{client_ip}: 上线状态更新失败 - {str(e)}")
|
||||
|
||||
# 循环接收消息(原有逻辑不变)
|
||||
print(f"[{current_time}] 客户端{client_ip}: 连接注册成功,在线数: {len(connected_clients)}")
|
||||
|
||||
# 消息循环:接收客户端明文消息(关键修改)
|
||||
while True:
|
||||
data = await websocket.receive()
|
||||
if "text" in data:
|
||||
await handle_text_msg(client_ip, data["text"], new_conn)
|
||||
# 处理客户端明文文本(如心跳:{"type":"heart",...})
|
||||
await handle_text_msg(new_conn, data["text"])
|
||||
elif "bytes" in data:
|
||||
await handle_binary_msg(client_ip, data["bytes"])
|
||||
|
||||
# 异常处理(断开/错误)
|
||||
except WebSocketDisconnect as e:
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:主动断开(代码:{e.code})")
|
||||
# 处理客户端明文二进制图像(直接入队,无需解密)
|
||||
frame_data = data["bytes"]
|
||||
try:
|
||||
new_conn.frame_queue.put_nowait(frame_data)
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 明文图像({len(frame_data)}字节)入队")
|
||||
except asyncio.QueueFull:
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 帧队列已满,丢弃数据")
|
||||
except Exception as e:
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:连接异常({str(e)[:50]})")
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 明文图像处理失败 - {str(e)}")
|
||||
|
||||
except WebSocketDisconnect as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 主动断开连接(代码: {e.code})")
|
||||
except Exception as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}")
|
||||
finally:
|
||||
# 清理连接和任务
|
||||
# 清理资源
|
||||
if client_ip in connected_clients:
|
||||
# 取消消费者任务
|
||||
if connected_clients[client_ip].consumer_task and not connected_clients[client_ip].consumer_task.done():
|
||||
connected_clients[client_ip].consumer_task.cancel()
|
||||
conn = connected_clients[client_ip]
|
||||
if conn.consumer_task and not conn.consumer_task.done():
|
||||
conn.consumer_task.cancel()
|
||||
if is_online_updated:
|
||||
try:
|
||||
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
|
||||
action_data = DeviceActionCreate(client_ip=client_ip, action=0)
|
||||
await asyncio.to_thread(add_device_action, action_data)
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后标记为离线")
|
||||
except Exception as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 离线更新失败 - {str(e)}")
|
||||
connected_clients.pop(client_ip, None)
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:连接已清理,当前在线{len(connected_clients)}个")
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理,在线数: {len(connected_clients)}")
|
||||