模型添加置信度设置,敏感词分页
This commit is contained in:
@ -31,15 +31,16 @@ DB_PATH_PREFIX_TO_REMOVE = str(PROJECT_ROOT) + os.sep
|
||||
ALLOWED_MODEL_EXT = {"pt"}
|
||||
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
|
||||
|
||||
# 全局模型变量(带版本标识)
|
||||
global _yolo_model, _current_model_version
|
||||
# 全局模型变量(带版本标识和置信度)
|
||||
global _yolo_model, _current_model_version, _current_conf_threshold
|
||||
_yolo_model = None
|
||||
_current_model_version = None # 模型版本标识(用于检测模型是否变化)
|
||||
_current_model_version = None # 模型版本标识
|
||||
_current_conf_threshold = 0.8 # 默认置信度初始值
|
||||
|
||||
router = APIRouter(prefix="/models", tags=["模型管理"])
|
||||
|
||||
|
||||
# 服务重启核心工具函数
|
||||
# 服务重启核心工具函数(保持不变)
|
||||
def restart_service():
|
||||
"""重启当前FastAPI服务进程"""
|
||||
print("\n[服务重启] 检测到默认模型更换,开始清理资源并重启...")
|
||||
@ -87,7 +88,7 @@ def restart_service():
|
||||
raise HTTPException(status_code=500, detail=f"模型更换成功,但服务重启失败:{str(e)}") from e
|
||||
|
||||
|
||||
# 模型路径验证工具函数
|
||||
# 模型路径验证工具函数(保持不变)
|
||||
def get_valid_model_abs_path(relative_path: str) -> str:
|
||||
try:
|
||||
relative_path = relative_path.replace("/", os.sep)
|
||||
@ -139,7 +140,7 @@ def get_valid_model_abs_path(relative_path: str) -> str:
|
||||
) from e
|
||||
|
||||
|
||||
# 对外提供当前模型(带版本校验)
|
||||
# 对外提供当前模型(带版本校验)(保持不变)
|
||||
def get_current_yolo_model():
|
||||
"""供检测模块获取当前最新默认模型(仅版本变化时重新加载)"""
|
||||
global _yolo_model, _current_model_version
|
||||
@ -155,21 +156,19 @@ def get_current_yolo_model():
|
||||
return None
|
||||
|
||||
# 1. 计算当前默认模型的唯一版本标识
|
||||
# (路径哈希 + 文件修改时间戳,确保模型变化时版本变化)
|
||||
valid_abs_path = get_valid_model_abs_path(default_model["path"])
|
||||
model_stat = os.stat(valid_abs_path)
|
||||
model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
|
||||
|
||||
# 2. 版本未变化则复用已有模型(核心优化点)
|
||||
# 2. 版本未变化则复用已有模型
|
||||
if _yolo_model and _current_model_version == model_version:
|
||||
# print(f"[get_current_yolo_model] 模型版本未变,复用缓存(版本:{_current_model_version[:10]}...)")
|
||||
return _yolo_model
|
||||
|
||||
# 3. 版本变化时重新加载模型
|
||||
_yolo_model = load_yolo_model(valid_abs_path)
|
||||
if _yolo_model:
|
||||
setattr(_yolo_model, "model_path", valid_abs_path)
|
||||
_current_model_version = model_version # 更新版本标识
|
||||
_current_model_version = model_version
|
||||
print(f"[get_current_yolo_model] 模型版本更新,重新加载(新版本:{model_version[:10]}...)")
|
||||
else:
|
||||
print(f"[get_current_yolo_model] 加载最新默认模型失败:{valid_abs_path}")
|
||||
@ -182,7 +181,14 @@ def get_current_yolo_model():
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 1. 上传模型
|
||||
# 新增:获取当前置信度阈值
|
||||
def get_current_conf_threshold():
|
||||
"""供检测模块获取当前设置的置信度阈值"""
|
||||
global _current_conf_threshold
|
||||
return _current_conf_threshold
|
||||
|
||||
|
||||
# 1. 上传模型(保持不变)
|
||||
@router.post("", response_model=APIResponse, summary="上传YOLO模型(.pt格式)")
|
||||
async def upload_model(
|
||||
name: str = Form(..., description="模型名称"),
|
||||
@ -255,7 +261,7 @@ async def upload_model(
|
||||
return APIResponse(
|
||||
code=201,
|
||||
message=f"模型上传成功!ID:{new_model['id']}",
|
||||
data=ModelResponse(**new_model)
|
||||
data=ModelResponse(** new_model)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
@ -273,7 +279,7 @@ async def upload_model(
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 2. 获取模型列表
|
||||
# 2. 获取模型列表(保持不变)
|
||||
@router.get("", response_model=APIResponse, summary="获取模型列表(分页)")
|
||||
async def get_model_list(
|
||||
page: int = Query(1, ge=1),
|
||||
@ -319,7 +325,7 @@ async def get_model_list(
|
||||
message=f"获取成功!共{total}条记录",
|
||||
data=ModelListResponse(
|
||||
total=total,
|
||||
models=[ModelResponse(**model) for model in model_list]
|
||||
models=[ModelResponse(** model) for model in model_list]
|
||||
)
|
||||
)
|
||||
|
||||
@ -329,7 +335,7 @@ async def get_model_list(
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 3. 获取默认模型
|
||||
# 3. 获取默认模型(保持不变)
|
||||
@router.get("/default", response_model=APIResponse, summary="获取当前默认模型")
|
||||
async def get_default_model():
|
||||
conn = None
|
||||
@ -371,7 +377,7 @@ async def get_default_model():
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 4. 获取单个模型详情
|
||||
# 4. 获取单个模型详情(保持不变)
|
||||
@router.get("/{model_id}", response_model=APIResponse, summary="获取单个模型详情")
|
||||
async def get_model(model_id: int):
|
||||
conn = None
|
||||
@ -392,7 +398,7 @@ async def get_model(model_id: int):
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"查询成功,但路径异常:{e.detail}",
|
||||
data=ModelResponse(**model)
|
||||
data=ModelResponse(** model)
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
@ -400,14 +406,13 @@ async def get_model(model_id: int):
|
||||
message="查询成功",
|
||||
data=ModelResponse(**model)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 5. 更新模型信息
|
||||
# 5. 更新模型信息(保持不变)
|
||||
@router.put("/{model_id}", response_model=APIResponse, summary="更新模型信息")
|
||||
async def update_model(model_id: int, model_update: ModelUpdateRequest):
|
||||
conn = None
|
||||
@ -479,7 +484,7 @@ async def update_model(model_id: int, model_update: ModelUpdateRequest):
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="模型更新成功",
|
||||
data=ModelResponse(**updated_model)
|
||||
data=ModelResponse(** updated_model)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
@ -490,9 +495,12 @@ async def update_model(model_id: int, model_update: ModelUpdateRequest):
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 5.1 更换默认模型(自动重启服务)
|
||||
# 5.1 更换默认模型(添加置信度参数)
|
||||
@router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型(自动重启服务)")
|
||||
async def set_default_model(model_id: int):
|
||||
async def set_default_model(
|
||||
model_id: int,
|
||||
conf_threshold: float = Query(0.8, ge=0.01, le=0.99, description="模型检测置信度阈值(0.01-0.99)")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
@ -551,10 +559,11 @@ async def set_default_model(model_id: int):
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
updated_model = cursor.fetchone()
|
||||
|
||||
# 7. 重置版本标识(关键:确保下次检测加载新模型)
|
||||
global _current_model_version
|
||||
# 7. 重置版本标识和更新置信度
|
||||
global _current_model_version, _current_conf_threshold
|
||||
_current_model_version = None
|
||||
print(f"[更换默认模型] 已重置模型版本标识,下次检测将加载新模型")
|
||||
_current_conf_threshold = conf_threshold # 保存动态置信度
|
||||
print(f"[更换默认模型] 已重置模型版本标识,设置新置信度:{conf_threshold}")
|
||||
|
||||
# 8. 延迟重启服务
|
||||
print(f"[更换默认模型] 成功!将在1秒后重启服务以应用新模型(ID:{model_id})")
|
||||
@ -566,8 +575,8 @@ async def set_default_model(model_id: int):
|
||||
# 9. 返回成功响应
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"已成功更换默认模型(ID:{model_id})!服务将在1秒后自动重启以应用新模型",
|
||||
data=ModelResponse(**updated_model)
|
||||
message=f"已成功更换默认模型(ID:{model_id}),置信度:{conf_threshold}!服务将在1秒后自动重启以应用新模型",
|
||||
data=ModelResponse(** updated_model)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
@ -580,7 +589,7 @@ async def set_default_model(model_id: int):
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 6. 删除模型
|
||||
# 6. 删除模型(保持不变)
|
||||
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
|
||||
async def delete_model(model_id: int):
|
||||
conn = None
|
||||
@ -639,7 +648,7 @@ async def delete_model(model_id: int):
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 7. 下载模型文件
|
||||
# 7. 下载模型文件(保持不变)
|
||||
@router.get("/{model_id}/download", summary="下载模型文件")
|
||||
async def download_model(model_id: int):
|
||||
conn = None
|
||||
@ -665,4 +674,4 @@ async def download_model(model_id: int):
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
db.close_connection(conn, cursor)
|
||||
Reference in New Issue
Block a user