Improve extractors and extract some funcs

This commit is contained in:
Sofía Aritz 2024-10-14 23:02:27 +02:00
parent 546e883a9c
commit fe62b28a03
Signed by: sofia
GPG key ID: 90B5116E3542B28F
7 changed files with 103 additions and 58 deletions

View file

@ -0,0 +1,19 @@
use crate::env;
use jsonwebtoken::{TokenData, Header, Validation};
use serde::{Serialize, Deserialize};
#[derive(Debug, Serialize, Deserialize)]
pub struct JwtUser {
pub uid: String,
pub email: String,
pub name: String,
pub exp: u64,
}
pub fn encode_jwt(claims: &JwtUser) -> jsonwebtoken::errors::Result<String> {
jsonwebtoken::encode(&Header::new(*env::jwt_alg()), claims, &env::jwt_secret().0)
}
pub fn decode_jwt(jwt: &str) -> jsonwebtoken::errors::Result<TokenData<JwtUser>> {
jsonwebtoken::decode::<JwtUser>(jwt, &env::jwt_secret().1, &Validation::new(*env::jwt_alg()))
}

View file

@ -0,0 +1,20 @@
use diesel::{SqliteConnection, r2d2::{ConnectionManager, PooledConnection}, RunQueryDsl, QueryDsl, SelectableHelper, ExpressionMethods, OptionalExtension};
use crate::database::models::User;
pub fn user(user_id: &str, conn: &mut PooledConnection<ConnectionManager<SqliteConnection>>) -> diesel::result::QueryResult<User> {
use crate::database::schema::users::dsl::users;
users
.find(user_id)
.select(User::as_select())
.first(conn)
}
pub fn user_by_email(email: &str, conn: &mut PooledConnection<ConnectionManager<SqliteConnection>>) -> diesel::result::QueryResult<Option<User>> {
use crate::database::schema::users::dsl as users;
users::users
.filter(users::email.eq(email))
.limit(1)
.select(User::as_select())
.first(conn)
.optional()
}

View file

@ -23,6 +23,7 @@ use crate::env;
pub mod models; pub mod models;
pub mod schema; pub mod schema;
pub mod list; pub mod list;
pub mod actions;
pub fn create_connection_pool() -> Result<Pool<ConnectionManager<SqliteConnection>>, r2d2::Error> { pub fn create_connection_pool() -> Result<Pool<ConnectionManager<SqliteConnection>>, r2d2::Error> {

View file

@ -117,12 +117,12 @@ pub struct SessionKey {
#[diesel(table_name = schema::users)] #[diesel(table_name = schema::users)]
#[diesel(check_for_backend(diesel::sqlite::Sqlite))] #[diesel(check_for_backend(diesel::sqlite::Sqlite))]
pub struct User { pub struct User {
id: String, pub id: String,
created_at: NaiveDateTime, pub created_at: NaiveDateTime,
last_connected_at: NaiveDateTime, pub last_connected_at: NaiveDateTime,
email: String, pub email: String,
password: String, pub password: String,
name: String, pub name: String,
limits: String, pub limits: String,
assets: String, pub assets: String,
} }

View file

@ -1,21 +1,12 @@
use axum::{async_trait, extract::FromRequestParts, http::{header::AUTHORIZATION, request::Parts, StatusCode}}; use axum::{async_trait, extract::FromRequestParts, http::{header::AUTHORIZATION, request::Parts, StatusCode}};
use jsonwebtoken::Validation; use tracing::{warn, error};
use serde::{Serialize, Deserialize}; use crate::database::{actions, models::User};
use tracing::warn; use crate::AppState;
use crate::auth::JwtUser;
use crate::env; pub struct ExtractJwtUser(pub JwtUser);
#[derive(Debug, Serialize, Deserialize)]
pub struct JwtUser {
pub uid: String,
pub email: String,
pub name: String,
pub exp: u64,
}
pub struct ExtractUser(pub JwtUser);
#[async_trait] #[async_trait]
impl<S> FromRequestParts<S> for ExtractUser impl<S> FromRequestParts<S> for ExtractJwtUser
where where
S: Send + Sync, S: Send + Sync,
{ {
@ -25,7 +16,7 @@ where
if let Some(authorization) = parts.headers.get(AUTHORIZATION) { if let Some(authorization) = parts.headers.get(AUTHORIZATION) {
if let Ok(authorization) = authorization.to_str() { if let Ok(authorization) = authorization.to_str() {
let token = authorization.replacen("Bearer ", "", 1); let token = authorization.replacen("Bearer ", "", 1);
match jsonwebtoken::decode::<JwtUser>(&token, &env::jwt_secret().1, &Validation::new(*env::jwt_alg())) { match crate::auth::decode_jwt(&token) {
Ok(claims) => Ok(Self(claims.claims)), Ok(claims) => Ok(Self(claims.claims)),
Err(err) => { Err(err) => {
warn!("token couldn't be decoded: {:?}", err); warn!("token couldn't be decoded: {:?}", err);
@ -37,8 +28,32 @@ where
Err((StatusCode::BAD_REQUEST, "Invalid `AUTHORIZATION` header")) Err((StatusCode::BAD_REQUEST, "Invalid `AUTHORIZATION` header"))
} }
} else { } else {
warn!("missin authorization header"); warn!("missing authorization header");
Err((StatusCode::BAD_REQUEST, "Missing `AUTHORIZATION` header")) Err((StatusCode::BAD_REQUEST, "Missing `AUTHORIZATION` header"))
} }
} }
}
pub struct ExtractUser(pub User);
#[async_trait]
impl FromRequestParts<AppState> for ExtractUser
{
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(parts: &mut Parts, state: &AppState) -> Result<Self, Self::Rejection> {
let jwt_user = ExtractJwtUser::from_request_parts(parts, state).await?;
if let Ok(mut conn) = state.pool.get() {
if let Ok(user) = actions::user(&jwt_user.0.uid, &mut conn) {
Ok(Self(user))
} else {
error!("JWT user does not exist in database");
Err((StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"))
}
} else {
error!("failed to obtain pooled connection");
Err((StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"))
}
}
} }

View file

@ -1,42 +1,36 @@
use std::time::SystemTime; use std::time::SystemTime;
use argon2::{Argon2, PasswordHasher, password_hash::{rand_core::OsRng, SaltString}}; use argon2::{Argon2, PasswordHasher, password_hash::{rand_core::OsRng, SaltString}};
use axum::{extract::State, http::StatusCode, routing::{get, post}, Json, Router}; use axum::{extract::State, http::StatusCode, routing::{get, post}, Json, Router};
use chrono::Utc; use chrono::{Utc, NaiveDateTime};
use diesel::{/*query_dsl::methods::{FindDsl, SelectDsl, FilterDsl},*/ SelectableHelper, RunQueryDsl, ExpressionMethods, QueryDsl, OptionalExtension}; use diesel::{RunQueryDsl, ExpressionMethods};
use jsonwebtoken::Header;
use tracing::{error, info}; use tracing::{error, info};
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use uuid::Uuid; use uuid::Uuid;
use crate::{database::models::User, env, http::extractors::auth::{ExtractUser, JwtUser}, AppState}; use crate::{database::actions, http::extractors::auth::ExtractUser, auth::JwtUser, AppState};
pub fn auth_router() -> Router<AppState> { pub fn auth_router() -> Router<AppState> {
Router::new() Router::new()
.route("/auth/account", get(account)) .route("/auth/account", get(account))
.route("/auth/register", post(register)) .route("/auth/register", post(register))
} }
async fn account(ExtractUser(jwt_user): ExtractUser, State(state): State<AppState>) -> Result<Json<User>, StatusCode> { #[derive(Debug, Serialize)]
use crate::database::schema::users::dsl::users; struct AccountResponse {
id: String,
created_at: NaiveDateTime,
last_connected_at: NaiveDateTime,
email: String,
name: String,
}
if let Ok(mut conn) = state.pool.get() { async fn account(ExtractUser(user): ExtractUser) -> Result<Json<AccountResponse>, StatusCode> {
let user = users Ok(Json(AccountResponse {
.find(jwt_user.uid) id: user.id,
.select(User::as_select()) created_at: user.created_at,
.first(&mut conn); last_connected_at: user.last_connected_at,
email: user.email,
if let Ok(user) = user { name: user.name,
Ok(Json(user)) }))
} else {
error!("JWT user does not exist in database");
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
} else {
error!("failed to obtain pooled connection");
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -56,12 +50,7 @@ async fn register(State(state): State<AppState>, Json(req): Json<RegisterRequest
use crate::database::schema::limits::dsl as limits; use crate::database::schema::limits::dsl as limits;
if let Ok(mut conn) = state.pool.get() { if let Ok(mut conn) = state.pool.get() {
let user = users::users let user = actions::user_by_email(&req.email, &mut conn);
.filter(users::email.eq(&req.email))
.limit(1)
.select(User::as_select())
.first(&mut conn)
.optional();
if user.is_err() { if user.is_err() {
error!("failed to retrieve potential existing user from database: {}, error: {:?}", &req.email, user.err()); error!("failed to retrieve potential existing user from database: {}, error: {:?}", &req.email, user.err());
@ -112,7 +101,7 @@ async fn register(State(state): State<AppState>, Json(req): Json<RegisterRequest
return Err(StatusCode::INTERNAL_SERVER_ERROR); return Err(StatusCode::INTERNAL_SERVER_ERROR);
} }
match jsonwebtoken::encode(&Header::new(*env::jwt_alg()), &JwtUser { match crate::auth::encode_jwt(&JwtUser {
uid: user_id, uid: user_id,
email: req.email, email: req.email,
name: req.name, name: req.name,
@ -120,7 +109,7 @@ async fn register(State(state): State<AppState>, Json(req): Json<RegisterRequest
.duration_since(SystemTime::UNIX_EPOCH) .duration_since(SystemTime::UNIX_EPOCH)
.expect("time went backwards") .expect("time went backwards")
.as_secs() + 180 * 24 * 3600, .as_secs() + 180 * 24 * 3600,
}, &env::jwt_secret().0) { }) {
Ok(token) => Ok(Json(RegisterResponse { token })), Ok(token) => Ok(Json(RegisterResponse { token })),
Err(err) => { Err(err) => {
error!("token couldn't be encoded: {:?}", err); error!("token couldn't be encoded: {:?}", err);

View file

@ -28,6 +28,7 @@ use tokio::time::Duration;
mod database; mod database;
mod env; mod env;
mod http; mod http;
mod auth;
#[derive(Clone)] #[derive(Clone)]
struct AppState { struct AppState {