From 81f7031555552521a7650e9e325398ddbe2e18d8 Mon Sep 17 00:00:00 2001 From: cosmo Date: Tue, 19 May 2026 19:04:13 +0200 Subject: [PATCH] implement get_db helper to centralize connection management and enable WAL mode --- main.py | 47 ++++++++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/main.py b/main.py index 9f5d2c6..8cc96ae 100644 --- a/main.py +++ b/main.py @@ -44,8 +44,13 @@ def get_password_hash(password): def verify_password(plain_password, hashed_password): return pwd_context.verify(plain_password[:72], hashed_password) +def get_db(): + conn = sqlite3.connect(DB_PATH, timeout=20.0) + conn.execute("PRAGMA journal_mode=WAL;") + return conn + def init_db(): - conn = sqlite3.connect(DB_PATH) + conn = get_db() c = conn.cursor() # Create users table @@ -148,7 +153,7 @@ async def login(request: Request): username = data.get("username") password = data.get("password") - conn = sqlite3.connect(DB_PATH) + conn = get_db() conn.row_factory = sqlite3.Row c = conn.cursor() c.execute('SELECT * FROM users WHERE username = ?', (username,)) @@ -182,7 +187,7 @@ async def auth_status(request: Request): # User Management (Admin Only) @app.get("/api/users", dependencies=[Depends(is_authenticated), Depends(is_admin)]) def get_users(): - conn = sqlite3.connect(DB_PATH) + conn = get_db() conn.row_factory = sqlite3.Row c = conn.cursor() c.execute('SELECT id, username, is_admin FROM users') @@ -202,7 +207,7 @@ async def create_user(request: Request): hashed = get_password_hash(password) try: - conn = sqlite3.connect(DB_PATH) + conn = get_db() c = conn.cursor() c.execute('INSERT INTO users (username, hashed_password, is_admin) VALUES (?, ?, ?)', (username, hashed, is_admin)) @@ -216,7 +221,7 @@ async def create_user(request: Request): def delete_user(id: int, request: Request): if id == request.session.get("user_id"): raise HTTPException(status_code=400, detail="Cannot delete yourself") - conn = sqlite3.connect(DB_PATH) + conn = get_db() c = conn.cursor() c.execute('DELETE FROM users WHERE id=?', (id,)) # Also delete their instances? Let's say yes for cleanliness. @@ -232,7 +237,7 @@ async def admin_reset_password(id: int, request: Request): if not new_password: raise HTTPException(status_code=400, detail="New password required") hashed = get_password_hash(new_password) - conn = sqlite3.connect(DB_PATH) + conn = get_db() c = conn.cursor() c.execute('UPDATE users SET hashed_password = ? WHERE id = ?', (hashed, id)) conn.commit() @@ -248,7 +253,7 @@ async def change_own_password(request: Request): raise HTTPException(status_code=400, detail="New password required") user_id = request.session.get("user_id") hashed = get_password_hash(new_password) - conn = sqlite3.connect(DB_PATH) + conn = get_db() c = conn.cursor() c.execute('UPDATE users SET hashed_password = ? WHERE id = ?', (hashed, user_id)) conn.commit() @@ -263,7 +268,7 @@ async def change_own_username(request: Request): raise HTTPException(status_code=400, detail="New username required") user_id = request.session.get("user_id") try: - conn = sqlite3.connect(DB_PATH) + conn = get_db() c = conn.cursor() c.execute('UPDATE users SET username = ? WHERE id = ?', (new_username, user_id)) conn.commit() @@ -308,7 +313,7 @@ def clean_secret(secret_input): @app.get("/api/otp-secrets", dependencies=[Depends(is_authenticated)]) def get_otp_secrets(request: Request): user_id = request.session.get("user_id") - conn = sqlite3.connect(DB_PATH) + conn = get_db() conn.row_factory = sqlite3.Row c = conn.cursor() c.execute('SELECT id, name FROM otp_secrets WHERE user_id = ?', (user_id,)) @@ -321,7 +326,7 @@ def create_otp_secret(secret_data: OTPSecretCreate, request: Request): user_id = request.session.get("user_id") cleaned = clean_secret(secret_data.secret) encrypted = fernet.encrypt(cleaned.encode()).decode() - conn = sqlite3.connect(DB_PATH) + conn = get_db() c = conn.cursor() c.execute(''' INSERT INTO otp_secrets (name, encrypted_secret, user_id) @@ -334,7 +339,7 @@ def create_otp_secret(secret_data: OTPSecretCreate, request: Request): @app.put("/api/otp-secrets/{id}", dependencies=[Depends(is_authenticated)]) def update_otp_secret(id: int, secret_data: OTPSecretUpdate, request: Request): user_id = request.session.get("user_id") - conn = sqlite3.connect(DB_PATH) + conn = get_db() c = conn.cursor() # Verify ownership @@ -363,7 +368,7 @@ def update_otp_secret(id: int, secret_data: OTPSecretUpdate, request: Request): @app.delete("/api/otp-secrets/{id}", dependencies=[Depends(is_authenticated)]) def delete_otp_secret(id: int, request: Request): user_id = request.session.get("user_id") - conn = sqlite3.connect(DB_PATH) + conn = get_db() c = conn.cursor() # Check if used by any instances @@ -380,7 +385,7 @@ def delete_otp_secret(id: int, request: Request): async def poll_instances(): while True: try: - conn = sqlite3.connect(DB_PATH) + conn = get_db() conn.row_factory = sqlite3.Row c = conn.cursor() c.execute(''' @@ -422,7 +427,7 @@ async def poll_instances(): except Exception as e: status_msg = f"Error: {str(e)}" - conn = sqlite3.connect(DB_PATH) + conn = get_db() c = conn.cursor() c.execute(''' UPDATE instances @@ -443,7 +448,7 @@ async def startup_event(): @app.get("/api/instances", dependencies=[Depends(is_authenticated)]) def get_instances(request: Request): user_id = request.session.get("user_id") - conn = sqlite3.connect(DB_PATH) + conn = get_db() conn.row_factory = sqlite3.Row c = conn.cursor() c.execute(''' @@ -459,12 +464,12 @@ def get_instances(request: Request): @app.post("/api/instances", dependencies=[Depends(is_authenticated)]) def create_instance(inst: InstanceCreate, request: Request): user_id = request.session.get("user_id") - conn = sqlite3.connect(DB_PATH) + conn = get_db() c = conn.cursor() c.execute(''' - INSERT INTO instances (name, ip, port, otp_secret_id, user_id) - VALUES (?, ?, ?, ?, ?) - ''', (inst.name, inst.ip, inst.port, inst.otp_secret_id, user_id)) + INSERT INTO instances (name, ip, port, otp_secret_id, user_id, encrypted_secret) + VALUES (?, ?, ?, ?, ?, ?) + ''', (inst.name, inst.ip, inst.port, inst.otp_secret_id, user_id, "")) conn.commit() conn.close() return {"status": "ok"} @@ -472,7 +477,7 @@ def create_instance(inst: InstanceCreate, request: Request): @app.put("/api/instances/{id}", dependencies=[Depends(is_authenticated)]) def update_instance(id: int, inst: InstanceUpdate, request: Request): user_id = request.session.get("user_id") - conn = sqlite3.connect(DB_PATH) + conn = get_db() c = conn.cursor() # Verify ownership @@ -497,7 +502,7 @@ def get_config(): @app.delete("/api/instances/{id}", dependencies=[Depends(is_authenticated)]) def delete_instance(id: int, request: Request): user_id = request.session.get("user_id") - conn = sqlite3.connect(DB_PATH) + conn = get_db() c = conn.cursor() c.execute('DELETE FROM instances WHERE id=? AND user_id=?', (id, user_id)) conn.commit()