124 lines
3.5 KiB
Python
124 lines
3.5 KiB
Python
import sqlite3
|
|
import numpy as np
|
|
import pickle
|
|
import os
|
|
|
|
class FaceDatabase:
|
|
def __init__(self, db_path='data/face_database.db'):
|
|
"""初始化人脸数据库
|
|
|
|
Args:
|
|
db_path: 数据库文件路径
|
|
"""
|
|
# 确保数据目录存在
|
|
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
|
|
|
self.db_path = db_path
|
|
self.conn = sqlite3.connect(db_path)
|
|
self.create_tables()
|
|
|
|
def create_tables(self):
|
|
"""创建必要的数据表"""
|
|
cursor = self.conn.cursor()
|
|
|
|
# 创建人员表
|
|
cursor.execute('''
|
|
CREATE TABLE IF NOT EXISTS persons (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name TEXT NOT NULL,
|
|
register_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
''')
|
|
|
|
# 创建人脸特征表
|
|
cursor.execute('''
|
|
CREATE TABLE IF NOT EXISTS face_features (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
person_id INTEGER NOT NULL,
|
|
feature_vector BLOB NOT NULL,
|
|
FOREIGN KEY (person_id) REFERENCES persons (id)
|
|
)
|
|
''')
|
|
|
|
self.conn.commit()
|
|
|
|
def add_person(self, name):
|
|
"""添加人员信息
|
|
|
|
Args:
|
|
name: 人员姓名
|
|
|
|
Returns:
|
|
person_id: 新增人员的ID
|
|
"""
|
|
cursor = self.conn.cursor()
|
|
cursor.execute("INSERT INTO persons (name) VALUES (?)", (name,))
|
|
self.conn.commit()
|
|
return cursor.lastrowid
|
|
|
|
def add_face_feature(self, person_id, feature_vector):
|
|
"""添加人脸特征向量
|
|
|
|
Args:
|
|
person_id: 人员ID
|
|
feature_vector: 人脸特征向量(numpy数组)
|
|
|
|
Returns:
|
|
feature_id: 新增特征的ID
|
|
"""
|
|
# 将numpy数组序列化为二进制数据
|
|
serialized_feature = pickle.dumps(feature_vector)
|
|
|
|
cursor = self.conn.cursor()
|
|
cursor.execute(
|
|
"INSERT INTO face_features (person_id, feature_vector) VALUES (?, ?)",
|
|
(person_id, serialized_feature)
|
|
)
|
|
self.conn.commit()
|
|
return cursor.lastrowid
|
|
|
|
def get_all_features(self):
|
|
"""获取所有人脸特征
|
|
|
|
Returns:
|
|
list of tuples: [(person_id, name, feature_vector), ...]
|
|
"""
|
|
cursor = self.conn.cursor()
|
|
cursor.execute("""
|
|
SELECT p.id, p.name, f.feature_vector
|
|
FROM persons p
|
|
JOIN face_features f ON p.id = f.person_id
|
|
""")
|
|
|
|
results = []
|
|
for row in cursor.fetchall():
|
|
person_id, name, serialized_feature = row
|
|
# 反序列化特征向量
|
|
feature_vector = pickle.loads(serialized_feature)
|
|
results.append((person_id, name, feature_vector))
|
|
|
|
return results
|
|
|
|
def get_person_by_id(self, person_id):
|
|
"""根据ID获取人员信息
|
|
|
|
Args:
|
|
person_id: 人员ID
|
|
|
|
Returns:
|
|
dict or None: 人员信息
|
|
"""
|
|
cursor = self.conn.cursor()
|
|
cursor.execute("SELECT id, name FROM persons WHERE id = ?", (person_id,))
|
|
result = cursor.fetchone()
|
|
|
|
if result:
|
|
return {"id": result[0], "name": result[1]}
|
|
return None
|
|
|
|
def close(self):
|
|
"""关闭数据库连接"""
|
|
if self.conn:
|
|
self.conn.close()
|
|
self.conn = None
|