diff --git a/src/db/auth.rs b/src/db/auth.rs new file mode 100644 index 0000000..da93f36 --- /dev/null +++ b/src/db/auth.rs @@ -0,0 +1,82 @@ +use anyhow::bail; +use axum::{ + http::{Response, StatusCode}, + response::{Html, IntoResponse}, +}; + +use crate::db::{ + Database, + tables::{sessions::Session, user::User}, +}; + +pub enum AuthResultApi { + Ok(User), + Error(Response), +} + +pub enum AuthResultWeb { + Ok(User), + Error((StatusCode, Html)), +} + +#[derive(Debug)] +pub struct Auth; + +impl Auth { + pub async fn check_auth( + db: &Database, + session: tower_sessions::Session, + token: &String, + ) -> anyhow::Result { + let mut user_session: crate::Session = session + .get(crate::Session::KEY) + .await + .unwrap() + .unwrap_or_default(); + + if let Some(uid) = user_session.user_id { + return Ok(User::get_by_id(db.pool(), uid).await?); + } + + if let Ok(session_token) = Session::get_by_session_key(db.pool(), token).await { + user_session.user_id = Some(session_token.user_id); + session + .insert(crate::Session::KEY, user_session) + .await + .unwrap(); + return Ok(User::get_by_id(db.pool(), session_token.user_id).await?); + } + + bail!("Unable to find user by this token") + } + + pub async fn check_auth_api( + db: &Database, + session: tower_sessions::Session, + token: &String, + ) -> AuthResultApi { + match Self::check_auth(db, session, token).await { + Ok(r) => AuthResultApi::Ok(r), + Err(_) => AuthResultApi::Error( + Response::builder() + .status(StatusCode::UNAUTHORIZED) + .header("Location", "/") + .body(R::default()) + .unwrap(), + ), + } + } + + pub async fn check_auth_web( + db: &Database, + session: tower_sessions::Session, + token: &String, + ) -> AuthResultWeb { + match Self::check_auth(db, session, token).await { + Ok(r) => AuthResultWeb::Ok(r), + Err(_) => { + AuthResultWeb::Error((StatusCode::UNAUTHORIZED, Html("UNAUTHORIZED".to_string()))) + } + } + } +} diff --git a/src/db/mod.rs b/src/db/mod.rs index 8f5d9b7..711a3a2 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -1,8 +1,8 @@ use anyhow::Result; use sqlx::{Pool, Postgres, postgres::PgPoolOptions}; +pub mod auth; pub mod tables; - pub type CurrDb = Postgres; pub type CurrPool = Pool; diff --git a/src/db/tables/sessions.rs b/src/db/tables/sessions.rs index 2606b1d..4f3c4f8 100644 --- a/src/db/tables/sessions.rs +++ b/src/db/tables/sessions.rs @@ -41,7 +41,7 @@ impl Session { .await?; Ok(session) } - pub async fn get_by_session_key(pool: &CurrPool, session_key: String) -> anyhow::Result { + pub async fn get_by_session_key(pool: &CurrPool, session_key: &String) -> anyhow::Result { let session = sqlx::query_as!( Session, "SELECT * FROM sessions WHERE session_key = $1", diff --git a/src/main.rs b/src/main.rs index 5167852..9d720cb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ use axum::Router; +use serde::{Deserialize, Serialize}; use tower::ServiceBuilder; use tower_cookies::cookie::time::Duration; use tower_http::{services::ServeDir, trace::TraceLayer}; @@ -10,6 +11,16 @@ mod api; mod db; mod pages; +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct Session { + // None if unauthed yet + user_id: Option, +} + +impl Session { + const KEY: &'static str = "user_session"; +} + #[tokio::main] async fn main() { let sub = FmtSubscriber::builder()