implement get_db helper to centralize connection management and enable WAL mode
This commit is contained in:
parent
90ed01c8a1
commit
81f7031555
1 changed files with 26 additions and 21 deletions
47
main.py
47
main.py
|
|
@ -44,8 +44,13 @@ def get_password_hash(password):
|
||||||
def verify_password(plain_password, hashed_password):
|
def verify_password(plain_password, hashed_password):
|
||||||
return pwd_context.verify(plain_password[:72], hashed_password)
|
return pwd_context.verify(plain_password[:72], hashed_password)
|
||||||
|
|
||||||
|
def get_db():
|
||||||
|
conn = sqlite3.connect(DB_PATH, timeout=20.0)
|
||||||
|
conn.execute("PRAGMA journal_mode=WAL;")
|
||||||
|
return conn
|
||||||
|
|
||||||
def init_db():
|
def init_db():
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = get_db()
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
|
|
||||||
# Create users table
|
# Create users table
|
||||||
|
|
@ -148,7 +153,7 @@ async def login(request: Request):
|
||||||
username = data.get("username")
|
username = data.get("username")
|
||||||
password = data.get("password")
|
password = data.get("password")
|
||||||
|
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = get_db()
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute('SELECT * FROM users WHERE username = ?', (username,))
|
c.execute('SELECT * FROM users WHERE username = ?', (username,))
|
||||||
|
|
@ -182,7 +187,7 @@ async def auth_status(request: Request):
|
||||||
# User Management (Admin Only)
|
# User Management (Admin Only)
|
||||||
@app.get("/api/users", dependencies=[Depends(is_authenticated), Depends(is_admin)])
|
@app.get("/api/users", dependencies=[Depends(is_authenticated), Depends(is_admin)])
|
||||||
def get_users():
|
def get_users():
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = get_db()
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute('SELECT id, username, is_admin FROM users')
|
c.execute('SELECT id, username, is_admin FROM users')
|
||||||
|
|
@ -202,7 +207,7 @@ async def create_user(request: Request):
|
||||||
|
|
||||||
hashed = get_password_hash(password)
|
hashed = get_password_hash(password)
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = get_db()
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute('INSERT INTO users (username, hashed_password, is_admin) VALUES (?, ?, ?)',
|
c.execute('INSERT INTO users (username, hashed_password, is_admin) VALUES (?, ?, ?)',
|
||||||
(username, hashed, is_admin))
|
(username, hashed, is_admin))
|
||||||
|
|
@ -216,7 +221,7 @@ async def create_user(request: Request):
|
||||||
def delete_user(id: int, request: Request):
|
def delete_user(id: int, request: Request):
|
||||||
if id == request.session.get("user_id"):
|
if id == request.session.get("user_id"):
|
||||||
raise HTTPException(status_code=400, detail="Cannot delete yourself")
|
raise HTTPException(status_code=400, detail="Cannot delete yourself")
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = get_db()
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute('DELETE FROM users WHERE id=?', (id,))
|
c.execute('DELETE FROM users WHERE id=?', (id,))
|
||||||
# Also delete their instances? Let's say yes for cleanliness.
|
# Also delete their instances? Let's say yes for cleanliness.
|
||||||
|
|
@ -232,7 +237,7 @@ async def admin_reset_password(id: int, request: Request):
|
||||||
if not new_password:
|
if not new_password:
|
||||||
raise HTTPException(status_code=400, detail="New password required")
|
raise HTTPException(status_code=400, detail="New password required")
|
||||||
hashed = get_password_hash(new_password)
|
hashed = get_password_hash(new_password)
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = get_db()
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute('UPDATE users SET hashed_password = ? WHERE id = ?', (hashed, id))
|
c.execute('UPDATE users SET hashed_password = ? WHERE id = ?', (hashed, id))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
@ -248,7 +253,7 @@ async def change_own_password(request: Request):
|
||||||
raise HTTPException(status_code=400, detail="New password required")
|
raise HTTPException(status_code=400, detail="New password required")
|
||||||
user_id = request.session.get("user_id")
|
user_id = request.session.get("user_id")
|
||||||
hashed = get_password_hash(new_password)
|
hashed = get_password_hash(new_password)
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = get_db()
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute('UPDATE users SET hashed_password = ? WHERE id = ?', (hashed, user_id))
|
c.execute('UPDATE users SET hashed_password = ? WHERE id = ?', (hashed, user_id))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
@ -263,7 +268,7 @@ async def change_own_username(request: Request):
|
||||||
raise HTTPException(status_code=400, detail="New username required")
|
raise HTTPException(status_code=400, detail="New username required")
|
||||||
user_id = request.session.get("user_id")
|
user_id = request.session.get("user_id")
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = get_db()
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute('UPDATE users SET username = ? WHERE id = ?', (new_username, user_id))
|
c.execute('UPDATE users SET username = ? WHERE id = ?', (new_username, user_id))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
@ -308,7 +313,7 @@ def clean_secret(secret_input):
|
||||||
@app.get("/api/otp-secrets", dependencies=[Depends(is_authenticated)])
|
@app.get("/api/otp-secrets", dependencies=[Depends(is_authenticated)])
|
||||||
def get_otp_secrets(request: Request):
|
def get_otp_secrets(request: Request):
|
||||||
user_id = request.session.get("user_id")
|
user_id = request.session.get("user_id")
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = get_db()
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute('SELECT id, name FROM otp_secrets WHERE user_id = ?', (user_id,))
|
c.execute('SELECT id, name FROM otp_secrets WHERE user_id = ?', (user_id,))
|
||||||
|
|
@ -321,7 +326,7 @@ def create_otp_secret(secret_data: OTPSecretCreate, request: Request):
|
||||||
user_id = request.session.get("user_id")
|
user_id = request.session.get("user_id")
|
||||||
cleaned = clean_secret(secret_data.secret)
|
cleaned = clean_secret(secret_data.secret)
|
||||||
encrypted = fernet.encrypt(cleaned.encode()).decode()
|
encrypted = fernet.encrypt(cleaned.encode()).decode()
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = get_db()
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute('''
|
c.execute('''
|
||||||
INSERT INTO otp_secrets (name, encrypted_secret, user_id)
|
INSERT INTO otp_secrets (name, encrypted_secret, user_id)
|
||||||
|
|
@ -334,7 +339,7 @@ def create_otp_secret(secret_data: OTPSecretCreate, request: Request):
|
||||||
@app.put("/api/otp-secrets/{id}", dependencies=[Depends(is_authenticated)])
|
@app.put("/api/otp-secrets/{id}", dependencies=[Depends(is_authenticated)])
|
||||||
def update_otp_secret(id: int, secret_data: OTPSecretUpdate, request: Request):
|
def update_otp_secret(id: int, secret_data: OTPSecretUpdate, request: Request):
|
||||||
user_id = request.session.get("user_id")
|
user_id = request.session.get("user_id")
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = get_db()
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
|
|
||||||
# Verify ownership
|
# Verify ownership
|
||||||
|
|
@ -363,7 +368,7 @@ def update_otp_secret(id: int, secret_data: OTPSecretUpdate, request: Request):
|
||||||
@app.delete("/api/otp-secrets/{id}", dependencies=[Depends(is_authenticated)])
|
@app.delete("/api/otp-secrets/{id}", dependencies=[Depends(is_authenticated)])
|
||||||
def delete_otp_secret(id: int, request: Request):
|
def delete_otp_secret(id: int, request: Request):
|
||||||
user_id = request.session.get("user_id")
|
user_id = request.session.get("user_id")
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = get_db()
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
|
|
||||||
# Check if used by any instances
|
# Check if used by any instances
|
||||||
|
|
@ -380,7 +385,7 @@ def delete_otp_secret(id: int, request: Request):
|
||||||
async def poll_instances():
|
async def poll_instances():
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = get_db()
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute('''
|
c.execute('''
|
||||||
|
|
@ -422,7 +427,7 @@ async def poll_instances():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
status_msg = f"Error: {str(e)}"
|
status_msg = f"Error: {str(e)}"
|
||||||
|
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = get_db()
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute('''
|
c.execute('''
|
||||||
UPDATE instances
|
UPDATE instances
|
||||||
|
|
@ -443,7 +448,7 @@ async def startup_event():
|
||||||
@app.get("/api/instances", dependencies=[Depends(is_authenticated)])
|
@app.get("/api/instances", dependencies=[Depends(is_authenticated)])
|
||||||
def get_instances(request: Request):
|
def get_instances(request: Request):
|
||||||
user_id = request.session.get("user_id")
|
user_id = request.session.get("user_id")
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = get_db()
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute('''
|
c.execute('''
|
||||||
|
|
@ -459,12 +464,12 @@ def get_instances(request: Request):
|
||||||
@app.post("/api/instances", dependencies=[Depends(is_authenticated)])
|
@app.post("/api/instances", dependencies=[Depends(is_authenticated)])
|
||||||
def create_instance(inst: InstanceCreate, request: Request):
|
def create_instance(inst: InstanceCreate, request: Request):
|
||||||
user_id = request.session.get("user_id")
|
user_id = request.session.get("user_id")
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = get_db()
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute('''
|
c.execute('''
|
||||||
INSERT INTO instances (name, ip, port, otp_secret_id, user_id)
|
INSERT INTO instances (name, ip, port, otp_secret_id, user_id, encrypted_secret)
|
||||||
VALUES (?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
''', (inst.name, inst.ip, inst.port, inst.otp_secret_id, user_id))
|
''', (inst.name, inst.ip, inst.port, inst.otp_secret_id, user_id, ""))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
@ -472,7 +477,7 @@ def create_instance(inst: InstanceCreate, request: Request):
|
||||||
@app.put("/api/instances/{id}", dependencies=[Depends(is_authenticated)])
|
@app.put("/api/instances/{id}", dependencies=[Depends(is_authenticated)])
|
||||||
def update_instance(id: int, inst: InstanceUpdate, request: Request):
|
def update_instance(id: int, inst: InstanceUpdate, request: Request):
|
||||||
user_id = request.session.get("user_id")
|
user_id = request.session.get("user_id")
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = get_db()
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
|
|
||||||
# Verify ownership
|
# Verify ownership
|
||||||
|
|
@ -497,7 +502,7 @@ def get_config():
|
||||||
@app.delete("/api/instances/{id}", dependencies=[Depends(is_authenticated)])
|
@app.delete("/api/instances/{id}", dependencies=[Depends(is_authenticated)])
|
||||||
def delete_instance(id: int, request: Request):
|
def delete_instance(id: int, request: Request):
|
||||||
user_id = request.session.get("user_id")
|
user_id = request.session.get("user_id")
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = get_db()
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute('DELETE FROM instances WHERE id=? AND user_id=?', (id, user_id))
|
c.execute('DELETE FROM instances WHERE id=? AND user_id=?', (id, user_id))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue