Add auth stuff
This commit is contained in:
parent
645416a04c
commit
6400c586ad
82
src/db/auth.rs
Normal file
82
src/db/auth.rs
Normal file
|
@ -0,0 +1,82 @@
|
|||
use anyhow::bail;
|
||||
use axum::{
|
||||
http::{Response, StatusCode},
|
||||
response::{Html, IntoResponse},
|
||||
};
|
||||
|
||||
use crate::db::{
|
||||
Database,
|
||||
tables::{sessions::Session, user::User},
|
||||
};
|
||||
|
||||
pub enum AuthResultApi<R: Default> {
|
||||
Ok(User),
|
||||
Error(Response<R>),
|
||||
}
|
||||
|
||||
pub enum AuthResultWeb {
|
||||
Ok(User),
|
||||
Error((StatusCode, Html<String>)),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Auth;
|
||||
|
||||
impl Auth {
|
||||
pub async fn check_auth(
|
||||
db: &Database,
|
||||
session: tower_sessions::Session,
|
||||
token: &String,
|
||||
) -> anyhow::Result<User> {
|
||||
let mut user_session: crate::Session = session
|
||||
.get(crate::Session::KEY)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap_or_default();
|
||||
|
||||
if let Some(uid) = user_session.user_id {
|
||||
return Ok(User::get_by_id(db.pool(), uid).await?);
|
||||
}
|
||||
|
||||
if let Ok(session_token) = Session::get_by_session_key(db.pool(), token).await {
|
||||
user_session.user_id = Some(session_token.user_id);
|
||||
session
|
||||
.insert(crate::Session::KEY, user_session)
|
||||
.await
|
||||
.unwrap();
|
||||
return Ok(User::get_by_id(db.pool(), session_token.user_id).await?);
|
||||
}
|
||||
|
||||
bail!("Unable to find user by this token")
|
||||
}
|
||||
|
||||
pub async fn check_auth_api<R: Default>(
|
||||
db: &Database,
|
||||
session: tower_sessions::Session,
|
||||
token: &String,
|
||||
) -> AuthResultApi<R> {
|
||||
match Self::check_auth(db, session, token).await {
|
||||
Ok(r) => AuthResultApi::Ok(r),
|
||||
Err(_) => AuthResultApi::Error(
|
||||
Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.header("Location", "/")
|
||||
.body(R::default())
|
||||
.unwrap(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn check_auth_web<R: Default>(
|
||||
db: &Database,
|
||||
session: tower_sessions::Session,
|
||||
token: &String,
|
||||
) -> AuthResultWeb {
|
||||
match Self::check_auth(db, session, token).await {
|
||||
Ok(r) => AuthResultWeb::Ok(r),
|
||||
Err(_) => {
|
||||
AuthResultWeb::Error((StatusCode::UNAUTHORIZED, Html("UNAUTHORIZED".to_string())))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,8 +1,8 @@
|
|||
use anyhow::Result;
|
||||
use sqlx::{Pool, Postgres, postgres::PgPoolOptions};
|
||||
|
||||
pub mod auth;
|
||||
pub mod tables;
|
||||
|
||||
pub type CurrDb = Postgres;
|
||||
pub type CurrPool = Pool<CurrDb>;
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ impl Session {
|
|||
.await?;
|
||||
Ok(session)
|
||||
}
|
||||
pub async fn get_by_session_key(pool: &CurrPool, session_key: String) -> anyhow::Result<Self> {
|
||||
pub async fn get_by_session_key(pool: &CurrPool, session_key: &String) -> anyhow::Result<Self> {
|
||||
let session = sqlx::query_as!(
|
||||
Session,
|
||||
"SELECT * FROM sessions WHERE session_key = $1",
|
||||
|
|
11
src/main.rs
11
src/main.rs
|
@ -1,4 +1,5 @@
|
|||
use axum::Router;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tower::ServiceBuilder;
|
||||
use tower_cookies::cookie::time::Duration;
|
||||
use tower_http::{services::ServeDir, trace::TraceLayer};
|
||||
|
@ -10,6 +11,16 @@ mod api;
|
|||
mod db;
|
||||
mod pages;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||
pub struct Session {
|
||||
// None if unauthed yet
|
||||
user_id: Option<i64>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
const KEY: &'static str = "user_session";
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let sub = FmtSubscriber::builder()
|
||||
|
|
Loading…
Reference in New Issue
Block a user