Compare commits

...

2 Commits

Author SHA1 Message Date
6400c586ad
Add auth stuff 2025-09-18 19:18:59 +03:00
645416a04c
cleanup 2025-09-15 22:01:23 +03:00
17 changed files with 112 additions and 38 deletions

View File

@ -1,7 +1,7 @@
use darling::FromDeriveInput; use darling::FromDeriveInput;
use proc_macro::{self, TokenStream}; use proc_macro::{self, TokenStream};
use quote::quote; use quote::quote;
use syn::{Data, DeriveInput, Expr, Fields, Type, parse_macro_input}; use syn::{Data, Expr, Fields, Type, parse_macro_input};
#[derive(FromDeriveInput, Default)] #[derive(FromDeriveInput, Default)]
#[darling(default, attributes(meta))] #[darling(default, attributes(meta))]

View File

@ -1,7 +1,6 @@
use axum::{ use axum::{
body::{Body, Bytes}, extract::State,
extract::{Path, State}, http::StatusCode,
http::{HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use serde::Deserialize; use serde::Deserialize;
@ -14,7 +13,7 @@ pub struct Data {
} }
#[axum::debug_handler] #[axum::debug_handler]
pub async fn route(State(db): State<Database>) -> impl IntoResponse { pub async fn route(State(_db): State<Database>) -> impl IntoResponse {
if false { if false {
return Response::builder() return Response::builder()
.status(StatusCode::NOT_FOUND) .status(StatusCode::NOT_FOUND)

View File

@ -1,7 +1,6 @@
use axum::{ use axum::{
body::{Body, Bytes}, extract::State,
extract::{Path, State}, http::StatusCode,
http::{HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use serde::Deserialize; use serde::Deserialize;
@ -27,4 +26,3 @@ pub async fn route(State(db): State<Database>) -> impl IntoResponse {
.body(String::new()) .body(String::new())
.unwrap() .unwrap()
} }

View File

@ -1,7 +1,6 @@
use axum::{ use axum::{
body::{Body, Bytes}, extract::State,
extract::{Path, State}, http::StatusCode,
http::{HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use serde::Deserialize; use serde::Deserialize;
@ -27,4 +26,3 @@ pub async fn route(State(db): State<Database>) -> impl IntoResponse {
.body(String::new()) .body(String::new())
.unwrap() .unwrap()
} }

View File

@ -1,7 +1,6 @@
use axum::{ use axum::{
body::{Body, Bytes}, extract::State,
extract::{Path, State}, http::StatusCode,
http::{HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use serde::Deserialize; use serde::Deserialize;
@ -27,4 +26,3 @@ pub async fn route(State(db): State<Database>) -> impl IntoResponse {
.body(String::new()) .body(String::new())
.unwrap() .unwrap()
} }

View File

@ -1,7 +1,6 @@
use axum::{ use axum::{
body::{Body, Bytes}, extract::State,
extract::{Path, State}, http::StatusCode,
http::{HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use serde::Deserialize; use serde::Deserialize;
@ -27,4 +26,3 @@ pub async fn route(State(db): State<Database>) -> impl IntoResponse {
.body(String::new()) .body(String::new())
.unwrap() .unwrap()
} }

View File

@ -4,7 +4,7 @@ use axum::{Router, extract::State, http::StatusCode, routing::get};
pub mod files; pub mod files;
pub mod user; pub mod user;
async fn root(State(state): State<Database>) -> (StatusCode, &'static str) { async fn root(State(_state): State<Database>) -> (StatusCode, &'static str) {
(StatusCode::OK, "We Good twin :3c") (StatusCode::OK, "We Good twin :3c")
} }

View File

@ -1,9 +1,4 @@
use axum::{ use axum::{body::Body, extract::State, http::StatusCode, response::Response};
body::Body,
extract::State,
http::{HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response},
};
use crate::db::Database; use crate::db::Database;

View File

@ -1,7 +1,4 @@
use axum::{ use axum::{Router, routing::post};
Router,
routing::{get, post},
};
use crate::db::Database; use crate::db::Database;

82
src/db/auth.rs Normal file
View File

@ -0,0 +1,82 @@
use anyhow::bail;
use axum::{
http::{Response, StatusCode},
response::{Html, IntoResponse},
};
use crate::db::{
Database,
tables::{sessions::Session, user::User},
};
pub enum AuthResultApi<R: Default> {
Ok(User),
Error(Response<R>),
}
pub enum AuthResultWeb {
Ok(User),
Error((StatusCode, Html<String>)),
}
#[derive(Debug)]
pub struct Auth;
impl Auth {
pub async fn check_auth(
db: &Database,
session: tower_sessions::Session,
token: &String,
) -> anyhow::Result<User> {
let mut user_session: crate::Session = session
.get(crate::Session::KEY)
.await
.unwrap()
.unwrap_or_default();
if let Some(uid) = user_session.user_id {
return Ok(User::get_by_id(db.pool(), uid).await?);
}
if let Ok(session_token) = Session::get_by_session_key(db.pool(), token).await {
user_session.user_id = Some(session_token.user_id);
session
.insert(crate::Session::KEY, user_session)
.await
.unwrap();
return Ok(User::get_by_id(db.pool(), session_token.user_id).await?);
}
bail!("Unable to find user by this token")
}
pub async fn check_auth_api<R: Default>(
db: &Database,
session: tower_sessions::Session,
token: &String,
) -> AuthResultApi<R> {
match Self::check_auth(db, session, token).await {
Ok(r) => AuthResultApi::Ok(r),
Err(_) => AuthResultApi::Error(
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("Location", "/")
.body(R::default())
.unwrap(),
),
}
}
pub async fn check_auth_web<R: Default>(
db: &Database,
session: tower_sessions::Session,
token: &String,
) -> AuthResultWeb {
match Self::check_auth(db, session, token).await {
Ok(r) => AuthResultWeb::Ok(r),
Err(_) => {
AuthResultWeb::Error((StatusCode::UNAUTHORIZED, Html("UNAUTHORIZED".to_string())))
}
}
}
}

View File

@ -1,8 +1,8 @@
use anyhow::Result; use anyhow::Result;
use sqlx::{Pool, Postgres, postgres::PgPoolOptions}; use sqlx::{Pool, Postgres, postgres::PgPoolOptions};
pub mod auth;
pub mod tables; pub mod tables;
pub type CurrDb = Postgres; pub type CurrDb = Postgres;
pub type CurrPool = Pool<CurrDb>; pub type CurrPool = Pool<CurrDb>;

View File

@ -15,7 +15,6 @@ pub struct Mission {
pub created_at: i64, pub created_at: i64,
pub modified_at: i64, pub modified_at: i64,
} }
impl Mission { impl Mission {
pub async fn insert_new(&self, pool: &CurrPool) -> Result<Self> { pub async fn insert_new(&self, pool: &CurrPool) -> Result<Self> {
let session = sqlx::query_as!( let session = sqlx::query_as!(

View File

@ -4,7 +4,7 @@ use sqlx::prelude::FromRow;
use crate::db::{ use crate::db::{
CurrPool, CurrPool,
tables::{ForeignKey, TableMeta, assignables::missions::Mission, user::User}, tables::{ForeignKey, TableMeta, user::User},
}; };
#[derive(Debug, Default, Clone, FromRow, TableMeta)] #[derive(Debug, Default, Clone, FromRow, TableMeta)]

View File

@ -4,7 +4,7 @@ use std::{
ops::{Deref, DerefMut}, ops::{Deref, DerefMut},
}; };
use sqlx::{Decode, Encode, QueryBuilder, Type, postgres::PgRow}; use sqlx::{Decode, QueryBuilder, Type, postgres::PgRow};
pub mod assignables; pub mod assignables;
pub mod attendance; pub mod attendance;

View File

@ -41,7 +41,7 @@ impl Session {
.await?; .await?;
Ok(session) Ok(session)
} }
pub async fn get_by_session_key(pool: &CurrPool, session_key: String) -> anyhow::Result<Self> { pub async fn get_by_session_key(pool: &CurrPool, session_key: &String) -> anyhow::Result<Self> {
let session = sqlx::query_as!( let session = sqlx::query_as!(
Session, Session,
"SELECT * FROM sessions WHERE session_key = $1", "SELECT * FROM sessions WHERE session_key = $1",

View File

@ -1,4 +1,3 @@
use anyhow::bail;
use persmgr_derive::TableMeta; use persmgr_derive::TableMeta;
use sqlx::prelude::FromRow; use sqlx::prelude::FromRow;

View File

@ -1,4 +1,5 @@
use axum::Router; use axum::Router;
use serde::{Deserialize, Serialize};
use tower::ServiceBuilder; use tower::ServiceBuilder;
use tower_cookies::cookie::time::Duration; use tower_cookies::cookie::time::Duration;
use tower_http::{services::ServeDir, trace::TraceLayer}; use tower_http::{services::ServeDir, trace::TraceLayer};
@ -10,6 +11,16 @@ mod api;
mod db; mod db;
mod pages; mod pages;
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct Session {
// None if unauthed yet
user_id: Option<i64>,
}
impl Session {
const KEY: &'static str = "user_session";
}
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
let sub = FmtSubscriber::builder() let sub = FmtSubscriber::builder()