based/src/auth/user.rs

275 lines
7.5 KiB
Rust
Raw Normal View History

2024-12-18 14:33:53 +01:00
use rocket::{Request, http::Status, outcome::Outcome, request::FromRequest};
2024-12-17 23:28:43 +01:00
use serde::{Deserialize, Serialize};
use serde_json::json;
use sqlx::FromRow;
2024-12-18 14:33:53 +01:00
use super::{Session, gen_token};
2024-12-17 23:28:43 +01:00
use crate::{get_pg, request::api::ToAPI};
2024-12-18 19:55:21 +01:00
/// User
///
/// # Example:
///
/// ```ignore
///
/// // Needs login
/// #[get("/myaccount")]
/// pub async fn account_page(ctx: RequestContext, user: User) -> StringResponse {
/// ...
/// }
/// ```
2024-12-17 23:28:43 +01:00
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct User {
/// The username chosen by the user
pub username: String,
/// The hashed password for the user
pub password: String,
/// The role of the user
pub user_role: UserRole,
}
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::Type)]
#[sqlx(type_name = "user_role", rename_all = "lowercase")]
pub enum UserRole {
/// A regular user with limited permissions
Regular,
/// An admin user with full system privileges
Admin,
}
impl User {
2024-12-18 14:33:53 +01:00
// Get a user from session ID
pub async fn from_session(session: &str) -> Option<Self> {
sqlx::query_as("SELECT * FROM users WHERE username = (SELECT \"user\" FROM user_session WHERE token = $1)").bind(session).fetch_optional(get_pg!()).await.unwrap()
}
2024-12-17 23:28:43 +01:00
/// Find a user by their username
pub async fn find(username: &str) -> Option<Self> {
sqlx::query_as("SELECT * FROM users WHERE username = $1")
.bind(username)
.fetch_optional(get_pg!())
.await
.unwrap()
}
/// Create a new user with the given details
///
/// Returns an Option containing the created user, or None if a user already exists with the same username
pub async fn create(username: &str, password: &str, role: UserRole) -> Option<Self> {
// Check if a user already exists with the same username
if Self::find(username).await.is_some() {
return None;
}
let u = Self {
username: username.to_string(),
password: bcrypt::hash(password, bcrypt::DEFAULT_COST).unwrap(),
user_role: role,
};
sqlx::query("INSERT INTO users (username, \"password\", user_role) VALUES ($1, $2, $3)")
.bind(&u.username)
.bind(&u.password)
.bind(&u.user_role)
.execute(get_pg!())
.await
.unwrap();
Some(u)
}
/// Login a user with the given username and password
pub async fn login(username: &str, password: &str) -> Option<(Session, UserRole)> {
let u = Self::find(username).await?;
if !u.verify_pw(password) {
return None;
}
Some((u.session().await, u.user_role))
}
/// Change the password of a User
///
/// Returns a Result indicating whether the password change was successful or not
pub async fn passwd(self, old: &str, new: &str) -> Result<(), ()> {
if self.verify_pw(old) {
sqlx::query("UPDATE users SET \"password\" = $1 WHERE username = $2;")
.bind(bcrypt::hash(new, bcrypt::DEFAULT_COST).unwrap())
.bind(&self.username)
.fetch_one(get_pg!())
.await
.unwrap();
return Ok(());
}
Err(())
}
/// Find all users in the system
pub async fn find_all() -> Vec<Self> {
sqlx::query_as("SELECT * FROM users")
.fetch_all(get_pg!())
.await
.unwrap()
}
/// Generate a new session token for the user
///
/// Returns a Session instance containing the generated token and associated user
pub async fn session(&self) -> Session {
sqlx::query_as(
"INSERT INTO user_session (token, \"user\") VALUES ($1, $2) RETURNING id, token, \"user\"",
)
.bind(gen_token(64))
.bind(&self.username)
.fetch_one(get_pg!())
.await
.unwrap()
}
/// Check if the user is an admin
pub const fn is_admin(&self) -> bool {
matches!(self.user_role, UserRole::Admin)
}
/// Verify that a provided password matches the hashed password for the user
///
/// Returns a boolean indicating whether the passwords match or not
pub fn verify_pw(&self, password: &str) -> bool {
bcrypt::verify(password, &self.password).unwrap()
}
}
impl ToAPI for User {
async fn api(&self) -> serde_json::Value {
json!({
"username": self.username,
"role": self.user_role
})
}
}
2024-12-18 19:55:21 +01:00
async fn extract_user<'r>(request: &'r Request<'_>) -> Option<User> {
2024-12-18 20:18:59 +01:00
if let Some(session_id) = request.cookies().get("session") {
2024-12-18 19:55:21 +01:00
if let Some(user) = User::from_session(session_id.value()).await {
return Some(user);
} else {
return None;
}
}
match request.headers().get_one("token") {
Some(key) => {
if let Some(user) = User::from_session(key).await {
return Some(user);
} else {
return None;
}
}
None => None,
}
}
2024-12-17 23:28:43 +01:00
#[rocket::async_trait]
impl<'r> FromRequest<'r> for User {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome<Self, Self::Error> {
2024-12-18 19:55:21 +01:00
if let Some(user) = extract_user(request).await {
return Outcome::Success(user);
} else {
return Outcome::Error((Status::Unauthorized, ()));
2024-12-18 14:33:53 +01:00
}
2024-12-18 19:55:21 +01:00
}
}
2024-12-18 14:33:53 +01:00
2024-12-18 19:55:21 +01:00
/// Maybe User?
///
/// This struct extracts a user if possible, but also allows anybody.
///
/// # Example:
///
/// ```ignore
///
/// // Publicly accessable
/// #[get("/")]
/// pub async fn index(ctx: RequestContext, user: MaybeUser) -> StringResponse {
/// match user {
/// MaybeUser::User(user) => println!("You are {}", user.username),
/// MaybeUser::Anonymous => println!("Who are you?")
/// }
/// }
/// ```
pub enum MaybeUser {
User(User),
Anonymous,
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for MaybeUser {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome<Self, Self::Error> {
if let Some(user) = extract_user(request).await {
return Outcome::Success(MaybeUser::User(user));
} else {
return Outcome::Success(MaybeUser::Anonymous);
}
}
}
2024-12-18 20:03:59 +01:00
impl From<MaybeUser> for Option<User> {
fn from(value: MaybeUser) -> Self {
2024-12-22 20:12:52 +01:00
value.take_user()
}
}
impl MaybeUser {
pub fn user(&self) -> Option<&User> {
match self {
MaybeUser::User(user) => Some(user),
MaybeUser::Anonymous => None,
}
}
pub fn take_user(self) -> Option<User> {
match self {
2024-12-18 20:03:59 +01:00
MaybeUser::User(user) => Some(user),
MaybeUser::Anonymous => None,
}
}
}
2024-12-18 19:55:21 +01:00
/// Admin User
///
/// This struct expects an Admin User and returns `Forbidden` otherwise.
///
/// # Example:
///
/// ```ignore
///
/// // Only admin users can access this route
/// #[get("/admin")]
/// pub async fn admin_panel(ctx: RequestContext, user: AdminUser) -> StringResponse {
/// ...
/// }
/// ```
pub struct AdminUser(User);
#[rocket::async_trait]
impl<'r> FromRequest<'r> for AdminUser {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome<Self, Self::Error> {
if let Some(user) = extract_user(request).await {
if user.is_admin() {
return Outcome::Success(AdminUser(user));
2024-12-17 23:28:43 +01:00
}
2024-12-18 19:55:21 +01:00
} else {
2024-12-17 23:28:43 +01:00
}
2024-12-18 19:55:21 +01:00
Outcome::Error((Status::Unauthorized, ()))
2024-12-17 23:28:43 +01:00
}
}