From d3ec3ed422115ebfdfbb30fa460718fd88146cc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sof=C3=ADa=20Aritz?= Date: Mon, 14 Oct 2024 21:10:31 +0200 Subject: [PATCH] implement /auth/account and /auth/register --- identity-api-rs/Cargo.toml | 13 +- identity-api-rs/src/database/mod.rs | 14 +- identity-api-rs/src/database/models.rs | 16 +-- identity-api-rs/src/env.rs | 21 +-- identity-api-rs/src/http/extractors/auth.rs | 44 +++++++ identity-api-rs/src/http/extractors/mod.rs | 1 + identity-api-rs/src/http/mod.rs | 2 + identity-api-rs/src/http/routes/auth.rs | 139 ++++++++++++++++++++ identity-api-rs/src/http/routes/mod.rs | 1 + identity-api-rs/src/main.rs | 73 +++++++++- 10 files changed, 298 insertions(+), 26 deletions(-) create mode 100644 identity-api-rs/src/http/extractors/auth.rs create mode 100644 identity-api-rs/src/http/extractors/mod.rs create mode 100644 identity-api-rs/src/http/mod.rs create mode 100644 identity-api-rs/src/http/routes/auth.rs create mode 100644 identity-api-rs/src/http/routes/mod.rs diff --git a/identity-api-rs/Cargo.toml b/identity-api-rs/Cargo.toml index 793ad75..3ba8ac9 100644 --- a/identity-api-rs/Cargo.toml +++ b/identity-api-rs/Cargo.toml @@ -4,10 +4,17 @@ version = "0.1.0" edition = "2021" [dependencies] -axum = { version = "0.7" } -chrono = "0.4" -diesel = { version = "2.2", features = ["sqlite", "returning_clauses_for_sqlite_3_35", "chrono"] } +argon2 = "0.5.3" +axum = { version = "0.7", features = ["macros", "tracing"] } +tower-http = { version = "0.6", features = ["trace"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +chrono = { version = "0.4", features = ["serde"] } +diesel = { version = "2.2", features = ["sqlite", "returning_clauses_for_sqlite_3_35", "chrono", "r2d2"] } dotenvy = "0.15" serde = { version = "1", features = ["derive"] } serde_json = "1" +r2d2 = "0.8" +jsonwebtoken = "9" +uuid = { version = "1.10", features = ["v4", "fast-rng"] } tokio = { version = "1", features = ["full"] } \ No newline at end of file diff --git a/identity-api-rs/src/database/mod.rs b/identity-api-rs/src/database/mod.rs index 8dd10b0..22465a7 100644 --- a/identity-api-rs/src/database/mod.rs +++ b/identity-api-rs/src/database/mod.rs @@ -14,15 +14,19 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -use diesel::{Connection, SqliteConnection}; +use diesel::prelude::*; +use diesel::r2d2::ConnectionManager; +use diesel::r2d2::Pool; -use crate::env::database_url; +use crate::env; pub mod models; pub mod schema; pub mod list; -pub fn establish_connection() -> SqliteConnection { - let url = database_url(); - SqliteConnection::establish(url).unwrap_or_else(|_| panic!("failed to connect to {}", url)) + +pub fn create_connection_pool() -> Result>, r2d2::Error> { + let url = env::database_url(); + let manager = ConnectionManager::::new(url); + Pool::builder().build(manager) } diff --git a/identity-api-rs/src/database/models.rs b/identity-api-rs/src/database/models.rs index f506f04..b24cd29 100644 --- a/identity-api-rs/src/database/models.rs +++ b/identity-api-rs/src/database/models.rs @@ -33,7 +33,7 @@ pub struct LocationCoordinates { longitude: f64, } -#[derive(Queryable, Selectable)] +#[derive(Queryable, Selectable, Serialize, Deserialize)] #[diesel(table_name = schema::date_entries)] #[diesel(check_for_backend(diesel::sqlite::Sqlite))] pub struct DateEntry { @@ -41,7 +41,7 @@ pub struct DateEntry { referenced_date: NaiveDateTime, } -#[derive(Queryable, Selectable)] +#[derive(Queryable, Selectable, Serialize, Deserialize)] #[diesel(table_name = schema::entries)] #[diesel(check_for_backend(diesel::sqlite::Sqlite))] pub struct Entry { @@ -58,7 +58,7 @@ pub struct Entry { date_entry: Option, } -#[derive(Queryable, Selectable)] +#[derive(Queryable, Selectable, Serialize, Deserialize)] #[diesel(table_name = schema::heirs)] #[diesel(check_for_backend(diesel::sqlite::Sqlite))] pub struct Heir { @@ -69,7 +69,7 @@ pub struct Heir { email: Option, } -#[derive(Queryable, Selectable)] +#[derive(Queryable, Selectable, Serialize, Deserialize)] #[diesel(table_name = schema::limits)] #[diesel(check_for_backend(diesel::sqlite::Sqlite))] pub struct Limit { @@ -78,7 +78,7 @@ pub struct Limit { max_asset_count: i32, } -#[derive(Queryable, Selectable)] +#[derive(Queryable, Selectable, Serialize, Deserialize)] #[diesel(table_name = schema::location_entries)] #[diesel(check_for_backend(diesel::sqlite::Sqlite))] pub struct LocationEntry { @@ -94,7 +94,7 @@ impl LocationEntry { } } -#[derive(Queryable, Selectable)] +#[derive(Queryable, Selectable, Serialize, Deserialize)] #[diesel(table_name = schema::music_entries)] #[diesel(check_for_backend(diesel::sqlite::Sqlite))] pub struct MusicEntry { @@ -105,7 +105,7 @@ pub struct MusicEntry { universal_ids: List, } -#[derive(Queryable, Selectable)] +#[derive(Queryable, Selectable, Serialize, Deserialize)] #[diesel(table_name = schema::session_keys)] #[diesel(check_for_backend(diesel::sqlite::Sqlite))] pub struct SessionKey { @@ -113,7 +113,7 @@ pub struct SessionKey { user_id: String, } -#[derive(Queryable, Selectable)] +#[derive(Queryable, Selectable, Serialize, Deserialize)] #[diesel(table_name = schema::users)] #[diesel(check_for_backend(diesel::sqlite::Sqlite))] pub struct User { diff --git a/identity-api-rs/src/env.rs b/identity-api-rs/src/env.rs index cfeb51b..1eb7b19 100644 --- a/identity-api-rs/src/env.rs +++ b/identity-api-rs/src/env.rs @@ -14,10 +14,12 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -use std::env; +use std::{env, str::FromStr}; use std::sync::OnceLock; use std::time::Duration; +use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey}; + const REQUIRED_ENV_VARIABLES: [&str; 4] = [ "IDENTITY_API_JWT_SECRET", "IDENTITY_API_ASSET_API_ENDPOINT", @@ -81,18 +83,21 @@ pub fn listen_port() -> &'static u16 { }) } -pub fn jwt_secret() -> &'static str { - static IDENTITY_API_JWT_SECRET: OnceLock = OnceLock::new(); +pub fn jwt_secret() -> &'static (EncodingKey, DecodingKey) { + static IDENTITY_API_JWT_SECRET: OnceLock<(EncodingKey, DecodingKey)> = OnceLock::new(); IDENTITY_API_JWT_SECRET.get_or_init(|| { - env::var("IDENTITY_API_JWT_SECRET") - .expect("environment variables were not loaded correctly") + let secret = env::var("IDENTITY_API_JWT_SECRET") + .expect("environment variables were not loaded correctly"); + + (EncodingKey::from_secret(secret.as_bytes()), DecodingKey::from_secret(secret.as_bytes())) }) } -pub fn jwt_alg() -> &'static str { - static IDENTITY_API_JWT_ALG: OnceLock = OnceLock::new(); +pub fn jwt_alg() -> &'static Algorithm { + static IDENTITY_API_JWT_ALG: OnceLock = OnceLock::new(); IDENTITY_API_JWT_ALG.get_or_init(|| { - env::var("IDENTITY_API_JWT_ALG").expect("environment variables were not loaded correctly") + let algo = env::var("IDENTITY_API_JWT_ALG").expect("environment variables were not loaded correctly"); + Algorithm::from_str(&algo).expect("invalid JWT algorithm") }) } diff --git a/identity-api-rs/src/http/extractors/auth.rs b/identity-api-rs/src/http/extractors/auth.rs new file mode 100644 index 0000000..db3d1ec --- /dev/null +++ b/identity-api-rs/src/http/extractors/auth.rs @@ -0,0 +1,44 @@ +use axum::{async_trait, extract::FromRequestParts, http::{header::AUTHORIZATION, request::Parts, StatusCode}}; +use jsonwebtoken::Validation; +use serde::{Serialize, Deserialize}; +use tracing::{warn, info}; + +use crate::env; + +#[derive(Debug, Serialize, Deserialize)] +pub struct JwtUser { + pub uid: String, + pub email: String, + pub name: String, + pub exp: u64, +} +pub struct ExtractUser(pub JwtUser); + +#[async_trait] +impl FromRequestParts for ExtractUser +where + S: Send + Sync, +{ + type Rejection = (StatusCode, &'static str); + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + if let Some(authorization) = parts.headers.get(AUTHORIZATION) { + if let Ok(authorization) = authorization.to_str() { + let token = authorization.replacen("Bearer ", "", 1); + match jsonwebtoken::decode::(&token, &env::jwt_secret().1, &Validation::new(*env::jwt_alg())) { + Ok(claims) => Ok(Self(claims.claims)), + Err(err) => { + warn!("token couldn't be decoded: {:?}", err); + Err((StatusCode::UNAUTHORIZED, "Invalid token")) + } + } + } else { + warn!("invalid authorization header: {:?}", authorization); + Err((StatusCode::BAD_REQUEST, "Invalid `AUTHORIZATION` header")) + } + } else { + warn!("missin authorization header"); + Err((StatusCode::BAD_REQUEST, "Missing `AUTHORIZATION` header")) + } + } +} \ No newline at end of file diff --git a/identity-api-rs/src/http/extractors/mod.rs b/identity-api-rs/src/http/extractors/mod.rs new file mode 100644 index 0000000..5696e21 --- /dev/null +++ b/identity-api-rs/src/http/extractors/mod.rs @@ -0,0 +1 @@ +pub mod auth; \ No newline at end of file diff --git a/identity-api-rs/src/http/mod.rs b/identity-api-rs/src/http/mod.rs new file mode 100644 index 0000000..d8d73f8 --- /dev/null +++ b/identity-api-rs/src/http/mod.rs @@ -0,0 +1,2 @@ +pub mod extractors; +pub mod routes; \ No newline at end of file diff --git a/identity-api-rs/src/http/routes/auth.rs b/identity-api-rs/src/http/routes/auth.rs new file mode 100644 index 0000000..f435f18 --- /dev/null +++ b/identity-api-rs/src/http/routes/auth.rs @@ -0,0 +1,139 @@ +use std::time::SystemTime; + +use argon2::{Argon2, PasswordHasher, password_hash::{rand_core::OsRng, SaltString}}; +use axum::{extract::State, http::StatusCode, routing::{get, post}, Json, Router}; +use chrono::Utc; +use diesel::{/*query_dsl::methods::{FindDsl, SelectDsl, FilterDsl},*/ SelectableHelper, RunQueryDsl, ExpressionMethods, QueryDsl, OptionalExtension}; +use jsonwebtoken::Header; +use tracing::{error, info}; +use serde::{Serialize, Deserialize}; +use uuid::Uuid; +use crate::{database::models::User, database::list::List, env, http::extractors::auth::{ExtractUser, JwtUser}, AppState}; + +pub fn auth_router() -> Router { + let router = Router::new() + .route("/auth/account", get(account)) + .route("/auth/register", post(register)); + + router +} + +async fn account(ExtractUser(jwt_user): ExtractUser, State(state): State) -> Result, StatusCode> { + use crate::database::schema::users::dsl::users; + + if let Ok(mut conn) = state.pool.get() { + let user = users + .find(jwt_user.uid) + .select(User::as_select()) + .first(&mut conn); + + if let Ok(user) = user { + Ok(Json(user)) + } else { + error!("JWT user does not exist in database"); + Err(StatusCode::INTERNAL_SERVER_ERROR) + } + } else { + error!("failed to obtain pooled connection"); + Err(StatusCode::INTERNAL_SERVER_ERROR) + } +} + +#[derive(Debug, Deserialize)] +struct RegisterRequest { + email: String, + password: String, + name: String, +} + +#[derive(Debug, Serialize)] +struct RegisterResponse { + token: String, +} + +async fn register(State(state): State, Json(req): Json) -> Result, StatusCode> { + use crate::database::schema::users::dsl as users; + use crate::database::schema::limits::dsl as limits; + + if let Ok(mut conn) = state.pool.get() { + let user = users::users + .filter(users::email.eq(&req.email)) + .limit(1) + .select(User::as_select()) + .first(&mut conn) + .optional(); + + if user.is_err() { + error!("failed to retrieve potential existing user from database: {}, error: {:?}", &req.email, user.err()); + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + + if user.is_ok_and(|v| v.is_some()) { + info!("tried to register existing user: {}", &req.email); + return Err(StatusCode::BAD_REQUEST); + } + + let limit_id = Uuid::new_v4().to_string(); + let result = diesel::insert_into(limits::limits) + .values(( + limits::id.eq(&limit_id), + limits::current_asset_count.eq(0), + limits::max_asset_count.eq(10), + )) + .execute(&mut conn); + + if result.is_err() { + error!("failed to insert into limits: {}, error: {:?}", &req.email, result.err()); + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + + let salt = SaltString::generate(&mut OsRng); + let argon2 = Argon2::default(); + let password_hash = argon2.hash_password(req.password.as_bytes(), &salt); + + if let Ok(password_hash) = password_hash { + let user_id = Uuid::new_v4().to_string(); + let result = diesel::insert_into(users::users) + .values(( + users::id.eq(&user_id), + users::created_at.eq(Utc::now().naive_utc()), + users::last_connected_at.eq(Utc::now().naive_utc()), + users::email.eq(&req.email), + users::password.eq(password_hash.to_string()), + users::name.eq(&req.name), + users::limits.eq(&limit_id), + // TODO: Implement diesel::Expression for List + users::assets.eq("[]"), + )) + .execute(&mut conn); + + if result.is_err() { + error!("failed to insert into users: {}, error: {:?}", req.email, result.err()); + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + + match jsonwebtoken::encode(&Header::new(*env::jwt_alg()), &JwtUser { + uid: user_id, + email: req.email, + name: req.name, + exp: SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .expect("time went backwards") + .as_secs() + 180 * 24 * 3600, + }, &env::jwt_secret().0) { + Ok(token) => Ok(Json(RegisterResponse { token })), + Err(err) => { + error!("token couldn't be encoded: {:?}", err); + Err(StatusCode::INTERNAL_SERVER_ERROR) + } + } + } else { + error!("failed to hash password: {:?}", password_hash.err()); + Err(StatusCode::INTERNAL_SERVER_ERROR) + } + + } else { + error!("failed to obtain pooled connection"); + Err(StatusCode::INTERNAL_SERVER_ERROR) + } +} \ No newline at end of file diff --git a/identity-api-rs/src/http/routes/mod.rs b/identity-api-rs/src/http/routes/mod.rs new file mode 100644 index 0000000..5696e21 --- /dev/null +++ b/identity-api-rs/src/http/routes/mod.rs @@ -0,0 +1 @@ +pub mod auth; \ No newline at end of file diff --git a/identity-api-rs/src/main.rs b/identity-api-rs/src/main.rs index d50c277..bdbefdc 100644 --- a/identity-api-rs/src/main.rs +++ b/identity-api-rs/src/main.rs @@ -14,11 +14,25 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -use axum::{routing::get, Router}; +use axum::{extract::{MatchedPath, Request}, response::Response, routing::get, Router}; +use database::create_connection_pool; +use diesel::{r2d2::ConnectionManager, SqliteConnection}; use env::LoadEnvError; +use http::routes::auth::auth_router; +use r2d2::Pool; +use tower_http::{classify::ServerErrorsFailureClass, trace::TraceLayer}; +use tracing::{info, info_span, warn, error, Span}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +use tokio::time::Duration; mod database; mod env; +mod http; + +#[derive(Clone)] +struct AppState { + pool: Pool>, +} #[tokio::main] async fn main() { @@ -29,7 +43,62 @@ async fn main() { } }); - let app = Router::new().route("/", get(landing)); + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + // axum logs rejections from built-in extractors with the `axum::rejection` + // target, at `TRACE` level. `axum::rejection=trace` enables showing those events + format!( + "{}=debug,tower_http=debug,axum::rejection=trace", + env!("CARGO_CRATE_NAME") + ) + .into() + }), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let state = AppState { + pool: create_connection_pool().expect("failed to create database connection pool"), + }; + + let app = Router::new() + .route("/", get(landing)) + .merge(auth_router()) + .with_state(state) + .layer( + TraceLayer::new_for_http() + .make_span_with(|request: &Request<_>| { + let matched_path = request + .extensions() + .get::() + .map(MatchedPath::as_str); + + info_span!( + "http_request", + method = ?request.method(), + matched_path, + ) + }) + .on_response(|response: &Response, _latency: Duration, _span: &Span| { + if response.status().is_client_error() { + warn!( + "client error: {}", + response.status().to_string() + ); + } else { + info!("finished processing request"); + } + }) + .on_failure( + |error: ServerErrorsFailureClass, _latency: Duration, _span: &Span| { + error!( + "internal server error: {}", + error.to_string(), + ); + }, + ), + ); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await