init
This commit is contained in:
commit
3299d3cc4c
14 changed files with 3920 additions and 0 deletions
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
||||||
|
/target
|
3276
Cargo.lock
generated
Normal file
3276
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
26
Cargo.toml
Normal file
26
Cargo.toml
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
[package]
|
||||||
|
name = "based"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2024"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
env_logger = "0.10.0"
|
||||||
|
hex = "0.4.3"
|
||||||
|
rayon = "1.7.0"
|
||||||
|
regex = "1.9.5"
|
||||||
|
ring = "0.16.20"
|
||||||
|
walkdir = "2.4.0"
|
||||||
|
chrono = { version = "0.4.38", features = ["serde"] }
|
||||||
|
futures = "0.3.30"
|
||||||
|
log = "0.4.20"
|
||||||
|
rocket = { version = "0.5.1", features = ["json"] }
|
||||||
|
rocket_cors = "0.6.0"
|
||||||
|
serde = { version = "1.0.195", features = ["derive"] }
|
||||||
|
serde_json = "1.0.111"
|
||||||
|
tokio = { version = "1.35.1", features = ["full"] }
|
||||||
|
uuid = { version = "1.8.0", features = ["v4", "serde"] }
|
||||||
|
sqlx = { version = "0.8", features = ["postgres", "runtime-tokio-native-tls", "derive", "uuid", "chrono", "json"] }
|
||||||
|
maud = "0.26.0"
|
||||||
|
rand = "0.8.5"
|
||||||
|
data-encoding = "2.6.0"
|
||||||
|
bcrypt = "0.16.0"
|
29
README.md
Normal file
29
README.md
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
# Based
|
||||||
|
Based is a micro framework providing web dev primitives.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
- User Auth
|
||||||
|
- PostgresDB Connection
|
||||||
|
- Logging
|
||||||
|
- Request Contexts
|
||||||
|
- Templates (Shell)
|
||||||
|
|
||||||
|
## User Auth
|
||||||
|
To use the user auth feature, make sure a migration has added the following to your PostgresDB:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
CREATE TYPE user_role AS ENUM ('regular', 'admin');
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS users (
|
||||||
|
username VARCHAR(255) NOT NULL PRIMARY KEY,
|
||||||
|
"password" text NOT NULL,
|
||||||
|
user_role user_role NOT NULL DEFAULT 'regular'
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS user_session (
|
||||||
|
id UUID NOT NULL PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
token text NOT NULL,
|
||||||
|
"user" varchar(255) NOT NULL,
|
||||||
|
FOREIGN KEY("user") REFERENCES users(username)
|
||||||
|
);
|
||||||
|
```
|
24
src/format.rs
Normal file
24
src/format.rs
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
pub fn format_date(date: &chrono::NaiveDate) -> String {
|
||||||
|
// TODO : Implement
|
||||||
|
date.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn format_number(num: i32) -> String {
|
||||||
|
// TODO : Implement
|
||||||
|
num.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn format_seconds_to_hhmmss(seconds: f64) -> String {
|
||||||
|
let total_seconds = seconds as u64;
|
||||||
|
let hours = total_seconds / 3600;
|
||||||
|
let minutes = (total_seconds % 3600) / 60;
|
||||||
|
let seconds = total_seconds % 60;
|
||||||
|
if hours != 0 {
|
||||||
|
format!("{:02}:{:02}:{:02}", hours, minutes, seconds)
|
||||||
|
} else {
|
||||||
|
format!("{:02}:{:02}", minutes, seconds)
|
||||||
|
}
|
||||||
|
}
|
27
src/lib.rs
Normal file
27
src/lib.rs
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
use tokio::sync::OnceCell;
|
||||||
|
|
||||||
|
pub mod result;
|
||||||
|
pub mod request;
|
||||||
|
pub mod user;
|
||||||
|
pub mod page;
|
||||||
|
|
||||||
|
// Postgres
|
||||||
|
|
||||||
|
pub static PG: OnceCell<sqlx::PgPool> = OnceCell::const_new();
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! get_pg {
|
||||||
|
() => {
|
||||||
|
if let Some(client) = $crate::PG.get() {
|
||||||
|
client
|
||||||
|
} else {
|
||||||
|
let client = sqlx::postgres::PgPoolOptions::new()
|
||||||
|
.max_connections(5)
|
||||||
|
.connect(&std::env::var("DATABASE_URL").unwrap())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
$crate::PG.set(client).unwrap();
|
||||||
|
$crate::PG.get().unwrap()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
84
src/page/mod.rs
Normal file
84
src/page/mod.rs
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
use core::num;
|
||||||
|
|
||||||
|
use maud::{html, PreEscaped};
|
||||||
|
|
||||||
|
use crate::{request::context::RequestContext, user::User};
|
||||||
|
|
||||||
|
use rocket::{
|
||||||
|
http::{ContentType, Status},
|
||||||
|
request::{self, FromRequest, Request},
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn htmx_link(
|
||||||
|
url: &str,
|
||||||
|
class: &str,
|
||||||
|
onclick: &str,
|
||||||
|
content: PreEscaped<String>,
|
||||||
|
) -> PreEscaped<String> {
|
||||||
|
html!(
|
||||||
|
a class=(class) onclick=(onclick) href=(url) hx-get=(url) hx-target="#main_content" hx-push-url="true" hx-swap="innerHTML" {
|
||||||
|
(content);
|
||||||
|
};
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn script(script: &str) -> PreEscaped<String> {
|
||||||
|
html!(
|
||||||
|
script {
|
||||||
|
(PreEscaped(script))
|
||||||
|
};
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn shell(content: PreEscaped<String>, title: &str, user: Option<User>) -> PreEscaped<String> {
|
||||||
|
html! {
|
||||||
|
html {
|
||||||
|
head {
|
||||||
|
title { (title) };
|
||||||
|
script src="https://cdn.tailwindcss.com" {};
|
||||||
|
script src="https://unpkg.com/htmx.org@2.0.3" integrity="sha384-0895/pl2MU10Hqc6jd4RvrthNlDiE9U1tWmX7WRESftEDRosgxNsQG/Ze9YMRzHq" crossorigin="anonymous" {};
|
||||||
|
meta name="viewport" content="width=device-width, initial-scale=1.0";
|
||||||
|
};
|
||||||
|
body class="bg-black text-white" {
|
||||||
|
header class="bg-gray-800 text-white shadow-md py-2" {
|
||||||
|
|
||||||
|
div class="flex justify-between px-6" {
|
||||||
|
|
||||||
|
a href="/" class="flex items-center space-x-2" {
|
||||||
|
img src="/favicon" alt="Logo" class="w-10 h-10 rounded-md";
|
||||||
|
span class="font-semibold text-xl" { "WatchDogs" };
|
||||||
|
};
|
||||||
|
|
||||||
|
@if user.is_some() {
|
||||||
|
p { (user.unwrap().username) };
|
||||||
|
};
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
div id="main_content" {
|
||||||
|
(content)
|
||||||
|
};
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn render_page(
|
||||||
|
htmx: RequestContext,
|
||||||
|
content: PreEscaped<String>,
|
||||||
|
title: &str,
|
||||||
|
user: Option<User>,
|
||||||
|
) -> (Status, (ContentType, String)) {
|
||||||
|
if !htmx.is_htmx {
|
||||||
|
(
|
||||||
|
Status::Ok,
|
||||||
|
(ContentType::HTML, shell(content, title, user).into_string()),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(Status::Ok, (ContentType::HTML, content.into_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
25
src/request/api.rs
Normal file
25
src/request/api.rs
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
use rocket::http::{ContentType, Status};
|
||||||
|
|
||||||
|
/// A trait to generate a Model API representation in JSON format.
|
||||||
|
pub trait ToAPI: Sized {
|
||||||
|
/// Generate public API JSON
|
||||||
|
fn api(&self) -> impl std::future::Future<Output = serde_json::Value>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts a slice of items implementing the `ToAPI` trait into a `Vec` of JSON values.
|
||||||
|
pub async fn vec_to_api(items: &[impl ToAPI]) -> Vec<serde_json::Value> {
|
||||||
|
let mut ret = Vec::with_capacity(items.len());
|
||||||
|
|
||||||
|
for e in items {
|
||||||
|
ret.push(e.api().await);
|
||||||
|
}
|
||||||
|
|
||||||
|
ret
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn api_response(json: &serde_json::Value) -> (Status, (ContentType, String)) {
|
||||||
|
(
|
||||||
|
Status::Ok,
|
||||||
|
(ContentType::JSON, serde_json::to_string(json).unwrap()),
|
||||||
|
)
|
||||||
|
}
|
56
src/request/assets.rs
Normal file
56
src/request/assets.rs
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
use rocket::{
|
||||||
|
fs::NamedFile,
|
||||||
|
get,
|
||||||
|
http::{ContentType, Status},
|
||||||
|
State,
|
||||||
|
};
|
||||||
|
|
||||||
|
use tokio::{fs::File, io::AsyncReadExt};
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
|
#[get("/video/raw?<v>")]
|
||||||
|
pub async fn video_file(
|
||||||
|
v: &str,
|
||||||
|
library: &State<Library>,
|
||||||
|
) -> Option<(Status, (ContentType, Vec<u8>))> {
|
||||||
|
let video = if let Some(video) = library.get_video_by_id(v).await {
|
||||||
|
video
|
||||||
|
} else {
|
||||||
|
library.get_video_by_youtube_id(v).await.unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Ok(mut file) = File::open(&video.path).await {
|
||||||
|
let mut buf = Vec::with_capacity(51200);
|
||||||
|
file.read_to_end(&mut buf).await.ok()?;
|
||||||
|
let content_type = if video.path.ends_with("mp4") {
|
||||||
|
ContentType::new("video", "mp4")
|
||||||
|
} else {
|
||||||
|
ContentType::new("video", "webm")
|
||||||
|
};
|
||||||
|
|
||||||
|
return Some((Status::Ok, (content_type, buf)));
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
#[get("/video/thumbnail?<v>")]
|
||||||
|
pub async fn video_thumbnail(
|
||||||
|
v: &str,
|
||||||
|
library: &State<Library>,
|
||||||
|
) -> Option<(Status, (ContentType, Vec<u8>))> {
|
||||||
|
let video = if let Some(video) = library.get_video_by_id(v).await {
|
||||||
|
video
|
||||||
|
} else {
|
||||||
|
library.get_video_by_youtube_id(v).await.unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(data) = library.get_thumbnail(&video).await {
|
||||||
|
return Some((Status::Ok, (ContentType::PNG, data)));
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
*/
|
121
src/request/cache.rs
Normal file
121
src/request/cache.rs
Normal file
|
@ -0,0 +1,121 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use rocket::tokio::sync::RwLock;
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! use_api_cache {
|
||||||
|
($route:literal, $id:ident, $cache:ident) => {
|
||||||
|
if let Some(ret) = $cache.get_only($route, $id).await {
|
||||||
|
return Ok(serde_json::from_str(&ret).unwrap());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
($route:literal, $id:literal, $cache:ident) => {
|
||||||
|
if let Some(ret) = $cache.get_only($route, $id).await {
|
||||||
|
return Ok(serde_json::from_str(&ret).unwrap());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct RouteCache {
|
||||||
|
inner: RwLock<HashMap<String, HashMap<String, Option<String>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RouteCache {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
inner: RwLock::new(HashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get<F, Fut>(&self, route: &str, id: &str, generator: F) -> String
|
||||||
|
where
|
||||||
|
F: FnOnce() -> Fut,
|
||||||
|
Fut: std::future::Future<Output = String>,
|
||||||
|
{
|
||||||
|
{
|
||||||
|
// Try to get a read lock first.
|
||||||
|
let lock = self.inner.read().await;
|
||||||
|
if let Some(inner_map) = lock.get(route) {
|
||||||
|
if let Some(cached_value) = inner_map.get(id) {
|
||||||
|
log::info!("Using cached value for {route} / {id}");
|
||||||
|
return cached_value.clone().unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the value was not found, acquire a write lock to insert the computed value.
|
||||||
|
let mut lock = self.inner.write().await;
|
||||||
|
|
||||||
|
log::info!("Computing value for {route} / {id}");
|
||||||
|
let computed = generator().await;
|
||||||
|
|
||||||
|
lock.entry(route.to_string())
|
||||||
|
.or_insert_with(HashMap::new)
|
||||||
|
.insert(id.to_string(), Some(computed.clone()));
|
||||||
|
|
||||||
|
computed
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_only(&self, route: &str, id: &str) -> Option<String> {
|
||||||
|
let lock = self.inner.read().await;
|
||||||
|
if let Some(inner_map) = lock.get(route) {
|
||||||
|
if let Some(cached_value) = inner_map.get(id) {
|
||||||
|
log::info!("Using cached value for {route} / {id}");
|
||||||
|
return cached_value.clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_option<F, Fut>(&self, route: &str, id: &str, generator: F) -> Option<String>
|
||||||
|
where
|
||||||
|
F: FnOnce() -> Fut,
|
||||||
|
Fut: std::future::Future<Output = Option<String>>,
|
||||||
|
{
|
||||||
|
{
|
||||||
|
// Try to get a read lock first.
|
||||||
|
let lock = self.inner.read().await;
|
||||||
|
if let Some(inner_map) = lock.get(route) {
|
||||||
|
if let Some(cached_value) = inner_map.get(id) {
|
||||||
|
log::info!("Using cached value for {route} / {id}");
|
||||||
|
return cached_value.clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the value was not found, acquire a write lock to insert the computed value.
|
||||||
|
let mut lock = self.inner.write().await;
|
||||||
|
|
||||||
|
log::info!("Computing value for {route} / {id}");
|
||||||
|
let computed = generator().await;
|
||||||
|
|
||||||
|
lock.entry(route.to_string())
|
||||||
|
.or_insert_with(HashMap::new)
|
||||||
|
.insert(id.to_string(), computed.clone());
|
||||||
|
|
||||||
|
computed
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn insert(&self, route: &str, id: &str, value: String) {
|
||||||
|
let mut lock = self.inner.write().await;
|
||||||
|
|
||||||
|
log::info!("Inserting value for {route} / {id}");
|
||||||
|
|
||||||
|
lock.entry(route.to_string())
|
||||||
|
.or_insert_with(HashMap::new)
|
||||||
|
.insert(id.to_string(), Some(value));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn invalidate(&self, route: &str, id: &str) {
|
||||||
|
let mut lock = self.inner.write().await;
|
||||||
|
if let Some(inner_map) = lock.get_mut(route) {
|
||||||
|
inner_map.remove(id);
|
||||||
|
|
||||||
|
// If the inner map is empty, remove the route entry as well.
|
||||||
|
if inner_map.is_empty() {
|
||||||
|
lock.remove(route);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
22
src/request/context.rs
Normal file
22
src/request/context.rs
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
use rocket::{request::{self, FromRequest}, Request};
|
||||||
|
|
||||||
|
use crate::user::{Session, User};
|
||||||
|
|
||||||
|
pub struct RequestContext {
|
||||||
|
pub is_htmx: bool
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for RequestContext {
|
||||||
|
type Error = ();
|
||||||
|
|
||||||
|
async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
|
||||||
|
rocket::outcome::Outcome::Success(RequestContext {
|
||||||
|
is_htmx: !req
|
||||||
|
.headers()
|
||||||
|
.get("HX-Request")
|
||||||
|
.collect::<Vec<&str>>()
|
||||||
|
.is_empty()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
27
src/request/mod.rs
Normal file
27
src/request/mod.rs
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
use rocket::response::status::BadRequest;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
pub mod context;
|
||||||
|
pub mod assets;
|
||||||
|
pub mod cache;
|
||||||
|
pub mod api;
|
||||||
|
|
||||||
|
|
||||||
|
pub fn to_uuid(id: &str) -> Result<uuid::Uuid, ApiError> {
|
||||||
|
uuid::Uuid::from_str(id).map_err(|_| no_uuid_error())
|
||||||
|
}
|
||||||
|
|
||||||
|
type ApiError = BadRequest<serde_json::Value>;
|
||||||
|
type FallibleApiResponse = Result<serde_json::Value, ApiError>;
|
||||||
|
|
||||||
|
pub fn no_uuid_error() -> ApiError {
|
||||||
|
api_error("No valid UUID")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn api_error(msg: &str) -> ApiError {
|
||||||
|
BadRequest(json!({
|
||||||
|
"error": msg
|
||||||
|
}))
|
||||||
|
}
|
19
src/result.rs
Normal file
19
src/result.rs
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
pub trait LogAndIgnore {
|
||||||
|
fn log_and_ignore(self, msg: &str);
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, E: std::fmt::Debug> LogAndIgnore for Result<T, E> {
|
||||||
|
/// Handles the result by ignoring and logging it if it contains an error.
|
||||||
|
///
|
||||||
|
/// If the result is `Ok`, does nothing.
|
||||||
|
/// If the result is `Err(e)`
|
||||||
|
/// logs the message provided (`msg`) along with the error.
|
||||||
|
fn log_and_ignore(self, msg: &str) {
|
||||||
|
match self {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(e) => {
|
||||||
|
log::error!("{msg} : {:?}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
183
src/user/mod.rs
Normal file
183
src/user/mod.rs
Normal file
|
@ -0,0 +1,183 @@
|
||||||
|
use data_encoding::HEXUPPER;
|
||||||
|
use rand::RngCore;
|
||||||
|
use rocket::{http::Status, outcome::Outcome, request::FromRequest, Request};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::json;
|
||||||
|
use sqlx::FromRow;
|
||||||
|
|
||||||
|
use crate::{get_pg, request::api::ToAPI};
|
||||||
|
|
||||||
|
fn gen_token(token_length: usize) -> String {
|
||||||
|
let mut token_bytes = vec![0u8; token_length];
|
||||||
|
|
||||||
|
rand::thread_rng().fill_bytes(&mut token_bytes);
|
||||||
|
|
||||||
|
HEXUPPER.encode(&token_bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||||
|
pub struct User {
|
||||||
|
/// The username chosen by the user
|
||||||
|
pub username: String,
|
||||||
|
/// The hashed password for the user
|
||||||
|
pub password: String,
|
||||||
|
/// The role of the user
|
||||||
|
pub user_role: UserRole,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::Type)]
|
||||||
|
#[sqlx(type_name = "user_role", rename_all = "lowercase")]
|
||||||
|
pub enum UserRole {
|
||||||
|
/// A regular user with limited permissions
|
||||||
|
Regular,
|
||||||
|
/// An admin user with full system privileges
|
||||||
|
Admin,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl User {
|
||||||
|
/// Find a user by their username
|
||||||
|
pub async fn find(username: &str) -> Option<Self> {
|
||||||
|
sqlx::query_as("SELECT * FROM users WHERE username = $1")
|
||||||
|
.bind(username)
|
||||||
|
.fetch_optional(get_pg!())
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new user with the given details
|
||||||
|
///
|
||||||
|
/// Returns an Option containing the created user, or None if a user already exists with the same username
|
||||||
|
pub async fn create(username: &str, password: &str, role: UserRole) -> Option<Self> {
|
||||||
|
// Check if a user already exists with the same username
|
||||||
|
if Self::find(username).await.is_some() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let u = Self {
|
||||||
|
username: username.to_string(),
|
||||||
|
password: bcrypt::hash(password, bcrypt::DEFAULT_COST).unwrap(),
|
||||||
|
user_role: role,
|
||||||
|
};
|
||||||
|
|
||||||
|
sqlx::query("INSERT INTO users (username, \"password\", user_role) VALUES ($1, $2, $3)")
|
||||||
|
.bind(&u.username)
|
||||||
|
.bind(&u.password)
|
||||||
|
.bind(&u.user_role)
|
||||||
|
.execute(get_pg!())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Some(u)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Login a user with the given username and password
|
||||||
|
pub async fn login(username: &str, password: &str) -> Option<(Session, UserRole)> {
|
||||||
|
let u = Self::find(username).await?;
|
||||||
|
|
||||||
|
if !u.verify_pw(password) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some((u.session().await, u.user_role))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Change the password of a User
|
||||||
|
///
|
||||||
|
/// Returns a Result indicating whether the password change was successful or not
|
||||||
|
pub async fn passwd(self, old: &str, new: &str) -> Result<(), ()> {
|
||||||
|
if self.verify_pw(old) {
|
||||||
|
sqlx::query("UPDATE users SET \"password\" = $1 WHERE username = $2;")
|
||||||
|
.bind(bcrypt::hash(new, bcrypt::DEFAULT_COST).unwrap())
|
||||||
|
.bind(&self.username)
|
||||||
|
.fetch_one(get_pg!())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find all users in the system
|
||||||
|
pub async fn find_all() -> Vec<Self> {
|
||||||
|
sqlx::query_as("SELECT * FROM users")
|
||||||
|
.fetch_all(get_pg!())
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate a new session token for the user
|
||||||
|
///
|
||||||
|
/// Returns a Session instance containing the generated token and associated user
|
||||||
|
pub async fn session(&self) -> Session {
|
||||||
|
sqlx::query_as(
|
||||||
|
"INSERT INTO user_session (token, \"user\") VALUES ($1, $2) RETURNING id, token, \"user\"",
|
||||||
|
)
|
||||||
|
.bind(gen_token(64))
|
||||||
|
.bind(&self.username)
|
||||||
|
.fetch_one(get_pg!())
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if the user is an admin
|
||||||
|
pub const fn is_admin(&self) -> bool {
|
||||||
|
matches!(self.user_role, UserRole::Admin)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Verify that a provided password matches the hashed password for the user
|
||||||
|
///
|
||||||
|
/// Returns a boolean indicating whether the passwords match or not
|
||||||
|
pub fn verify_pw(&self, password: &str) -> bool {
|
||||||
|
bcrypt::verify(password, &self.password).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToAPI for User {
|
||||||
|
async fn api(&self) -> serde_json::Value {
|
||||||
|
json!({
|
||||||
|
"username": self.username,
|
||||||
|
"role": self.user_role
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||||
|
pub struct Session {
|
||||||
|
/// The unique ID of the session token
|
||||||
|
pub id: uuid::Uuid,
|
||||||
|
/// The generated session token
|
||||||
|
pub token: String,
|
||||||
|
/// The username associated with the session token
|
||||||
|
pub user: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! check_admin {
|
||||||
|
($u:ident) => {
|
||||||
|
if !$u.is_admin() {
|
||||||
|
return Err(api_error("Forbidden"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for User {
|
||||||
|
type Error = ();
|
||||||
|
|
||||||
|
async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome<Self, Self::Error> {
|
||||||
|
// todo : cookie auth
|
||||||
|
match request.headers().get_one("token") {
|
||||||
|
Some(key) => {
|
||||||
|
if let Some(user) = sqlx::query_as("SELECT * FROM users WHERE username = (SELECT \"user\" FROM user_session WHERE token = $1)").bind(key).fetch_optional(get_pg!()).await.unwrap() {
|
||||||
|
Outcome::Success(user)
|
||||||
|
} else {
|
||||||
|
Outcome::Error((Status::Unauthorized, ()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => Outcome::Error((Status::Unauthorized, ())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue