replace Basic Auth with session-based authentication and a login overlay
This commit is contained in:
parent
2304717de5
commit
844879d301
5 changed files with 132 additions and 21 deletions
51
main.py
51
main.py
|
|
@ -7,9 +7,10 @@ import httpx
|
|||
import base64
|
||||
import secrets
|
||||
import urllib.parse
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import Response, FileResponse
|
||||
from fastapi import FastAPI, Request, Depends, HTTPException
|
||||
from fastapi.responses import Response, FileResponse, JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from pydantic import BaseModel
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
|
|
@ -30,6 +31,7 @@ except ValueError:
|
|||
exit(1)
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(SessionMiddleware, secret_key=ENCRYPTION_KEY)
|
||||
|
||||
def init_db():
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
|
|
@ -50,21 +52,28 @@ def init_db():
|
|||
|
||||
init_db()
|
||||
|
||||
@app.middleware("http")
|
||||
async def basic_auth_middleware(request: Request, call_next):
|
||||
if request.url.path.startswith("/api/") or request.url.path == "/" or request.url.path.startswith("/static/"):
|
||||
auth = request.headers.get("Authorization")
|
||||
if not auth or not auth.startswith("Basic "):
|
||||
return Response(status_code=401, headers={"WWW-Authenticate": 'Basic realm="Login Required"'}, content="Unauthorized")
|
||||
try:
|
||||
decoded = base64.b64decode(auth[6:]).decode()
|
||||
username, password = decoded.split(":", 1)
|
||||
# Accept any username, just check password
|
||||
if not secrets.compare_digest(password, WEB_PASSWORD):
|
||||
return Response(status_code=401, headers={"WWW-Authenticate": 'Basic realm="Login Required"'}, content="Unauthorized")
|
||||
except:
|
||||
return Response(status_code=401, headers={"WWW-Authenticate": 'Basic realm="Login Required"'}, content="Unauthorized")
|
||||
return await call_next(request)
|
||||
def is_authenticated(request: Request):
|
||||
if not request.session.get("authenticated"):
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
return True
|
||||
|
||||
@app.post("/api/login")
|
||||
async def login(request: Request):
|
||||
data = await request.json()
|
||||
password = data.get("password")
|
||||
if secrets.compare_digest(password, WEB_PASSWORD):
|
||||
request.session["authenticated"] = True
|
||||
return {"status": "ok"}
|
||||
raise HTTPException(status_code=401, detail="Invalid password")
|
||||
|
||||
@app.get("/api/logout")
|
||||
async def logout(request: Request):
|
||||
request.session.clear()
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/api/auth/status")
|
||||
async def auth_status(request: Request):
|
||||
return {"authenticated": request.session.get("authenticated", False)}
|
||||
|
||||
class InstanceCreate(BaseModel):
|
||||
name: str
|
||||
|
|
@ -143,7 +152,7 @@ async def poll_instances():
|
|||
async def startup_event():
|
||||
asyncio.create_task(poll_instances())
|
||||
|
||||
@app.get("/api/instances")
|
||||
@app.get("/api/instances", dependencies=[Depends(is_authenticated)])
|
||||
def get_instances():
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
|
@ -153,7 +162,7 @@ def get_instances():
|
|||
conn.close()
|
||||
return instances
|
||||
|
||||
@app.post("/api/instances")
|
||||
@app.post("/api/instances", dependencies=[Depends(is_authenticated)])
|
||||
def create_instance(inst: InstanceCreate):
|
||||
cleaned = clean_secret(inst.secret)
|
||||
encrypted = fernet.encrypt(cleaned.encode()).decode()
|
||||
|
|
@ -167,11 +176,11 @@ def create_instance(inst: InstanceCreate):
|
|||
conn.close()
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/api/config")
|
||||
@app.get("/api/config", dependencies=[Depends(is_authenticated)])
|
||||
def get_config():
|
||||
return {"firewall_host_ip": FIREWALL_HOST_IP}
|
||||
|
||||
@app.delete("/api/instances/{id}")
|
||||
@app.delete("/api/instances/{id}", dependencies=[Depends(is_authenticated)])
|
||||
def delete_instance(id: int):
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
c = conn.cursor()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue