implement get_db helper to centralize connection management and enable WAL mode

This commit is contained in:
CPTN Cosmo 2026-05-19 19:04:13 +02:00
parent 90ed01c8a1
commit 81f7031555

47
main.py
View file

@ -44,8 +44,13 @@ def get_password_hash(password):
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password[:72], hashed_password)
def get_db():
conn = sqlite3.connect(DB_PATH, timeout=20.0)
conn.execute("PRAGMA journal_mode=WAL;")
return conn
def init_db():
conn = sqlite3.connect(DB_PATH)
conn = get_db()
c = conn.cursor()
# Create users table
@ -148,7 +153,7 @@ async def login(request: Request):
username = data.get("username")
password = data.get("password")
conn = sqlite3.connect(DB_PATH)
conn = get_db()
conn.row_factory = sqlite3.Row
c = conn.cursor()
c.execute('SELECT * FROM users WHERE username = ?', (username,))
@ -182,7 +187,7 @@ async def auth_status(request: Request):
# User Management (Admin Only)
@app.get("/api/users", dependencies=[Depends(is_authenticated), Depends(is_admin)])
def get_users():
conn = sqlite3.connect(DB_PATH)
conn = get_db()
conn.row_factory = sqlite3.Row
c = conn.cursor()
c.execute('SELECT id, username, is_admin FROM users')
@ -202,7 +207,7 @@ async def create_user(request: Request):
hashed = get_password_hash(password)
try:
conn = sqlite3.connect(DB_PATH)
conn = get_db()
c = conn.cursor()
c.execute('INSERT INTO users (username, hashed_password, is_admin) VALUES (?, ?, ?)',
(username, hashed, is_admin))
@ -216,7 +221,7 @@ async def create_user(request: Request):
def delete_user(id: int, request: Request):
if id == request.session.get("user_id"):
raise HTTPException(status_code=400, detail="Cannot delete yourself")
conn = sqlite3.connect(DB_PATH)
conn = get_db()
c = conn.cursor()
c.execute('DELETE FROM users WHERE id=?', (id,))
# Also delete their instances? Let's say yes for cleanliness.
@ -232,7 +237,7 @@ async def admin_reset_password(id: int, request: Request):
if not new_password:
raise HTTPException(status_code=400, detail="New password required")
hashed = get_password_hash(new_password)
conn = sqlite3.connect(DB_PATH)
conn = get_db()
c = conn.cursor()
c.execute('UPDATE users SET hashed_password = ? WHERE id = ?', (hashed, id))
conn.commit()
@ -248,7 +253,7 @@ async def change_own_password(request: Request):
raise HTTPException(status_code=400, detail="New password required")
user_id = request.session.get("user_id")
hashed = get_password_hash(new_password)
conn = sqlite3.connect(DB_PATH)
conn = get_db()
c = conn.cursor()
c.execute('UPDATE users SET hashed_password = ? WHERE id = ?', (hashed, user_id))
conn.commit()
@ -263,7 +268,7 @@ async def change_own_username(request: Request):
raise HTTPException(status_code=400, detail="New username required")
user_id = request.session.get("user_id")
try:
conn = sqlite3.connect(DB_PATH)
conn = get_db()
c = conn.cursor()
c.execute('UPDATE users SET username = ? WHERE id = ?', (new_username, user_id))
conn.commit()
@ -308,7 +313,7 @@ def clean_secret(secret_input):
@app.get("/api/otp-secrets", dependencies=[Depends(is_authenticated)])
def get_otp_secrets(request: Request):
user_id = request.session.get("user_id")
conn = sqlite3.connect(DB_PATH)
conn = get_db()
conn.row_factory = sqlite3.Row
c = conn.cursor()
c.execute('SELECT id, name FROM otp_secrets WHERE user_id = ?', (user_id,))
@ -321,7 +326,7 @@ def create_otp_secret(secret_data: OTPSecretCreate, request: Request):
user_id = request.session.get("user_id")
cleaned = clean_secret(secret_data.secret)
encrypted = fernet.encrypt(cleaned.encode()).decode()
conn = sqlite3.connect(DB_PATH)
conn = get_db()
c = conn.cursor()
c.execute('''
INSERT INTO otp_secrets (name, encrypted_secret, user_id)
@ -334,7 +339,7 @@ def create_otp_secret(secret_data: OTPSecretCreate, request: Request):
@app.put("/api/otp-secrets/{id}", dependencies=[Depends(is_authenticated)])
def update_otp_secret(id: int, secret_data: OTPSecretUpdate, request: Request):
user_id = request.session.get("user_id")
conn = sqlite3.connect(DB_PATH)
conn = get_db()
c = conn.cursor()
# Verify ownership
@ -363,7 +368,7 @@ def update_otp_secret(id: int, secret_data: OTPSecretUpdate, request: Request):
@app.delete("/api/otp-secrets/{id}", dependencies=[Depends(is_authenticated)])
def delete_otp_secret(id: int, request: Request):
user_id = request.session.get("user_id")
conn = sqlite3.connect(DB_PATH)
conn = get_db()
c = conn.cursor()
# Check if used by any instances
@ -380,7 +385,7 @@ def delete_otp_secret(id: int, request: Request):
async def poll_instances():
while True:
try:
conn = sqlite3.connect(DB_PATH)
conn = get_db()
conn.row_factory = sqlite3.Row
c = conn.cursor()
c.execute('''
@ -422,7 +427,7 @@ async def poll_instances():
except Exception as e:
status_msg = f"Error: {str(e)}"
conn = sqlite3.connect(DB_PATH)
conn = get_db()
c = conn.cursor()
c.execute('''
UPDATE instances
@ -443,7 +448,7 @@ async def startup_event():
@app.get("/api/instances", dependencies=[Depends(is_authenticated)])
def get_instances(request: Request):
user_id = request.session.get("user_id")
conn = sqlite3.connect(DB_PATH)
conn = get_db()
conn.row_factory = sqlite3.Row
c = conn.cursor()
c.execute('''
@ -459,12 +464,12 @@ def get_instances(request: Request):
@app.post("/api/instances", dependencies=[Depends(is_authenticated)])
def create_instance(inst: InstanceCreate, request: Request):
user_id = request.session.get("user_id")
conn = sqlite3.connect(DB_PATH)
conn = get_db()
c = conn.cursor()
c.execute('''
INSERT INTO instances (name, ip, port, otp_secret_id, user_id)
VALUES (?, ?, ?, ?, ?)
''', (inst.name, inst.ip, inst.port, inst.otp_secret_id, user_id))
INSERT INTO instances (name, ip, port, otp_secret_id, user_id, encrypted_secret)
VALUES (?, ?, ?, ?, ?, ?)
''', (inst.name, inst.ip, inst.port, inst.otp_secret_id, user_id, ""))
conn.commit()
conn.close()
return {"status": "ok"}
@ -472,7 +477,7 @@ def create_instance(inst: InstanceCreate, request: Request):
@app.put("/api/instances/{id}", dependencies=[Depends(is_authenticated)])
def update_instance(id: int, inst: InstanceUpdate, request: Request):
user_id = request.session.get("user_id")
conn = sqlite3.connect(DB_PATH)
conn = get_db()
c = conn.cursor()
# Verify ownership
@ -497,7 +502,7 @@ def get_config():
@app.delete("/api/instances/{id}", dependencies=[Depends(is_authenticated)])
def delete_instance(id: int, request: Request):
user_id = request.session.get("user_id")
conn = sqlite3.connect(DB_PATH)
conn = get_db()
c = conn.cursor()
c.execute('DELETE FROM instances WHERE id=? AND user_id=?', (id, user_id))
conn.commit()