implement get_db helper to centralize connection management and enable WAL mode
This commit is contained in:
parent
90ed01c8a1
commit
81f7031555
1 changed files with 26 additions and 21 deletions
47
main.py
47
main.py
|
|
@ -44,8 +44,13 @@ def get_password_hash(password):
|
|||
def verify_password(plain_password, hashed_password):
|
||||
return pwd_context.verify(plain_password[:72], hashed_password)
|
||||
|
||||
def get_db():
|
||||
conn = sqlite3.connect(DB_PATH, timeout=20.0)
|
||||
conn.execute("PRAGMA journal_mode=WAL;")
|
||||
return conn
|
||||
|
||||
def init_db():
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn = get_db()
|
||||
c = conn.cursor()
|
||||
|
||||
# Create users table
|
||||
|
|
@ -148,7 +153,7 @@ async def login(request: Request):
|
|||
username = data.get("username")
|
||||
password = data.get("password")
|
||||
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn = get_db()
|
||||
conn.row_factory = sqlite3.Row
|
||||
c = conn.cursor()
|
||||
c.execute('SELECT * FROM users WHERE username = ?', (username,))
|
||||
|
|
@ -182,7 +187,7 @@ async def auth_status(request: Request):
|
|||
# User Management (Admin Only)
|
||||
@app.get("/api/users", dependencies=[Depends(is_authenticated), Depends(is_admin)])
|
||||
def get_users():
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn = get_db()
|
||||
conn.row_factory = sqlite3.Row
|
||||
c = conn.cursor()
|
||||
c.execute('SELECT id, username, is_admin FROM users')
|
||||
|
|
@ -202,7 +207,7 @@ async def create_user(request: Request):
|
|||
|
||||
hashed = get_password_hash(password)
|
||||
try:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn = get_db()
|
||||
c = conn.cursor()
|
||||
c.execute('INSERT INTO users (username, hashed_password, is_admin) VALUES (?, ?, ?)',
|
||||
(username, hashed, is_admin))
|
||||
|
|
@ -216,7 +221,7 @@ async def create_user(request: Request):
|
|||
def delete_user(id: int, request: Request):
|
||||
if id == request.session.get("user_id"):
|
||||
raise HTTPException(status_code=400, detail="Cannot delete yourself")
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn = get_db()
|
||||
c = conn.cursor()
|
||||
c.execute('DELETE FROM users WHERE id=?', (id,))
|
||||
# Also delete their instances? Let's say yes for cleanliness.
|
||||
|
|
@ -232,7 +237,7 @@ async def admin_reset_password(id: int, request: Request):
|
|||
if not new_password:
|
||||
raise HTTPException(status_code=400, detail="New password required")
|
||||
hashed = get_password_hash(new_password)
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn = get_db()
|
||||
c = conn.cursor()
|
||||
c.execute('UPDATE users SET hashed_password = ? WHERE id = ?', (hashed, id))
|
||||
conn.commit()
|
||||
|
|
@ -248,7 +253,7 @@ async def change_own_password(request: Request):
|
|||
raise HTTPException(status_code=400, detail="New password required")
|
||||
user_id = request.session.get("user_id")
|
||||
hashed = get_password_hash(new_password)
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn = get_db()
|
||||
c = conn.cursor()
|
||||
c.execute('UPDATE users SET hashed_password = ? WHERE id = ?', (hashed, user_id))
|
||||
conn.commit()
|
||||
|
|
@ -263,7 +268,7 @@ async def change_own_username(request: Request):
|
|||
raise HTTPException(status_code=400, detail="New username required")
|
||||
user_id = request.session.get("user_id")
|
||||
try:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn = get_db()
|
||||
c = conn.cursor()
|
||||
c.execute('UPDATE users SET username = ? WHERE id = ?', (new_username, user_id))
|
||||
conn.commit()
|
||||
|
|
@ -308,7 +313,7 @@ def clean_secret(secret_input):
|
|||
@app.get("/api/otp-secrets", dependencies=[Depends(is_authenticated)])
|
||||
def get_otp_secrets(request: Request):
|
||||
user_id = request.session.get("user_id")
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn = get_db()
|
||||
conn.row_factory = sqlite3.Row
|
||||
c = conn.cursor()
|
||||
c.execute('SELECT id, name FROM otp_secrets WHERE user_id = ?', (user_id,))
|
||||
|
|
@ -321,7 +326,7 @@ def create_otp_secret(secret_data: OTPSecretCreate, request: Request):
|
|||
user_id = request.session.get("user_id")
|
||||
cleaned = clean_secret(secret_data.secret)
|
||||
encrypted = fernet.encrypt(cleaned.encode()).decode()
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn = get_db()
|
||||
c = conn.cursor()
|
||||
c.execute('''
|
||||
INSERT INTO otp_secrets (name, encrypted_secret, user_id)
|
||||
|
|
@ -334,7 +339,7 @@ def create_otp_secret(secret_data: OTPSecretCreate, request: Request):
|
|||
@app.put("/api/otp-secrets/{id}", dependencies=[Depends(is_authenticated)])
|
||||
def update_otp_secret(id: int, secret_data: OTPSecretUpdate, request: Request):
|
||||
user_id = request.session.get("user_id")
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn = get_db()
|
||||
c = conn.cursor()
|
||||
|
||||
# Verify ownership
|
||||
|
|
@ -363,7 +368,7 @@ def update_otp_secret(id: int, secret_data: OTPSecretUpdate, request: Request):
|
|||
@app.delete("/api/otp-secrets/{id}", dependencies=[Depends(is_authenticated)])
|
||||
def delete_otp_secret(id: int, request: Request):
|
||||
user_id = request.session.get("user_id")
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn = get_db()
|
||||
c = conn.cursor()
|
||||
|
||||
# Check if used by any instances
|
||||
|
|
@ -380,7 +385,7 @@ def delete_otp_secret(id: int, request: Request):
|
|||
async def poll_instances():
|
||||
while True:
|
||||
try:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn = get_db()
|
||||
conn.row_factory = sqlite3.Row
|
||||
c = conn.cursor()
|
||||
c.execute('''
|
||||
|
|
@ -422,7 +427,7 @@ async def poll_instances():
|
|||
except Exception as e:
|
||||
status_msg = f"Error: {str(e)}"
|
||||
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn = get_db()
|
||||
c = conn.cursor()
|
||||
c.execute('''
|
||||
UPDATE instances
|
||||
|
|
@ -443,7 +448,7 @@ async def startup_event():
|
|||
@app.get("/api/instances", dependencies=[Depends(is_authenticated)])
|
||||
def get_instances(request: Request):
|
||||
user_id = request.session.get("user_id")
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn = get_db()
|
||||
conn.row_factory = sqlite3.Row
|
||||
c = conn.cursor()
|
||||
c.execute('''
|
||||
|
|
@ -459,12 +464,12 @@ def get_instances(request: Request):
|
|||
@app.post("/api/instances", dependencies=[Depends(is_authenticated)])
|
||||
def create_instance(inst: InstanceCreate, request: Request):
|
||||
user_id = request.session.get("user_id")
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn = get_db()
|
||||
c = conn.cursor()
|
||||
c.execute('''
|
||||
INSERT INTO instances (name, ip, port, otp_secret_id, user_id)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
''', (inst.name, inst.ip, inst.port, inst.otp_secret_id, user_id))
|
||||
INSERT INTO instances (name, ip, port, otp_secret_id, user_id, encrypted_secret)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
''', (inst.name, inst.ip, inst.port, inst.otp_secret_id, user_id, ""))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return {"status": "ok"}
|
||||
|
|
@ -472,7 +477,7 @@ def create_instance(inst: InstanceCreate, request: Request):
|
|||
@app.put("/api/instances/{id}", dependencies=[Depends(is_authenticated)])
|
||||
def update_instance(id: int, inst: InstanceUpdate, request: Request):
|
||||
user_id = request.session.get("user_id")
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn = get_db()
|
||||
c = conn.cursor()
|
||||
|
||||
# Verify ownership
|
||||
|
|
@ -497,7 +502,7 @@ def get_config():
|
|||
@app.delete("/api/instances/{id}", dependencies=[Depends(is_authenticated)])
|
||||
def delete_instance(id: int, request: Request):
|
||||
user_id = request.session.get("user_id")
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn = get_db()
|
||||
c = conn.cursor()
|
||||
c.execute('DELETE FROM instances WHERE id=? AND user_id=?', (id, user_id))
|
||||
conn.commit()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue