diff --git a/.gitignore b/.gitignore index 4fc2701..f7379f7 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,3 @@ *.db *.db-wal *.db-shm -/.direnv \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index d7869b6..a229ab3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1930,6 +1930,7 @@ checksum = "ee6798b1838b6a0f69c007c133b8df5866302197e404e8b6ee8ed3e3a5e68dc6" dependencies = [ "base64", "bytes", + "chrono", "crc", "crossbeam-queue", "either", @@ -1986,6 +1987,8 @@ dependencies = [ "serde_json", "sha2", "sqlx-core", + "sqlx-mysql", + "sqlx-postgres", "sqlx-sqlite", "syn 2.0.111", "tokio", @@ -2003,6 +2006,7 @@ dependencies = [ "bitflags", "byteorder", "bytes", + "chrono", "crc", "digest", "dotenvy", @@ -2023,6 +2027,7 @@ dependencies = [ "percent-encoding", "rand 0.8.5", "rsa", + "serde", "sha1", "sha2", "smallvec", @@ -2043,6 +2048,7 @@ dependencies = [ "base64", "bitflags", "byteorder", + "chrono", "crc", "dotenvy", "etcetera", @@ -2077,6 +2083,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2d12fe70b2c1b4401038055f90f151b78208de1f9f89a7dbfd41587a10c3eea" dependencies = [ "atoi", + "chrono", "flume", "futures-channel", "futures-core", diff --git a/database/Cargo.toml b/database/Cargo.toml index e8141ea..374258e 100644 --- a/database/Cargo.toml +++ b/database/Cargo.toml @@ -17,6 +17,7 @@ sqlx = { version = "0.8.6", default-features = false, features = [ "sqlite", "migrate", "macros", + "chrono" ] } rand = {version = "0.9.2", default-features = false} base-62 = {version = "0.1.1", default-features = false} diff --git a/database/src/models/achievement.rs b/database/src/models/achievement.rs index 012f2b0..bc53d23 100644 --- a/database/src/models/achievement.rs +++ b/database/src/models/achievement.rs @@ -35,7 +35,7 @@ pub struct AchievementGoalUnlock { pub goal_id: i32, pub goal_description: String, pub goal_sequence: i32, - pub unlocked_at: Option>, + pub time: DateTime, } #[derive(Serialize, Deserialize)] diff --git a/database/src/repos/achievement.rs b/database/src/repos/achievement.rs index 877636c..66fba61 100644 --- a/database/src/repos/achievement.rs +++ b/database/src/repos/achievement.rs @@ -1,8 +1,8 @@ -use sqlx::{SqlitePool, query, query_as}; +use sqlx::{SqlitePool, query, query_as, query_scalar}; use crate::{ error::DatabaseError, - models::achievement::{Achievement, AchievementCreate, AchievementGoal}, + models::achievement::{Achievement, AchievementCreate, AchievementGoal, AchievementGoalUnlock}, }; pub struct AchievementRepo<'a> { @@ -40,6 +40,36 @@ impl<'a> AchievementRepo<'a> { .await?) } + pub async fn by_unlocked_goal_id( + &self, + user_id: u32, + goal_id: u32, + ) -> Result, DatabaseError> { + Ok(query_as( + " + SELECT + achievement.id as achievement_id, + achievement.name as achievement_name, + service_id, + goal.id as goal_id, + goal.description as goal_description, + goal.sequence as goal_sequence, + time + FROM + goal as goal1 + INNER JOIN achievement on achievement.id = goal1.achievement_id + INNER JOIN goal on goal.achievement_id = achievement.id + INNER JOIN unlock on goal1.id = unlock.goal_id + WHERE + goal1.id = ? AND user_id = ?; + ", + ) + .bind(goal_id) + .bind(user_id) + .fetch_all(self.db) + .await?) + } + pub async fn for_service( &self, service_id: u32, @@ -119,4 +149,87 @@ impl<'a> AchievementRepo<'a> { tx.commit().await?; self.by_id(db_achievement.id).await } + + pub async fn unlock_goal( + &self, + user_id: u32, + goal_id: u32, + ) -> Result, DatabaseError> { + query( + " + INSERT INTO + unlock (user_id, goal_id) + VALUES + (?,?); + ", + ) + .bind(user_id) + .bind(goal_id) + .execute(self.db) + .await?; + + self.by_unlocked_goal_id(user_id, goal_id).await + } + + pub async fn goal_exist(&self, goal_id: u32) -> Result { + Ok(query_scalar::<_, i32>( + " + SELECT + 1 + FROM + goal + WHERE + goal.id = ?; + ", + ) + .bind(goal_id) + .fetch_optional(self.db) + .await? + .is_some()) + } + + pub async fn goal_unlocked(&self, goal_id: u32) -> Result { + Ok(query_scalar::<_, i32>( + " + SELECT + 1 + FROM + unlock + WHERE + goal_id = ?; + ", + ) + .bind(goal_id) + .fetch_optional(self.db) + .await? + .is_some()) + } + + pub async fn unlocked_for_user( + &self, + user_id: u32, + ) -> Result, DatabaseError> { + Ok(query_as( + "SELECT + achievement.id as achievement_id, + name as achievement_name, + service_id, + goal.id as goal_id, + description as goal_description, + sequence as goal_sequence, + time + FROM + unlock + INNER JOIN goal ON goal_id = goal.id + INNER JOIN achievement ON achievement_id = achievement.id + WHERE + user_id = ? + ORDER BY + achievement_id, goal_sequence; + ", + ) + .bind(user_id) + .fetch_all(self.db) + .await?) + } } diff --git a/database/src/repos/service.rs b/database/src/repos/service.rs index d44b0e2..297a782 100644 --- a/database/src/repos/service.rs +++ b/database/src/repos/service.rs @@ -74,4 +74,12 @@ impl<'a> ServiceRepo<'a> { .await? .ok_or(DatabaseError::NotFound) } + + pub async fn by_id(&self, id: u32) -> Result { + sqlx::query_as("SELECT id, name, api_key FROM service WHERE id == ? LIMIT 1;") + .bind(id) + .fetch_optional(self.db) + .await? + .ok_or(DatabaseError::NotFound) + } } diff --git a/src/dto/achievement.rs b/src/dto/achievement.rs index 9046bfb..1ef6ab8 100644 --- a/src/dto/achievement.rs +++ b/src/dto/achievement.rs @@ -1,8 +1,8 @@ -use std::iter::Peekable; +use std::iter::from_fn; use database::{ Database, - models::achievement::{AchievementCreate, AchievementGoal}, + models::achievement::{AchievementCreate, AchievementGoal, AchievementGoalUnlock, GoalCreate}, }; use serde::{Deserialize, Serialize}; @@ -18,21 +18,27 @@ pub struct AchievementPayload { pub goals: Vec, } +impl From for AchievementPayload { + fn from(row: AchievementGoal) -> Self { + Self { + id: row.achievement_id, + name: row.achievement_name, + goals: vec![GoalPayload { + id: row.goal_id, + description: row.goal_description, + sequence: row.goal_sequence, + }], + } + } +} + impl AchievementPayload { pub async fn for_service( db: &Database, service_id: u32, ) -> Result, AppError> { let rows = db.achievements().for_service(service_id).await?; - - let mut rows = rows.into_iter().peekable(); - - let mut achievements = Vec::new(); - while let Some(achievement) = unpack_next_achievement(&mut rows) { - achievements.push(achievement); - } - - Ok(achievements) + Ok(unpack_achievements(rows).collect()) } } @@ -43,8 +49,54 @@ pub struct AchievementUnlockedPayload { pub goals: Vec, } +impl From for AchievementUnlockedPayload { + fn from(row: AchievementGoalUnlock) -> Self { + Self { + id: row.achievement_id, + name: row.achievement_name, + goals: vec![GoalUnlockedPayload { + id: row.goal_id, + description: row.goal_description, + sequence: row.goal_sequence, + time: row.time, + }], + } + } +} + impl AchievementUnlockedPayload { - // TODO unlock goal + pub async fn for_user( + db: &Database, + user_id: u32, + ) -> Result, AppError> { + let rows = db.achievements().unlocked_for_user(user_id).await?; + + let rows = rows.into_iter().peekable(); + + Ok(unpack_achievements(rows).collect()) + } + + pub async fn unlock_goal( + db: &Database, + user_id: u32, + goal_id: u32, + ) -> Result { + if !db.achievements().goal_exist(goal_id).await? { + return Err(AppError::NotFound); + } + + // FIXME improve + let rows = if db.achievements().goal_unlocked(goal_id).await? { + // goal already unlocked + db.achievements() + .by_unlocked_goal_id(user_id, goal_id) + .await? + } else { + db.achievements().unlock_goal(user_id, goal_id).await? + }; + + unpack_achievements(rows).next().ok_or(AppError::NotFound) + } } #[derive(Serialize, Deserialize)] @@ -64,7 +116,7 @@ impl AchievementCreatePayload { } self.goals.sort_by_key(|x| x.sequence); - let ordered_1_seperated = self + let ordered_1_separated = self .goals .iter() .map(|x| x.sequence) @@ -75,7 +127,7 @@ impl AchievementCreatePayload { _ => false, }); if let Some(goal) = self.goals.first() - && (goal.sequence != 0 || !ordered_1_seperated) + && (goal.sequence != 0 || !ordered_1_separated) { return Err(AppError::PayloadError( "Sequence should start with 0 and count up by 1".into(), @@ -88,51 +140,74 @@ impl AchievementCreatePayload { service_id, AchievementCreate { name: self.name, - goals: self.goals.into_iter().map(|x| x.into()).collect(), + goals: self.goals.into_iter().map(GoalCreate::from).collect(), }, ) .await?; - // pack rows into an achievement payload - let mut rows = rows.into_iter().peekable(); - let achievement = unpack_next_achievement(&mut rows).ok_or(AppError::NotFound)?; - Ok(achievement) + unpack_achievements(rows).next().ok_or(AppError::NotFound) } } -/// unpacks an achievement from database rows into a payload -fn unpack_next_achievement(rows: &mut Peekable) -> Option +pub trait AchievementRow: Sized { + type Payload: From; + + fn achievement_id(&self) -> i32; + + fn push_into(self, payload: &mut Self::Payload); +} + +impl AchievementRow for AchievementGoal { + type Payload = AchievementPayload; + + fn achievement_id(&self) -> i32 { + self.achievement_id + } + + fn push_into(self, payload: &mut Self::Payload) { + payload.goals.push(GoalPayload { + id: self.goal_id, + description: self.goal_description, + sequence: self.goal_sequence, + }); + } +} + +impl AchievementRow for AchievementGoalUnlock { + type Payload = AchievementUnlockedPayload; + + fn achievement_id(&self) -> i32 { + self.achievement_id + } + + fn push_into(self, payload: &mut Self::Payload) { + payload.goals.push(GoalUnlockedPayload { + id: self.goal_id, + description: self.goal_description, + sequence: self.goal_sequence, + time: self.time, + }); + } +} + +// group rows by achievement id and return an iterator of achievements +fn unpack_achievements(rows: I) -> impl Iterator where - I: Iterator, + I: IntoIterator, + R: AchievementRow, { - // get first row - let row = rows.next()?; - - // make a new achievement with the first goal - let mut achievement = AchievementPayload { - id: row.achievement_id, - name: row.achievement_name, - goals: vec![GoalPayload { - id: row.goal_id, - description: row.goal_description, - sequence: row.goal_sequence, - }], - }; - - // add all following goals for the same achievement - while let Some(next_row) = rows.peek() { - if next_row.achievement_id != achievement.id { - break; - } + let mut iter = rows.into_iter().peekable(); - if let Some(next_goal) = rows.next() { - achievement.goals.push(GoalPayload { - id: next_goal.goal_id, - description: next_goal.goal_description, - sequence: next_goal.goal_sequence, - }); + from_fn(move || { + let first_row = iter.next()?; + let current_id = first_row.achievement_id(); + + let mut achievement: R::Payload = first_row.into(); + // pack all goals for this achievement into the achievement + while let Some(next_row) = iter.next_if(|r| r.achievement_id() == current_id) { + next_row.push_into(&mut achievement); } - } - Some(achievement) + Some(achievement) + }) } diff --git a/src/dto/goal.rs b/src/dto/goal.rs index 10428da..d726198 100644 --- a/src/dto/goal.rs +++ b/src/dto/goal.rs @@ -14,7 +14,7 @@ pub struct GoalUnlockedPayload { pub id: i32, pub description: String, pub sequence: i32, - pub unlocked_at: DateTime, + pub time: DateTime, } #[derive(Serialize, Deserialize)] diff --git a/src/dto/service.rs b/src/dto/service.rs index f12ca7e..a277b6d 100644 --- a/src/dto/service.rs +++ b/src/dto/service.rs @@ -30,7 +30,7 @@ impl ServicePayloadAdmin { .all() .await? .into_iter() - .map(|service| service.into()) + .map(Self::from) .collect()) } @@ -63,7 +63,7 @@ impl ServicePayloadUser { .all() .await? .into_iter() - .map(|service| service.into()) + .map(Self::from) .collect()) } } diff --git a/src/dto/user.rs b/src/dto/user.rs index 2608e7d..902e696 100644 --- a/src/dto/user.rs +++ b/src/dto/user.rs @@ -1,10 +1,11 @@ use database::{ Database, - error::DatabaseError, models::{tag::Tag, user::UserPatch}, }; use serde::{Deserialize, Serialize}; +use crate::{dto::achievement::AchievementUnlockedPayload, error::AppError}; + #[derive(Debug, Serialize, Deserialize)] pub struct UserPatchPayload { pub about: String, @@ -22,6 +23,7 @@ pub struct UserProfile { pub username: String, pub about: String, pub tags: Vec, + pub achievements: Vec, } pub enum UserId { @@ -45,18 +47,20 @@ impl From for UserId { } impl UserProfile { - pub async fn get(db: &Database, user_id: UserId) -> Result { + pub async fn get(db: &Database, user_id: UserId) -> Result { let user = match user_id { UserId::Username(username) => db.users().by_username(username).await?, UserId::Id(id) => db.users().by_id(id).await?, }; let tags = db.tags().for_user(user.id).await?; + let achievements = AchievementUnlockedPayload::for_user(db, user.id).await?; Ok(UserProfile { id: user.id, username: user.username, about: user.about, tags, + achievements, }) } } diff --git a/src/error.rs b/src/error.rs index 6088190..af3f6b7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -48,7 +48,7 @@ pub enum AppError { #[error("Submitted image resolution was too large")] ImageResTooLarge, - #[error("The requested image was not found")] + #[error("Not found")] NotFound, #[error("Submitted file had an incorrect type")] @@ -60,6 +60,9 @@ pub enum AppError { #[error("User was not logged in")] NotLoggedIn, + #[error("Wrong api key")] + BadApiKey, + #[error("Forbidden")] Forbidden, @@ -80,6 +83,7 @@ impl AppError { let (status, msg) = match self { Self::PayloadError(_) => (StatusCode::BAD_REQUEST, "Payload error"), Self::NotLoggedIn => (StatusCode::UNAUTHORIZED, "Not logged in."), + Self::BadApiKey => (StatusCode::UNAUTHORIZED, "Bad api key."), Self::Forbidden => (StatusCode::FORBIDDEN, "Forbidden."), Self::NoFile => ( StatusCode::BAD_REQUEST, diff --git a/src/extractors/api_key.rs b/src/extractors/api_key.rs new file mode 100644 index 0000000..26af31b --- /dev/null +++ b/src/extractors/api_key.rs @@ -0,0 +1,24 @@ +use axum::{extract::FromRequestParts, http::request::Parts}; +use axum_extra::TypedHeader; +use headers::{Authorization, authorization::Bearer}; + +use crate::error::AppError; + +#[derive(Debug)] +pub struct ApiKey(pub String); + +impl FromRequestParts for ApiKey +where + S: Send + Sync, +{ + type Rejection = AppError; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let header = TypedHeader::>::from_request_parts(parts, state).await; + + match header { + Ok(TypedHeader(Authorization(bearer))) => Ok(ApiKey(bearer.token().to_string())), + _ => Err(AppError::BadApiKey), + } + } +} diff --git a/src/extractors/mod.rs b/src/extractors/mod.rs index 128e91e..d8be7bf 100644 --- a/src/extractors/mod.rs +++ b/src/extractors/mod.rs @@ -1,4 +1,5 @@ pub mod admin; +pub mod api_key; pub mod authenticated_user; pub mod config; pub mod database; diff --git a/src/handlers/service.rs b/src/handlers/service.rs index 89e47c1..7617831 100644 --- a/src/handlers/service.rs +++ b/src/handlers/service.rs @@ -2,10 +2,14 @@ use axum::{Json, extract::Path}; use database::Database; use crate::{ - dto::service::{ - ServiceCreatePayload, ServicePatchPayload, ServicePayloadAdmin, ServicePayloadUser, + dto::{ + achievement::AchievementUnlockedPayload, + service::{ + ServiceCreatePayload, ServicePatchPayload, ServicePayloadAdmin, ServicePayloadUser, + }, }, error::AppError, + extractors::api_key::ApiKey, }; pub struct ServiceHandler; @@ -42,4 +46,19 @@ impl ServiceHandler { ServicePayloadAdmin::regenerate_api_key(&db, service_id).await?, )) } + + pub async fn unlock_goal( + db: Database, + Path((user_id, service_id, goal_id)): Path<(u32, u32, u32)>, + ApiKey(api_key): ApiKey, + ) -> Result, AppError> { + let expected_api_key = db.services().by_id(service_id).await?.api_key; + if api_key != expected_api_key { + return Err(AppError::BadApiKey); + } + + Ok(Json( + AchievementUnlockedPayload::unlock_goal(&db, user_id, goal_id).await?, + )) + } } diff --git a/src/lib.rs b/src/lib.rs index 507b052..2b069c6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -78,6 +78,10 @@ fn open_routes() -> Router { .route("/oauth/callback", get(AuthHandler::callback)) .route("/image/{id}", get(ImageHandler::get)) .route("/version", get(VersionHandler::get)) + .route( + "/users/{id}/unlock/{service_id}/{goal_id}", + post(ServiceHandler::unlock_goal), + ) } fn authenticated_routes() -> Router { diff --git a/tests/achievement.rs b/tests/achievement.rs index 8f087b3..125051f 100644 --- a/tests/achievement.rs +++ b/tests/achievement.rs @@ -5,22 +5,26 @@ use zpi::dto::{ goal::GoalCreatePayload, }; -use crate::common::{ - into_struct::IntoStruct, router::AuthenticatedRouter, test_objects::TestObjects, -}; +use crate::common::{into_struct::IntoStruct, router::TestRouter, test_objects::TestObjects}; mod common; #[sqlx::test(fixtures("services", "achievements"))] #[test_log::test] -async fn get_achievements_for_service(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; - let response = router.get("/admin/services/1/achievements").await; +async fn get_achievements_for_service(db: SqlitePool) { + let none = TestRouter::new(db.clone()); + let response = none.get("/admin/services/1/achievements").await; + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + + let user = TestRouter::as_user(db.clone()).await; + let response = user.get("/admin/services/1/achievements").await; + assert_eq!(response.status(), StatusCode::FORBIDDEN); + let admin = TestRouter::as_admin(db).await; + let response = admin.get("/admin/services/1/achievements").await; assert_eq!(response.status(), StatusCode::OK); let data: Vec = response.into_struct().await; - assert_eq!( data, vec![TestObjects::achievement_1(), TestObjects::achievement_2()] @@ -29,8 +33,7 @@ async fn get_achievements_for_service(db_pool: SqlitePool) { #[sqlx::test(fixtures("services"))] #[test_log::test] -async fn post_achievements_for_service(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; +async fn post_achievements_for_service(db: SqlitePool) { let body = AchievementCreatePayload { name: "Achievements".into(), goals: vec![ @@ -44,19 +47,26 @@ async fn post_achievements_for_service(db_pool: SqlitePool) { }, ], }; - let response = router.post("/admin/services/1/achievements", body).await; + let none = TestRouter::new(db.clone()); + let response = none.post("/admin/services/1/achievements", &body).await; + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + + let user = TestRouter::as_user(db.clone()).await; + let response = user.post("/admin/services/1/achievements", &body).await; + assert_eq!(response.status(), StatusCode::FORBIDDEN); + + let admin = TestRouter::as_admin(db).await; + let response = admin.post("/admin/services/1/achievements", &body).await; assert_eq!(response.status(), StatusCode::OK); let data: AchievementPayload = response.into_struct().await; - assert_eq!(data, TestObjects::achievement_1()); } #[sqlx::test(fixtures("services"))] #[test_log::test] -async fn post_achievements_wrong_sequence(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; +async fn post_achievements_wrong_sequence(db: SqlitePool) { let mut body = AchievementCreatePayload { name: "Achievements".into(), goals: vec![ @@ -71,13 +81,54 @@ async fn post_achievements_wrong_sequence(db_pool: SqlitePool) { ], }; - let response = router - .clone() - .post("/admin/services/1/achievements", &body) - .await; + let router = TestRouter::as_admin(db.clone()).await; + let response = router.post("/admin/services/1/achievements", &body).await; assert_eq!(response.status(), StatusCode::BAD_REQUEST); body.goals[1].sequence = 1; let response = router.post("/admin/services/1/achievements", &body).await; assert_eq!(response.status(), StatusCode::BAD_REQUEST); } + +#[sqlx::test(fixtures("services", "achievements", "users"))] +#[test_log::test] +async fn unlock_goal(db: SqlitePool) { + let none = TestRouter::new(db.clone()); + let response = none.post("/users/1/unlock/1/1", None::<()>).await; + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + + let router = TestRouter::with_api_key(db, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + let response = router.post("/users/1/unlock/1/1", None::<()>).await; + assert_eq!(response.status(), StatusCode::OK); + + let data: AchievementPayload = response.into_struct().await; + assert_eq!(data, TestObjects::achievement_1()); +} + +#[sqlx::test(fixtures("services"))] +#[test_log::test] +async fn unlock_goal_wrong_api_key(db_pool: SqlitePool) { + let router = TestRouter::with_api_key(db_pool, "wrongapikey"); + + let response = router.post("/users/1/unlock/1/1", None::<()>).await; + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); +} + +#[sqlx::test(fixtures("services"))] +#[test_log::test] +async fn unlock_goal_404(db: SqlitePool) { + let router = TestRouter::with_api_key(db, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + let response = router.post("/users/1/unlock/1/3", None::<()>).await; + assert_eq!(response.status(), StatusCode::NOT_FOUND); +} + +#[sqlx::test(fixtures("services", "users", "achievements", "unlocks"))] +#[test_log::test] +async fn unlock_goal_already_unlocked(db: SqlitePool) { + let router = TestRouter::with_api_key(db, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + let response = router.post("/users/1/unlock/1/1", None::<()>).await; + assert_eq!(response.status(), StatusCode::OK); + + let data: AchievementPayload = response.into_struct().await; + assert_eq!(data, TestObjects::achievement_1()); +} diff --git a/tests/common/router.rs b/tests/common/router.rs index e64f07d..c3b9a52 100644 --- a/tests/common/router.rs +++ b/tests/common/router.rs @@ -1,12 +1,13 @@ use std::{path::PathBuf, sync::Arc}; use axum::{ - Json, Router, + Json, body::Body, http::Request, response::{IntoResponse, Response}, }; use database::Database; +use dotenvy::dotenv; use reqwest::{Method, header}; use serde::Serialize; use sqlx::SqlitePool; @@ -16,35 +17,19 @@ use zpi::{ AppState, api_router, config::AppConfig, extractors::authenticated_user::AuthenticatedUser, }; -#[derive(Clone)] -pub struct AuthenticatedRouter { - router: Router, - cookie: String, +pub struct TestRouter { + router: axum::Router, + store: MemoryStore, + cookie: Option, + api_key: Option, } -impl AuthenticatedRouter { - pub async fn new(db: SqlitePool) -> Self { - let _ = dotenvy::dotenv(); - let store = Arc::new(MemoryStore::default()); - - let session_id = { - let session = Session::new(Some(Id(1)), store.clone(), None); - session - .insert( - "user", - AuthenticatedUser { - id: 1, - username: "cheese".to_string(), - admin: true, - }, - ) - .await - .unwrap(); - session.save().await.unwrap(); - session.id().unwrap() - }; +impl TestRouter { + pub fn new(db: SqlitePool) -> Self { + let _ = dotenv(); + let store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(Arc::into_inner(store).unwrap()) + let session_layer = SessionManagerLayer::new(store.clone()) .with_secure(false) .with_same_site(tower_sessions::cookie::SameSite::Lax); @@ -58,28 +43,64 @@ impl AuthenticatedRouter { Self { router: api_router().layer(session_layer).with_state(state), - cookie: format!("id={}", session_id), + store: store, + cookie: None, + api_key: None, } } + pub async fn as_user(db: SqlitePool) -> Self { + Self::new(db) + .add_to_store(AuthenticatedUser { + id: 1, + username: "cheese".to_string(), + admin: false, + }) + .await + } + + pub async fn as_admin(db: SqlitePool) -> Self { + Self::new(db) + .add_to_store(AuthenticatedUser { + id: 1, + username: "cheese".to_string(), + admin: true, + }) + .await + } + + async fn add_to_store(mut self, user: AuthenticatedUser) -> Self { + let session = Session::new(Some(Id(1)), Arc::new(self.store.clone()), None); + session.insert("user", user).await.unwrap(); + session.save().await.unwrap(); + self.cookie.replace(format!("id={}", session.id().unwrap())); + self + } + + pub fn with_api_key(db: SqlitePool, api_key: &str) -> Self { + let mut router = Self::new(db); + router.api_key = Some("Bearer ".to_string() + api_key); + router + } + /// send a request to an endpoint on this router /// /// must have a leading "/" - pub async fn get(self, path: &str) -> Response { + pub async fn get(&self, path: &str) -> Response { self.request(Method::GET, path, None::<()>).await } /// send a patch request to an endpoint on this router /// /// must have a leading "/" - pub async fn patch(self, path: &str, body: T) -> Response { + pub async fn patch(&self, path: &str, body: T) -> Response { self.request(Method::PATCH, path, Some(body)).await } - /// send a patch request to an endpoint on this router + /// send a post request to an endpoint on this router /// /// must have a leading "/" - pub async fn post(self, path: &str, body: T) -> Response { + pub async fn post(&self, path: &str, body: T) -> Response { self.request(Method::POST, path, Some(body)).await } @@ -87,15 +108,20 @@ impl AuthenticatedRouter { /// /// must have a leading "/" async fn request( - self, + &self, method: Method, path: &str, body: Option, ) -> Response { - let request_builder = Request::builder() - .method(method) - .uri(path) - .header(header::COOKIE, &self.cookie); + let mut request_builder = Request::builder().method(method).uri(path); + + if let Some(api_key) = &self.api_key { + request_builder = request_builder.header(header::AUTHORIZATION, api_key); + } + + if let Some(cookie) = &self.cookie { + request_builder = request_builder.header(header::COOKIE, cookie); + } let request = match body { Some(body) => request_builder @@ -103,42 +129,7 @@ impl AuthenticatedRouter { .body(Json(body).into_response().into_body()), None => request_builder.body(Body::empty()), }; - self.router.oneshot(request.unwrap()).await.unwrap() - } -} -pub struct UnauthenticatedRouter { - router: Router, -} - -impl UnauthenticatedRouter { - pub async fn new(db: SqlitePool) -> Self { - let _ = dotenvy::dotenv(); - let store = MemoryStore::default(); - - let session_layer = SessionManagerLayer::new(store) - .with_secure(false) - .with_same_site(tower_sessions::cookie::SameSite::Lax); - - let mut config = AppConfig::load().unwrap(); - config.image_path = PathBuf::from("./tests/test_images"); - - let state = AppState { - db: Database::new(db), - config, - }; - - Self { - router: api_router().layer(session_layer).with_state(state), - } - } - /// send a request to an endpoint on this router - /// - /// must have a leading "/" - pub async fn get(self, path: &str) -> Response { - self.router - .oneshot(Request::builder().uri(path).body(Body::empty()).unwrap()) - .await - .unwrap() + self.router.clone().oneshot(request.unwrap()).await.unwrap() } } diff --git a/tests/common/test_objects.rs b/tests/common/test_objects.rs index d304790..9b39faf 100644 --- a/tests/common/test_objects.rs +++ b/tests/common/test_objects.rs @@ -1,8 +1,9 @@ +use chrono::{Local, NaiveDateTime, TimeZone}; use database::models::{tag::Tag, user::User}; use zpi::{ dto::{ - achievement::AchievementPayload, - goal::GoalPayload, + achievement::{AchievementPayload, AchievementUnlockedPayload}, + goal::{GoalPayload, GoalUnlockedPayload}, service::{ServicePayloadAdmin, ServicePayloadUser}, user::UserProfile, }, @@ -13,6 +14,14 @@ pub struct TestObjects; impl TestObjects { pub fn authenticated_user_1() -> AuthenticatedUser { + AuthenticatedUser { + id: 1, + username: "cheese".into(), + admin: false, + } + } + + pub fn admin_user_1() -> AuthenticatedUser { AuthenticatedUser { id: 1, username: "cheese".into(), @@ -37,20 +46,51 @@ impl TestObjects { } pub fn user_profile_1() -> UserProfile { + let naive = + NaiveDateTime::parse_from_str("2025-01-01 19:19:20", "%Y-%m-%d %H:%M:%S").unwrap(); + let dt_local = Local.from_local_datetime(&naive).single().unwrap(); + + let naive2 = + NaiveDateTime::parse_from_str("2025-09-16 12:59:21", "%Y-%m-%d %H:%M:%S").unwrap(); + let dt_local2 = Local.from_local_datetime(&naive2).single().unwrap(); + UserProfile { id: 1, username: "cheese".into(), about: "Just a test user, doing its job... and fantasizing about a life outside the test environment.".to_string(), tags: Vec::new(), + achievements: vec![ + AchievementUnlockedPayload { + id: 1, + name: "Achievements".into(), + goals: vec![GoalUnlockedPayload {id : 1, description: String::from("Get 1 achievement"), sequence: 0, time: dt_local}], + }, AchievementUnlockedPayload { + id: 3, + name: "Votes".into(), + goals: vec![GoalUnlockedPayload {id : 4, description: String::from("Vote 1 time"), sequence: 0, time: dt_local2}], + } ], } } pub fn user_profile_2() -> UserProfile { + let naive = + NaiveDateTime::parse_from_str("2025-05-05 12:11:12", "%Y-%m-%d %H:%M:%S").unwrap(); + let dt_local = Local.from_local_datetime(&naive).single().unwrap(); UserProfile { id: 2, username: "wafel".into(), about: "I like cheese.".into(), tags: Self::tags(), + achievements: vec![AchievementUnlockedPayload { + id: 2, + name: "Profile Picture".into(), + goals: vec![GoalUnlockedPayload { + id: 3, + description: "Upload a profile picture".into(), + sequence: 0, + time: dt_local, + }], + }], } } diff --git a/tests/image.rs b/tests/image.rs index b90042d..33894d1 100644 --- a/tests/image.rs +++ b/tests/image.rs @@ -1,13 +1,13 @@ use reqwest::StatusCode; use sqlx::SqlitePool; -use crate::common::router::UnauthenticatedRouter; +use crate::common::router::TestRouter; mod common; #[sqlx::test] async fn get_image_default(db_pool: SqlitePool) { - let router = UnauthenticatedRouter::new(db_pool).await; + let router = TestRouter::new(db_pool); let response = router.get("/image/1").await; assert_eq!(response.status(), StatusCode::OK); @@ -15,7 +15,7 @@ async fn get_image_default(db_pool: SqlitePool) { #[sqlx::test] async fn get_image_placeholder(db_pool: SqlitePool) { - let router = UnauthenticatedRouter::new(db_pool).await; + let router = TestRouter::new(db_pool); let response = router.get("/image/1?placeholder=true").await; assert_eq!(response.status(), StatusCode::OK); @@ -23,7 +23,7 @@ async fn get_image_placeholder(db_pool: SqlitePool) { #[sqlx::test] async fn get_image_no_placeholder_404(db_pool: SqlitePool) { - let router = UnauthenticatedRouter::new(db_pool).await; + let router = TestRouter::new(db_pool); let response = router.get("/image/1?placeholder=false").await; assert_eq!(response.status(), StatusCode::NOT_FOUND); @@ -31,7 +31,7 @@ async fn get_image_no_placeholder_404(db_pool: SqlitePool) { #[sqlx::test] async fn get_image_no_placeholder(db_pool: SqlitePool) { - let router = UnauthenticatedRouter::new(db_pool).await; + let router = TestRouter::new(db_pool); let response = router.get("/image/2?placeholder=false").await; assert_eq!(response.status(), StatusCode::OK); diff --git a/tests/service.rs b/tests/service.rs index 02374d2..5e6f8cf 100644 --- a/tests/service.rs +++ b/tests/service.rs @@ -5,16 +5,14 @@ use zpi::dto::service::{ ServiceCreatePayload, ServicePatchPayload, ServicePayloadAdmin, ServicePayloadUser, }; -use crate::common::{ - into_struct::IntoStruct, router::AuthenticatedRouter, test_objects::TestObjects, -}; +use crate::common::{into_struct::IntoStruct, router::TestRouter, test_objects::TestObjects}; mod common; #[sqlx::test(fixtures("services"))] #[test_log::test] async fn get_all_services_as_admin(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_admin(db_pool).await; let response = router.get("/admin/services").await; assert_eq!(response.status(), StatusCode::OK); @@ -27,7 +25,7 @@ async fn get_all_services_as_admin(db_pool: SqlitePool) { #[sqlx::test(fixtures("services"))] #[test_log::test] async fn get_all_services(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_user(db_pool).await; let response = router.get("/services").await; assert_eq!(response.status(), StatusCode::OK); @@ -45,7 +43,7 @@ struct ApiKey { #[sqlx::test(fixtures("services"))] #[test_log::test] async fn users_dont_see_api_key(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_user(db_pool).await; let response = router.get("/services").await; assert_eq!(response.status(), StatusCode::OK); @@ -58,7 +56,7 @@ async fn users_dont_see_api_key(db_pool: SqlitePool) { #[sqlx::test] #[test_log::test] async fn create_service(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_admin(db_pool).await; let body = ServiceCreatePayload { name: "zpi".to_string(), }; @@ -77,7 +75,7 @@ async fn create_service(db_pool: SqlitePool) { #[test_log::test] async fn patch_service(db_pool: SqlitePool) { let new_name = "gamification2"; - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_admin(db_pool).await; let body = ServicePatchPayload { name: new_name.to_string(), }; @@ -95,7 +93,7 @@ async fn patch_service(db_pool: SqlitePool) { #[sqlx::test(fixtures("services"))] #[test_log::test] async fn regenerate_api_key(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_admin(db_pool).await; let response = router.post("/admin/services/1/apikey", "").await; // empty body assert_eq!(response.status(), StatusCode::OK); diff --git a/tests/tags.rs b/tests/tags.rs deleted file mode 100644 index b3d4b75..0000000 --- a/tests/tags.rs +++ /dev/null @@ -1,22 +0,0 @@ -use reqwest::StatusCode; -use sqlx::SqlitePool; -use zpi::dto::user::UserProfile; - -use crate::common::{ - into_struct::IntoStruct, router::AuthenticatedRouter, test_objects::TestObjects, -}; - -mod common; - -#[sqlx::test(fixtures("users", "tags"))] -#[test_log::test] -async fn get_user_with_tags(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; - let response = router.get("/users/2").await; - - assert_eq!(response.status(), StatusCode::OK); - - let data: UserProfile = response.into_struct().await; - - assert_eq!(data, TestObjects::user_profile_2()) -} diff --git a/tests/user.rs b/tests/user.rs index dabd533..5fabef2 100644 --- a/tests/user.rs +++ b/tests/user.rs @@ -1,20 +1,16 @@ use database::models::user::{User, UserPatch}; use reqwest::StatusCode; use sqlx::SqlitePool; -use zpi::{dto::user::UserProfile, extractors::AuthenticatedUser}; +use zpi::extractors::AuthenticatedUser; -use crate::common::{ - into_struct::IntoStruct, - router::{AuthenticatedRouter, UnauthenticatedRouter}, - test_objects::TestObjects, -}; +use crate::common::{into_struct::IntoStruct, router::TestRouter, test_objects::TestObjects}; mod common; #[sqlx::test] #[test_log::test] async fn get_users_me(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_user(db_pool).await; let response = router.get("/users/me").await; assert_eq!(response.status(), StatusCode::OK); @@ -25,7 +21,7 @@ async fn get_users_me(db_pool: SqlitePool) { #[sqlx::test] #[test_log::test] async fn get_users_me_unauthenticated(db_pool: SqlitePool) { - let router = UnauthenticatedRouter::new(db_pool).await; + let router = TestRouter::new(db_pool); let response = router.get("/users/me").await; assert_eq!(response.status(), StatusCode::UNAUTHORIZED); } @@ -33,7 +29,7 @@ async fn get_users_me_unauthenticated(db_pool: SqlitePool) { #[sqlx::test(fixtures("users"))] #[test_log::test] async fn patch_user(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_user(db_pool).await; let body = UserPatch { about: "Changed about".to_string(), }; @@ -48,47 +44,3 @@ async fn patch_user(db_pool: SqlitePool) { assert_eq!(user_response, expected_user); } - -#[sqlx::test(fixtures("users"))] -#[test_log::test] -async fn get_profile_by_id(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; - let response = router.get("/users/1").await; - assert_eq!(response.status(), StatusCode::OK); - - let user_response: UserProfile = response.into_struct().await; - assert_eq!(user_response, TestObjects::user_profile_1()); -} - -#[sqlx::test] -#[test_log::test] -async fn get_profile_by_id_unauthenticated(db_pool: SqlitePool) { - let router = UnauthenticatedRouter::new(db_pool).await; - let response = router.get("/users/1").await; - assert_eq!(response.status(), StatusCode::UNAUTHORIZED); -} - -#[sqlx::test] -#[test_log::test] -async fn get_profile_404(db_pool: SqlitePool) { - // test getting by id - let router = AuthenticatedRouter::new(db_pool.clone()).await; - let response = router.get("/users/1").await; - assert_eq!(response.status(), StatusCode::NOT_FOUND); - - // test getting by username - let router = AuthenticatedRouter::new(db_pool).await; - let response = router.get("/users/cheese").await; - assert_eq!(response.status(), StatusCode::NOT_FOUND); -} - -#[sqlx::test(fixtures("users"))] -#[test_log::test] -async fn get_profile_by_name(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; - let response = router.get("/users/cheese").await; - assert_eq!(response.status(), StatusCode::OK); - - let user_response: UserProfile = response.into_struct().await; - assert_eq!(user_response, TestObjects::user_profile_1()); -} diff --git a/tests/user_profile.rs b/tests/user_profile.rs new file mode 100644 index 0000000..71c50ea --- /dev/null +++ b/tests/user_profile.rs @@ -0,0 +1,62 @@ +use reqwest::StatusCode; +use sqlx::SqlitePool; +use zpi::dto::user::UserProfile; + +use crate::common::{into_struct::IntoStruct, router::TestRouter, test_objects::TestObjects}; + +mod common; + +#[sqlx::test(fixtures("users", "services", "achievements", "unlocks", "tags"))] +#[test_log::test] +async fn get_profile_by_id(db_pool: SqlitePool) { + let router = TestRouter::as_user(db_pool).await; + let response = router.get("/users/1").await; + assert_eq!(response.status(), StatusCode::OK); + + let user_response: UserProfile = response.into_struct().await; + assert_eq!(user_response, TestObjects::user_profile_1()); +} + +#[sqlx::test(fixtures("users", "services", "achievements", "unlocks", "tags"))] +#[test_log::test] +async fn get_profile_by_id_with_tags(db_pool: SqlitePool) { + let router = TestRouter::as_user(db_pool).await; + let response = router.get("/users/2").await; + assert_eq!(response.status(), StatusCode::OK); + + let user_response: UserProfile = response.into_struct().await; + assert_eq!(user_response, TestObjects::user_profile_2()); +} + +#[sqlx::test] +#[test_log::test] +async fn get_profile_by_id_unauthenticated(db_pool: SqlitePool) { + let router = TestRouter::new(db_pool); + let response = router.get("/users/1").await; + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); +} + +#[sqlx::test] +#[test_log::test] +async fn get_profile_404(db_pool: SqlitePool) { + let router = TestRouter::as_user(db_pool).await; + + // test getting by id + let response = router.get("/users/1").await; + assert_eq!(response.status(), StatusCode::NOT_FOUND); + + // test getting by username + let response = router.get("/users/cheese").await; + assert_eq!(response.status(), StatusCode::NOT_FOUND); +} + +#[sqlx::test(fixtures("users", "services", "achievements", "unlocks"))] +#[test_log::test] +async fn get_profile_by_name(db_pool: SqlitePool) { + let router = TestRouter::as_user(db_pool).await; + let response = router.get("/users/cheese").await; + assert_eq!(response.status(), StatusCode::OK); + + let user_response: UserProfile = response.into_struct().await; + assert_eq!(user_response, TestObjects::user_profile_1()); +}