Add auth stuff
This commit is contained in:
parent
645416a04c
commit
6400c586ad
82
src/db/auth.rs
Normal file
82
src/db/auth.rs
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
use anyhow::bail;
|
||||||
|
use axum::{
|
||||||
|
http::{Response, StatusCode},
|
||||||
|
response::{Html, IntoResponse},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::db::{
|
||||||
|
Database,
|
||||||
|
tables::{sessions::Session, user::User},
|
||||||
|
};
|
||||||
|
|
||||||
|
pub enum AuthResultApi<R: Default> {
|
||||||
|
Ok(User),
|
||||||
|
Error(Response<R>),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum AuthResultWeb {
|
||||||
|
Ok(User),
|
||||||
|
Error((StatusCode, Html<String>)),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Auth;
|
||||||
|
|
||||||
|
impl Auth {
|
||||||
|
pub async fn check_auth(
|
||||||
|
db: &Database,
|
||||||
|
session: tower_sessions::Session,
|
||||||
|
token: &String,
|
||||||
|
) -> anyhow::Result<User> {
|
||||||
|
let mut user_session: crate::Session = session
|
||||||
|
.get(crate::Session::KEY)
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
if let Some(uid) = user_session.user_id {
|
||||||
|
return Ok(User::get_by_id(db.pool(), uid).await?);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(session_token) = Session::get_by_session_key(db.pool(), token).await {
|
||||||
|
user_session.user_id = Some(session_token.user_id);
|
||||||
|
session
|
||||||
|
.insert(crate::Session::KEY, user_session)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
return Ok(User::get_by_id(db.pool(), session_token.user_id).await?);
|
||||||
|
}
|
||||||
|
|
||||||
|
bail!("Unable to find user by this token")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn check_auth_api<R: Default>(
|
||||||
|
db: &Database,
|
||||||
|
session: tower_sessions::Session,
|
||||||
|
token: &String,
|
||||||
|
) -> AuthResultApi<R> {
|
||||||
|
match Self::check_auth(db, session, token).await {
|
||||||
|
Ok(r) => AuthResultApi::Ok(r),
|
||||||
|
Err(_) => AuthResultApi::Error(
|
||||||
|
Response::builder()
|
||||||
|
.status(StatusCode::UNAUTHORIZED)
|
||||||
|
.header("Location", "/")
|
||||||
|
.body(R::default())
|
||||||
|
.unwrap(),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn check_auth_web<R: Default>(
|
||||||
|
db: &Database,
|
||||||
|
session: tower_sessions::Session,
|
||||||
|
token: &String,
|
||||||
|
) -> AuthResultWeb {
|
||||||
|
match Self::check_auth(db, session, token).await {
|
||||||
|
Ok(r) => AuthResultWeb::Ok(r),
|
||||||
|
Err(_) => {
|
||||||
|
AuthResultWeb::Error((StatusCode::UNAUTHORIZED, Html("UNAUTHORIZED".to_string())))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,8 +1,8 @@
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use sqlx::{Pool, Postgres, postgres::PgPoolOptions};
|
use sqlx::{Pool, Postgres, postgres::PgPoolOptions};
|
||||||
|
|
||||||
|
pub mod auth;
|
||||||
pub mod tables;
|
pub mod tables;
|
||||||
|
|
||||||
pub type CurrDb = Postgres;
|
pub type CurrDb = Postgres;
|
||||||
pub type CurrPool = Pool<CurrDb>;
|
pub type CurrPool = Pool<CurrDb>;
|
||||||
|
|
||||||
|
|
|
@ -41,7 +41,7 @@ impl Session {
|
||||||
.await?;
|
.await?;
|
||||||
Ok(session)
|
Ok(session)
|
||||||
}
|
}
|
||||||
pub async fn get_by_session_key(pool: &CurrPool, session_key: String) -> anyhow::Result<Self> {
|
pub async fn get_by_session_key(pool: &CurrPool, session_key: &String) -> anyhow::Result<Self> {
|
||||||
let session = sqlx::query_as!(
|
let session = sqlx::query_as!(
|
||||||
Session,
|
Session,
|
||||||
"SELECT * FROM sessions WHERE session_key = $1",
|
"SELECT * FROM sessions WHERE session_key = $1",
|
||||||
|
|
11
src/main.rs
11
src/main.rs
|
@ -1,4 +1,5 @@
|
||||||
use axum::Router;
|
use axum::Router;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use tower::ServiceBuilder;
|
use tower::ServiceBuilder;
|
||||||
use tower_cookies::cookie::time::Duration;
|
use tower_cookies::cookie::time::Duration;
|
||||||
use tower_http::{services::ServeDir, trace::TraceLayer};
|
use tower_http::{services::ServeDir, trace::TraceLayer};
|
||||||
|
@ -10,6 +11,16 @@ mod api;
|
||||||
mod db;
|
mod db;
|
||||||
mod pages;
|
mod pages;
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||||
|
pub struct Session {
|
||||||
|
// None if unauthed yet
|
||||||
|
user_id: Option<i64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Session {
|
||||||
|
const KEY: &'static str = "user_session";
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
let sub = FmtSubscriber::builder()
|
let sub = FmtSubscriber::builder()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user