内容安全审核
This commit is contained in:
131
service/ocr_service.py
Normal file
131
service/ocr_service.py
Normal file
@ -0,0 +1,131 @@
|
||||
# 首先添加NumPy兼容处理
|
||||
import numpy as np
|
||||
|
||||
# 修复np.int已弃用的问题
|
||||
if not hasattr(np, 'int'):
|
||||
np.int = int
|
||||
|
||||
from paddleocr import PaddleOCR
|
||||
from service.sensitive_service import get_all_sensitive_words
|
||||
|
||||
_ocr_engine = None
|
||||
_forbidden_words = set()
|
||||
_conf_threshold = 0.5
|
||||
|
||||
def set_forbidden_words(new_words):
|
||||
global _forbidden_words
|
||||
if not isinstance(new_words, (set, list, tuple)):
|
||||
raise TypeError("新违禁词必须是集合、列表或元组类型")
|
||||
_forbidden_words = set(new_words) # 确保是集合类型
|
||||
print(f"已通过函数更新违禁词,当前数量: {len(_forbidden_words)}")
|
||||
|
||||
def load_forbidden_words():
|
||||
global _forbidden_words
|
||||
try:
|
||||
_forbidden_words = get_all_sensitive_words()
|
||||
print(f"加载的违禁词数量: {len(_forbidden_words)}")
|
||||
except Exception as e:
|
||||
print(f"Forbidden words load error: {e}")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def init_ocr_engine():
|
||||
global _ocr_engine
|
||||
try:
|
||||
_ocr_engine = PaddleOCR(
|
||||
use_angle_cls=True,
|
||||
lang="ch",
|
||||
show_log=False,
|
||||
use_gpu=True,
|
||||
max_text_length=1024
|
||||
)
|
||||
load_result = load_forbidden_words()
|
||||
if not load_result:
|
||||
print("警告:违禁词加载失败,可能影响检测功能")
|
||||
print("OCR引擎初始化完成")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"OCR引擎初始化错误: {e}")
|
||||
_ocr_engine = None
|
||||
return False
|
||||
|
||||
|
||||
def detect(frame, conf_threshold=0.8):
|
||||
print("开始进行OCR检测...")
|
||||
try:
|
||||
ocr_res = _ocr_engine.ocr(frame, cls=True)
|
||||
if not ocr_res or not isinstance(ocr_res, list):
|
||||
return (False, "无OCR结果")
|
||||
|
||||
texts = []
|
||||
confs = []
|
||||
for line in ocr_res:
|
||||
if line is None:
|
||||
continue
|
||||
if isinstance(line, list):
|
||||
items_to_process = line
|
||||
else:
|
||||
items_to_process = [line]
|
||||
|
||||
for item in items_to_process:
|
||||
if isinstance(item, list) and len(item) == 4:
|
||||
is_coordinate = True
|
||||
for point in item:
|
||||
if not (isinstance(point, list) and len(point) == 2 and
|
||||
all(isinstance(coord, (int, float)) for coord in point)):
|
||||
is_coordinate = False
|
||||
break
|
||||
if is_coordinate:
|
||||
continue
|
||||
if isinstance(item, list) and all(isinstance(x, (int, float)) for x in item):
|
||||
continue
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
text, conf = item
|
||||
if isinstance(text, str) and isinstance(conf, (int, float)):
|
||||
texts.append(text.strip())
|
||||
confs.append(float(conf))
|
||||
continue
|
||||
if isinstance(item, list) and len(item) >= 2:
|
||||
text_data = item[1]
|
||||
if isinstance(text_data, tuple) and len(text_data) == 2:
|
||||
text, conf = text_data
|
||||
if isinstance(text, str) and isinstance(conf, (int, float)):
|
||||
texts.append(text.strip())
|
||||
confs.append(float(conf))
|
||||
continue
|
||||
elif isinstance(text_data, str):
|
||||
texts.append(text_data.strip())
|
||||
confs.append(1.0)
|
||||
continue
|
||||
print(f"无法解析的OCR结果格式: {item}")
|
||||
|
||||
if len(texts) != len(confs):
|
||||
return (False, "OCR结果格式异常")
|
||||
|
||||
# 收集所有识别到的违禁词(去重且保持出现顺序)
|
||||
vio_words = []
|
||||
for txt, conf in zip(texts, confs):
|
||||
if conf < _conf_threshold: # 过滤低置信度结果
|
||||
continue
|
||||
# 提取当前文本中包含的违禁词
|
||||
matched = [w for w in _forbidden_words if w in txt]
|
||||
# 仅添加未记录过的违禁词(去重)
|
||||
for word in matched:
|
||||
if word not in vio_words:
|
||||
vio_words.append(word)
|
||||
|
||||
has_text = len(texts) > 0
|
||||
has_violation = len(vio_words) > 0
|
||||
|
||||
if not has_text:
|
||||
return (False, "未识别到文本")
|
||||
elif has_violation:
|
||||
# 多个违禁词用逗号拼接
|
||||
return (True, ", ".join(vio_words))
|
||||
else:
|
||||
return (False, "未检测到违禁词")
|
||||
|
||||
except Exception as e:
|
||||
print(f"OCR detect error: {e}")
|
||||
return (False, f"检测错误: {str(e)}")
|
||||
Reference in New Issue
Block a user