facerecognition/face_db.py
2025-04-07 08:08:39 +08:00

124 lines
3.5 KiB
Python

import sqlite3
import numpy as np
import pickle
import os
class FaceDatabase:
def __init__(self, db_path='data/face_database.db'):
"""初始化人脸数据库
Args:
db_path: 数据库文件路径
"""
# 确保数据目录存在
os.makedirs(os.path.dirname(db_path), exist_ok=True)
self.db_path = db_path
self.conn = sqlite3.connect(db_path)
self.create_tables()
def create_tables(self):
"""创建必要的数据表"""
cursor = self.conn.cursor()
# 创建人员表
cursor.execute('''
CREATE TABLE IF NOT EXISTS persons (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
register_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
# 创建人脸特征表
cursor.execute('''
CREATE TABLE IF NOT EXISTS face_features (
id INTEGER PRIMARY KEY AUTOINCREMENT,
person_id INTEGER NOT NULL,
feature_vector BLOB NOT NULL,
FOREIGN KEY (person_id) REFERENCES persons (id)
)
''')
self.conn.commit()
def add_person(self, name):
"""添加人员信息
Args:
name: 人员姓名
Returns:
person_id: 新增人员的ID
"""
cursor = self.conn.cursor()
cursor.execute("INSERT INTO persons (name) VALUES (?)", (name,))
self.conn.commit()
return cursor.lastrowid
def add_face_feature(self, person_id, feature_vector):
"""添加人脸特征向量
Args:
person_id: 人员ID
feature_vector: 人脸特征向量(numpy数组)
Returns:
feature_id: 新增特征的ID
"""
# 将numpy数组序列化为二进制数据
serialized_feature = pickle.dumps(feature_vector)
cursor = self.conn.cursor()
cursor.execute(
"INSERT INTO face_features (person_id, feature_vector) VALUES (?, ?)",
(person_id, serialized_feature)
)
self.conn.commit()
return cursor.lastrowid
def get_all_features(self):
"""获取所有人脸特征
Returns:
list of tuples: [(person_id, name, feature_vector), ...]
"""
cursor = self.conn.cursor()
cursor.execute("""
SELECT p.id, p.name, f.feature_vector
FROM persons p
JOIN face_features f ON p.id = f.person_id
""")
results = []
for row in cursor.fetchall():
person_id, name, serialized_feature = row
# 反序列化特征向量
feature_vector = pickle.loads(serialized_feature)
results.append((person_id, name, feature_vector))
return results
def get_person_by_id(self, person_id):
"""根据ID获取人员信息
Args:
person_id: 人员ID
Returns:
dict or None: 人员信息
"""
cursor = self.conn.cursor()
cursor.execute("SELECT id, name FROM persons WHERE id = ?", (person_id,))
result = cursor.fetchone()
if result:
return {"id": result[0], "name": result[1]}
return None
def close(self):
"""关闭数据库连接"""
if self.conn:
self.conn.close()
self.conn = None