From 397369ae1e5095e78e1099b38adb72d1969969d2 Mon Sep 17 00:00:00 2001 From: JMARyA Date: Fri, 27 Dec 2024 00:07:01 +0100 Subject: [PATCH] awfully cursed --- Cargo.toml | 2 +- examples/sqlite.rs | 21 ++++++++++++++++++++ src/auth/user.rs | 24 +++++++++++++++++++++-- src/db.rs | 47 ++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 49 ++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 140 insertions(+), 3 deletions(-) create mode 100644 examples/sqlite.rs create mode 100644 src/db.rs diff --git a/Cargo.toml b/Cargo.toml index b506a28..1e61599 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ serde = { version = "1.0.195", features = ["derive"] } serde_json = "1.0.111" tokio = { version = "1.35.1", features = ["full"] } uuid = { version = "1.8.0", features = ["v4", "serde"] } -sqlx = { version = "0.8", features = ["postgres", "runtime-tokio-native-tls", "derive", "uuid", "chrono", "json"] } +sqlx = { version = "0.8", features = ["postgres", "runtime-tokio-native-tls", "derive", "uuid", "chrono", "json", "sqlite"] } maud = "0.26.0" rand = "0.8.5" data-encoding = "2.6.0" diff --git a/examples/sqlite.rs b/examples/sqlite.rs new file mode 100644 index 0000000..11beb4a --- /dev/null +++ b/examples/sqlite.rs @@ -0,0 +1,21 @@ +use based::{get_sqlite, get_sqlite_or_create}; +use based::request::{RequestContext, StringResponse}; +use based::{ + get_pg, + page::{Shell, render_page}, +}; +use maud::html; +use rocket::get; +use rocket::routes; + +#[rocket::launch] +async fn launch() -> _ { + // Logging + env_logger::init(); + + // Database + unsafe { std::env::set_var("DATABASE_URL", "test.db"); } + let db = based::db::Database::new().await; + + rocket::build().mount("/", routes![]) +} diff --git a/src/auth/user.rs b/src/auth/user.rs index ad85fb0..fba95be 100644 --- a/src/auth/user.rs +++ b/src/auth/user.rs @@ -4,7 +4,7 @@ use serde_json::json; use sqlx::FromRow; use super::{Session, gen_token}; -use crate::{get_pg, request::api::ToAPI}; +use crate::{db::Database, get_pg, request::api::ToAPI, with}; /// User /// @@ -40,7 +40,27 @@ pub enum UserRole { impl User { // Get a user from session ID pub async fn from_session(session: &str) -> Option { - sqlx::query_as("SELECT * FROM users WHERE username = (SELECT \"user\" FROM user_session WHERE token = $1)").bind(session).fetch_optional(get_pg!()).await.unwrap() + with!( + Database::new().await, + { + sqlx::query_as::<_, Self>( + "SELECT * FROM users WHERE username = (SELECT \"user\" FROM user_session WHERE token = $1)" + ) + .bind(session) + .fetch_optional(&pg) + .await + .ok() + }, + { + sqlx::query_as::<_, Self>( + "SELECT * FROM users WHERE username = (SELECT \"user\" FROM user_session WHERE token = $1)" + ) + .bind(session) + .fetch_optional(&sqlite) + .await + .ok() + } + ) } /// Find a user by their username diff --git a/src/db.rs b/src/db.rs new file mode 100644 index 0000000..6aecbdf --- /dev/null +++ b/src/db.rs @@ -0,0 +1,47 @@ +use crate::{get_pg, get_sqlite, get_sqlite_or_create}; + + +pub enum DatabaseBackend { + Sqlite, + Postgres +} + +pub struct Database { + pub in_use: DatabaseBackend, + pub sqlite: Option>, + pub postgres: Option> +} + +impl Database { + pub async fn new() -> Self { + if std::env::var("DATABASE_URL").unwrap().contains("postgres") { + return Self { + in_use: DatabaseBackend::Postgres, + sqlite: None, + postgres: Some(get_pg!().clone()) + }; + } else { + return Self { + in_use: DatabaseBackend::Sqlite, + sqlite: Some(get_sqlite_or_create!().clone()), + postgres: None + }; + } + } +} + +#[macro_export] +macro_rules! with { + ($db:expr, $pg:block, $sqlite:block) => {{ + match $db.in_use { + crate::db::DatabaseBackend::Postgres => { + let pg = $db.postgres.clone().expect("Postgres connection not available"); + $pg + } + crate::db::DatabaseBackend::Sqlite => { + let sqlite = $db.sqlite.clone().expect("Sqlite connection not available"); + $sqlite + } + } + }}; +} diff --git a/src/lib.rs b/src/lib.rs index 6cd0577..01367f2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,14 +7,22 @@ pub mod htmx; pub mod page; pub mod request; pub mod result; +pub mod db; // TODO : API Pagination? // TODO : CORS? // TODO : CSRF? + +// TODO : Implement SQLITE backend +// Code sharing with macro +// Seperate migrations +// which db by $DB env var "postgres" "sqlite" + // Postgres pub static PG: OnceCell = OnceCell::const_new(); +pub static SQ: OnceCell = OnceCell::const_new(); /// A macro to retrieve or initialize the PostgreSQL connection pool. /// @@ -43,3 +51,44 @@ macro_rules! get_pg { } }; } + +#[macro_export] +macro_rules! get_sqlite { + () => { + if let Some(client) = $crate::SQ.get() { + client + } else { + + let client = sqlx::sqlite::SqlitePoolOptions::new() + .max_connections(5) + .connect(&std::env::var("DATABASE_URL").unwrap()) + .await + .unwrap(); + $crate::SQ.set(client).unwrap(); + $crate::SQ.get().unwrap() + } + }; +} + +#[macro_export] +macro_rules! get_sqlite_or_create { + () => { + if let Some(client) = $crate::SQ.get() { + client + } else { + let db_url = std::env::var("DATABASE_URL").unwrap(); + + if !std::fs::exists(&db_url).unwrap_or(false) { + std::fs::File::create(&db_url).unwrap(); + } + + let client = sqlx::sqlite::SqlitePoolOptions::new() + .max_connections(5) + .connect(&db_url) + .await + .unwrap(); + $crate::SQ.set(client).unwrap(); + $crate::SQ.get().unwrap() + } + }; +}