From 4e89158d961474f35aacbe423af31b7c3c30a801 Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Tue, 14 Jan 2025 19:04:24 +0000 Subject: [PATCH] add forward auth header reading --- src/main.py | 105 ++++++++++++++++++++++++++---- src/models.py | 1 + src/settings.py | 3 + src/templates/admin_feeds.html.j2 | 3 +- 4 files changed, 99 insertions(+), 13 deletions(-) diff --git a/src/main.py b/src/main.py index 0ea6b90..89ec5e1 100644 --- a/src/main.py +++ b/src/main.py @@ -12,7 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, JSONResponse, RedirectResponse from fastapi.templating import Jinja2Templates from PIL import Image -from sqlmodel import Session, and_, select +from sqlmodel import Session, and_, or_, select import models from process import AudioProcessor @@ -30,7 +30,21 @@ def get_session() -> Generator[Session, Any, None]: yield session +def handle_user_auth(request: Request) -> tuple[str, str]: + if ( + settings.forward_auth_name_header is None + or settings.forward_auth_uid_header is None + ): + return ("default", "Admin") + + return ( + request.headers.get(settings.forward_auth_uid_header, "default"), + request.headers.get(settings.forward_auth_name_header, "Admin"), + ) + + SessionDep = Annotated[Session, Depends(get_session)] +AuthDep = Annotated[tuple[str, str], Depends(handle_user_auth)] log = structlog.get_logger() @@ -46,13 +60,23 @@ audio_processor.start_processing() @app.get("/admin") -def admin_list_podcasts(session: SessionDep, request: Request): - podcasts = session.exec(select(models.Podcast)).all() +def admin_list_podcasts(session: SessionDep, request: Request, user: AuthDep): + podcasts = session.exec( + select(models.Podcast).where( + or_( + models.Podcast.owner_id == user[0], + models.Podcast.owner_id == None, + ) + ) + ).all() return templates.TemplateResponse( request=request, name="admin_feeds.html.j2", - context={"podcasts": podcasts}, + context={ + "podcasts": podcasts, + "user_name": user[1], + }, ) @@ -68,6 +92,7 @@ def admin_create_podcast(request: Request): def admin_create_podcast_post( session: SessionDep, request: Request, + user: AuthDep, name: Annotated[str, Form()], ): if name.strip() == "": @@ -81,7 +106,7 @@ def admin_create_podcast_post( }, ) - podcast = models.Podcast(name=name, description=name) + podcast = models.Podcast(name=name, description=name, owner_id=user[0]) session.add(podcast) session.commit() @@ -90,9 +115,19 @@ def admin_create_podcast_post( @app.get("/admin/{podcast_id}") -def admin_list_podcast(session: SessionDep, request: Request, podcast_id: str): +def admin_list_podcast( + session: SessionDep, request: Request, podcast_id: str, user: AuthDep +): podcast = session.exec( - select(models.Podcast).where(models.Podcast.id == podcast_id) + select(models.Podcast).where( + and_( + models.Podcast.id == podcast_id, + or_( + models.Podcast.owner_id == user[0], + models.Podcast.owner_id == None, + ), + ) + ) ).first() if podcast is None: @@ -151,7 +186,11 @@ def finish_processing( @app.post("/admin/{podcast_id}/upload") async def admin_upload_episode( - session: SessionDep, request: Request, podcast_id: str, file: UploadFile + session: SessionDep, + request: Request, + podcast_id: str, + file: UploadFile, + user: AuthDep, ): file_id = request.headers.get("uploader-file-id") chunks_total = int(request.headers.get("uploader-chunks-total")) @@ -164,7 +203,15 @@ async def admin_upload_episode( file_id = "".join(c for c in file_id if c.isalnum()).strip() podcast = session.exec( - select(models.Podcast).where(models.Podcast.id == podcast_id) + select(models.Podcast).where( + and_( + models.Podcast.id == podcast_id, + or_( + models.Podcast.owner_id == user[0], + models.Podcast.owner_id == None, + ), + ) + ) ).first() if podcast is None: @@ -216,12 +263,17 @@ def admin_delete_episode( request: Request, podcast_id: str, episode_id: str, + user: AuthDep, ): episode = session.exec( select(models.PodcastEpisode).where( and_( models.PodcastEpisode.id == episode_id, models.PodcastEpisode.podcast_id == podcast_id, + or_( + models.Podcast.owner_id == user[0], + models.Podcast.owner_id == None, + ), ) ) ).first() @@ -250,12 +302,17 @@ def admin_edit_episode( request: Request, podcast_id: str, episode_id: str, + user: AuthDep, ): episode = session.exec( select(models.PodcastEpisode).where( and_( models.PodcastEpisode.id == episode_id, models.PodcastEpisode.podcast_id == podcast_id, + or_( + models.Podcast.owner_id == user[0], + models.Podcast.owner_id == None, + ), ) ) ).first() @@ -284,6 +341,7 @@ def admin_edit_episode_post( request: Request, podcast_id: str, episode_id: str, + user: AuthDep, name: Annotated[str, Form()], description: Annotated[str, Form()], ): @@ -292,6 +350,10 @@ def admin_edit_episode_post( and_( models.PodcastEpisode.id == episode_id, models.PodcastEpisode.podcast_id == podcast_id, + or_( + models.Podcast.owner_id == user[0], + models.Podcast.owner_id == None, + ), ) ) ).first() @@ -322,9 +384,19 @@ def admin_edit_episode_post( @app.get("/admin/{podcast_id}/edit") -def admin_edit_podcast(session: SessionDep, request: Request, podcast_id: str): +def admin_edit_podcast( + session: SessionDep, request: Request, user: AuthDep, podcast_id: str +): podcast = session.exec( - select(models.Podcast).where(models.Podcast.id == podcast_id) + select(models.Podcast).where( + and_( + models.Podcast.id == podcast_id, + or_( + models.Podcast.owner_id == user[0], + models.Podcast.owner_id == None, + ), + ) + ) ).first() if podcast is None: @@ -350,12 +422,21 @@ def admin_edit_podcast_post( session: SessionDep, request: Request, podcast_id: str, + user: AuthDep, name: Annotated[str, Form()], description: Annotated[str, Form()], image: Optional[UploadFile] = None, ): podcast = session.exec( - select(models.Podcast).where(models.Podcast.id == podcast_id) + select(models.Podcast).where( + and_( + models.Podcast.id == podcast_id, + or_( + models.Podcast.owner_id == user[0], + models.Podcast.owner_id == None, + ), + ) + ) ).first() if podcast is None: diff --git a/src/models.py b/src/models.py index bf46ab0..d64f4e0 100644 --- a/src/models.py +++ b/src/models.py @@ -15,6 +15,7 @@ class Podcast(SQLModel, table=True): description: str explicit: bool = Field(default=True) image_filename: Optional[str] = Field(default=None) + owner_id: Optional[str] = Field(default=None) episodes: list["PodcastEpisode"] = Relationship(back_populates="podcast") diff --git a/src/settings.py b/src/settings.py index ca12d4c..709e3b8 100644 --- a/src/settings.py +++ b/src/settings.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Optional from pydantic import Field from pydantic_settings import BaseSettings, SettingsConfigDict @@ -7,6 +8,8 @@ from pydantic_settings import BaseSettings, SettingsConfigDict class Settings(BaseSettings): directory: Path = Field(default=Path.cwd() / "data") uploads_directory: Path = Field(default=Path.cwd() / "uploads") + forward_auth_name_header: Optional[str] = Field(default=None) + forward_auth_uid_header: Optional[str] = Field(default=None) model_config = SettingsConfigDict(env_nested_delimiter="__", env_prefix="PG_") diff --git a/src/templates/admin_feeds.html.j2 b/src/templates/admin_feeds.html.j2 index 6b4fd7f..1ce44b2 100644 --- a/src/templates/admin_feeds.html.j2 +++ b/src/templates/admin_feeds.html.j2 @@ -1,7 +1,8 @@ {% extends 'admin_layout.html.j2' %} {% block content %} {{ super() }} -

Podcasts

+

Hello {{ user_name }}!

+

My Podcasts

Actions: Create