From 609d52098632352239a7ef7c2e7d9ed45c32b874 Mon Sep 17 00:00:00 2001 From: JMARyA Date: Mon, 30 Dec 2024 14:06:32 +0100 Subject: [PATCH 1/2] add vector db --- Cargo.lock | 304 +++++++++++++++++++++++++++++++--- Cargo.toml | 3 + docker-compose.yml | 13 ++ env | 6 + migrations/0001_embedding.sql | 10 ++ src/ai.rs | 179 ++++++++++++++++++++ src/archive.rs | 12 +- src/main.rs | 15 +- src/pages/mod.rs | 43 +++-- 9 files changed, 547 insertions(+), 38 deletions(-) create mode 100644 migrations/0001_embedding.sql create mode 100644 src/ai.rs diff --git a/Cargo.lock b/Cargo.lock index 10e9ca6..7d196c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -78,7 +78,7 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -89,7 +89,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -276,6 +276,12 @@ dependencies = [ "shlex", ] +[[package]] +name = "cesu8" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" + [[package]] name = "cfg-if" version = "1.0.0" @@ -307,6 +313,16 @@ dependencies = [ "inout", ] +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "memchr", +] + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -487,7 +503,7 @@ dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -510,7 +526,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -649,6 +665,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futf" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df420e2e84819663797d1ec6544b13c5be84629e7bb00dc960d6917db2987843" +dependencies = [ + "mac", + "new_debug_unreachable", +] + [[package]] name = "futures" version = "0.3.31" @@ -716,7 +742,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -909,6 +935,34 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "html2md" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be92446e11d68f5d71367d571c229d09ced1f24ab6d08ea0bff329d5f6c0b2a3" +dependencies = [ + "html5ever", + "jni", + "lazy_static", + "markup5ever_rcdom", + "percent-encoding", + "regex", +] + +[[package]] +name = "html5ever" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bea68cab48b8459f17cf1c944c67ddc572d272d9f2b274140f223ecb1da4a3b7" +dependencies = [ + "log", + "mac", + "markup5ever", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "http" version = "0.2.12" @@ -1230,7 +1284,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -1303,6 +1357,26 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" +[[package]] +name = "jni" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6df18c2e3db7e453d3c6ac5b3e9d5182664d28788126d39b91f2d1e22b017ec" +dependencies = [ + "cesu8", + "combine", + "jni-sys", + "log", + "thiserror", + "walkdir", +] + +[[package]] +name = "jni-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" + [[package]] name = "js-sys" version = "0.3.76" @@ -1388,6 +1462,38 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "mac" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4" + +[[package]] +name = "markup5ever" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2629bb1404f3d34c2e921f21fd34ba00b206124c81f65c50b43b6aaefeb016" +dependencies = [ + "log", + "phf", + "phf_codegen", + "string_cache", + "string_cache_codegen", + "tendril", +] + +[[package]] +name = "markup5ever_rcdom" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9521dd6750f8e80ee6c53d65e2e4656d7de37064f3a7a5d2d11d05df93839c2" +dependencies = [ + "html5ever", + "markup5ever", + "tendril", + "xml5ever", +] + [[package]] name = "matchers" version = "0.1.0" @@ -1416,7 +1522,7 @@ dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -1503,6 +1609,12 @@ dependencies = [ "tempfile", ] +[[package]] +name = "new_debug_unreachable" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" + [[package]] name = "nom" version = "7.1.3" @@ -1595,6 +1707,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "ollama-rs" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "763afb01db2dced00e656cc2cdcd875659fc3fac4c449e6337a4f04f9e3d9efc" +dependencies = [ + "async-stream", + "async-trait", + "log", + "reqwest 0.12.11", + "serde", + "serde_json", + "url", +] + [[package]] name = "once_cell" version = "1.20.2" @@ -1624,7 +1751,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -1706,7 +1833,7 @@ dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -1724,6 +1851,53 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "pgvector" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0e8871b6d7ca78348c6cd29b911b94851f3429f0cd403130ca17f26c1fb91a6" +dependencies = [ + "sqlx", +] + +[[package]] +name = "phf" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fabbf1ead8a5bcbc20f5f8b939ee3f5b0f6f281b6ad3468b84656b658b455259" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_codegen" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb1c3a8bc4dd4e5cfce29b44ffc14bedd2ee294559a294e2a4d4c9e9a6a13cd" +dependencies = [ + "phf_generator", + "phf_shared", +] + +[[package]] +name = "phf_generator" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d5285893bb5eb82e6aaf5d59ee909a06a16737a8970984dd7746ba9283498d6" +dependencies = [ + "phf_shared", + "rand", +] + +[[package]] +name = "phf_shared" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6796ad771acdc0123d2a88dc428b5e38ef24456743ddb1744ed628f9815c096" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project-lite" version = "0.2.15" @@ -1778,6 +1952,12 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "precomputed-hash" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -1818,7 +1998,7 @@ checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", "version_check", "yansi", ] @@ -1908,7 +2088,7 @@ checksum = "bcc303e793d3734489387d205e9b186fac9c6cfacedd98cbb2e8a5943595f3e6" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -2119,7 +2299,7 @@ dependencies = [ "proc-macro2", "quote", "rocket_http", - "syn", + "syn 2.0.93", "unicode-xid", "version_check", ] @@ -2337,7 +2517,7 @@ checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -2429,6 +2609,12 @@ dependencies = [ "rand_core", ] +[[package]] +name = "siphasher" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" + [[package]] name = "slab" version = "0.4.9" @@ -2557,7 +2743,7 @@ dependencies = [ "quote", "sqlx-core", "sqlx-macros-core", - "syn", + "syn 2.0.93", ] [[package]] @@ -2580,7 +2766,7 @@ dependencies = [ "sqlx-mysql", "sqlx-postgres", "sqlx-sqlite", - "syn", + "syn 2.0.93", "tempfile", "tokio", "url", @@ -2719,6 +2905,32 @@ dependencies = [ "loom", ] +[[package]] +name = "string_cache" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f91138e76242f575eb1d3b38b4f1362f10d3a43f47d182a5b359af488a02293b" +dependencies = [ + "new_debug_unreachable", + "once_cell", + "parking_lot", + "phf_shared", + "precomputed-hash", + "serde", +] + +[[package]] +name = "string_cache_codegen" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bb30289b722be4ff74a408c3cc27edeaad656e06cb1fe8fa9231fa59c728988" +dependencies = [ + "phf_generator", + "phf_shared", + "proc-macro2", + "quote", +] + [[package]] name = "stringprep" version = "0.1.5" @@ -2736,6 +2948,17 @@ version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.93" @@ -2770,7 +2993,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -2828,6 +3051,17 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "tendril" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d24a120c5fc464a3458240ee02c299ebcb9d67b5249c8848b09d639dca8d7bb0" +dependencies = [ + "futf", + "mac", + "utf-8", +] + [[package]] name = "termcolor" version = "1.4.1" @@ -2854,7 +3088,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -2949,7 +3183,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -3077,7 +3311,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -3228,6 +3462,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf16_iter" version = "1.0.5" @@ -3320,7 +3560,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn", + "syn 2.0.93", "wasm-bindgen-shared", ] @@ -3355,7 +3595,7 @@ checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3384,8 +3624,11 @@ dependencies = [ "chrono", "env_logger", "futures", + "html2md", "log", "maud", + "ollama-rs", + "pgvector", "regex", "reqwest 0.12.11", "rocket", @@ -3665,6 +3908,17 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" +[[package]] +name = "xml5ever" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4034e1d05af98b51ad7214527730626f019682d797ba38b51689212118d8e650" +dependencies = [ + "log", + "mac", + "markup5ever", +] + [[package]] name = "yansi" version = "1.0.1" @@ -3694,7 +3948,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", "synstructure", ] @@ -3716,7 +3970,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -3736,7 +3990,7 @@ checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", "synstructure", ] @@ -3765,5 +4019,5 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] diff --git a/Cargo.toml b/Cargo.toml index cc5a0f8..1da2ff1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,3 +19,6 @@ maud = "0.26.0" based = { git = "https://git.hydrar.de/jmarya/based", features = [] } url = "2.5.4" reqwest = "0.12.11" +ollama-rs = "0.2.2" +pgvector = { version = "0.4", features = ["sqlx"] } +html2md = "0.2.14" diff --git a/docker-compose.yml b/docker-compose.yml index ef00b97..00690e8 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -7,3 +7,16 @@ services: - ./websites:/websites - ./favicon:/favicon env_file: env + + postgres: + # Any Postgres with support for pgvector + image: git.hydrar.de/hydra/postgres:latest + restart: always + ports: + - 5432:5432 + volumes: + - ./db:/var/lib/postgresql/data/ + environment: + - POSTGRES_USER=user + - POSTGRES_PASSWORD=pass + - POSTGRES_DB=webarc diff --git a/env b/env index 61f6371..c2354bd 100644 --- a/env +++ b/env @@ -10,3 +10,9 @@ DOWNLOAD_ON_DEMAND=true # Blacklisted domains (Comma-seperated regex) BLACKLIST_DOMAINS="google.com,.*.youtube.com" + +# Database +DATABASE_URL=postgres://user:pass@postgres/webarc + +# Ollama URL (Enables vector search) +OLLAMA_URL=127.0.0.1:11434 diff --git a/migrations/0001_embedding.sql b/migrations/0001_embedding.sql new file mode 100644 index 0000000..1826b2b --- /dev/null +++ b/migrations/0001_embedding.sql @@ -0,0 +1,10 @@ + +CREATE EXTENSION IF NOT EXISTS vector; + +CREATE TABLE doc_embedding ( + domain VARCHAR(500) NOT NULL, + path VARCHAR(1000) NOT NULL, + ver VARCHAR(10) NOT NULL, + embed_mxbai_embed_large vector(1024) NOT NULL, + PRIMARY KEY (domain, path, ver) +) diff --git a/src/ai.rs b/src/ai.rs new file mode 100644 index 0000000..31a094b --- /dev/null +++ b/src/ai.rs @@ -0,0 +1,179 @@ +use std::collections::VecDeque; + +use based::get_pg; +use ollama_rs::generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest}; +use serde::Serialize; +use sqlx::FromRow; + +use crate::archive::{Document, Domain, WebsiteArchive}; + +#[derive(Debug, Clone, FromRow, Serialize)] +pub struct DocEmbedding { + pub domain: String, + pub path: String, + pub ver: String, + + #[serde(skip)] + embed_mxbai_embed_large: pgvector::Vector, + + #[sqlx(default)] + pub similarity: f64, +} + +pub trait Embedding { + fn embedding(&self, ver: Option) + -> impl std::future::Future>>; +} + +impl Embedding for Document { + async fn embedding(&self, ver: Option) -> Option> { + let latest = "latest".to_string(); + log::info!( + "Generating Vector embeddings for {} / {} @ {}", + self.domain, + self.path, + ver.as_ref().unwrap_or(&latest) + ); + + let content_html = self.render_local(ver).await.unwrap(); + let content = html2md::parse_html(&content_html); + generate_embedding(content).await + } +} + +pub async fn generate_embedding(input: String) -> Option> { + if let Ok(ollama_url) = std::env::var("OLLAMA_URL") { + let (host, port) = ollama_url.split_once(':')?; + let ollama = ollama_rs::Ollama::new(format!("http://{host}"), port.parse().ok()?); + + let models = ollama.list_local_models().await.ok()?; + + if !models + .into_iter() + .any(|x| x.name.starts_with("mxbai-embed-large")) + { + log::info!("Model not found. Pulling 'mxbai-embed-large'"); + ollama + .pull_model("mxbai-embed-large".to_string(), false) + .await + .ok()?; + } + + let res = ollama + .generate_embeddings(GenerateEmbeddingsRequest::new( + "mxbai-embed-large".to_string(), + EmbeddingsInput::Single(input), + )) + .await + .ok()?; + let embed = res.embeddings.first()?; + return Some(embed.clone()); + } + + None +} + +pub struct EmbedStore; + +impl EmbedStore { + pub async fn get_embedding(doc: &Document, ver: Option<&str>) -> Option { + let use_ver = ver.map_or_else( + || { + let version = doc.versions(); + version.first().unwrap().clone() + }, + |x| x.to_string(), + ); + sqlx::query_as("SELECT * FROM doc_embedding WHERE domain = $1 AND path = $2 AND ver = $3") + .bind(&doc.domain) + .bind(&doc.path) + .bind(use_ver) + .fetch_optional(get_pg!()) + .await + .unwrap() + } + + pub async fn embed_document(doc: &Document, ver: &str) { + if let Some(embed) = doc.embedding(Some(ver.to_string())).await { + let _ = sqlx::query( + "DELETE FROM doc_embedding WHERE domain = $1 AND path = $2 AND ver = $3", + ) + .bind(&doc.domain) + .bind(&doc.path) + .bind(ver) + .execute(get_pg!()) + .await; + + sqlx::query("INSERT INTO doc_embedding VALUES ($1, $2, $3, $4)") + .bind(&doc.domain) + .bind(&doc.path) + .bind(ver) + .bind(embed) + .execute(get_pg!()) + .await + .unwrap(); + } else { + log::warn!( + "No embeds could be generated for {} / {}", + doc.domain, + doc.path + ); + } + } + + pub async fn ensure_embedding(doc: &Document) { + for ver in doc.versions() { + if Self::get_embedding(doc, Some(ver.as_str())).await.is_none() { + Self::embed_document(doc, &ver).await; + } + } + } + + pub async fn search_vector(v: &pgvector::Vector) -> Vec { + sqlx::query_as( + "SELECT *, 1 / (1 + (embed_mxbai_embed_large <-> $1)) AS similarity FROM doc_embedding ORDER BY embed_mxbai_embed_large <-> $1 LIMIT 5", + ) + .bind(v) + .fetch_all(get_pg!()) + .await + .unwrap() + } + + pub async fn generate_embeddings_for(arc: &WebsiteArchive) { + log::info!("Generating embeddings"); + + for dom in arc.domains() { + let dom = arc.get_domain(&dom); + embed_path(&dom, "/").await; + } + + log::info!("Done generating embeddings"); + } +} + +pub async fn embed_path(dom: &Domain, path: &str) { + let (paths, is_doc) = dom.paths(path); + + // If the path is a document, process the root path. + if is_doc { + let doc = dom.path("/"); + EmbedStore::ensure_embedding(&doc).await; + } + + // Create a queue to process paths iteratively + let mut queue = VecDeque::new(); + + // Add the initial paths to the queue + queue.extend(paths); + + while let Some(next_path) = queue.pop_front() { + let (next_paths, is_doc) = dom.paths(next_path.path()); + + if is_doc { + let doc = dom.path(next_path.path()); + EmbedStore::ensure_embedding(&doc).await; + } + + queue.extend(next_paths); + } +} diff --git a/src/archive.rs b/src/archive.rs index 7fc566c..05e6290 100644 --- a/src/archive.rs +++ b/src/archive.rs @@ -214,7 +214,17 @@ impl Document { pub fn versions(&self) -> Vec { let mut res: Vec = read_dir(&self.doc_dir()) .into_iter() - .filter(|x| x.starts_with("index_") && x.ends_with(".html")) + .filter_map(|x| { + if x.starts_with("index_") && x.ends_with(".html") { + return Some( + x.trim_start_matches("index_") + .trim_end_matches(".html") + .to_string(), + ); + } + + None + }) .collect(); res.sort(); res.reverse(); diff --git a/src/main.rs b/src/main.rs index df8ddb4..5f0db31 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,9 @@ +use ai::EmbedStore; use archive::WebsiteArchive; +use based::get_pg; use rocket::routes; +mod ai; mod archive; mod blacklist; mod favicon; @@ -12,6 +15,15 @@ async fn launch() -> _ { let arc = WebsiteArchive::new("./websites"); + if std::env::var("DATABASE_URL").is_ok() { + let pg = get_pg!(); + sqlx::migrate!("./migrations").run(pg).await.unwrap(); + } + + if std::env::var("OLLAMA_URL").is_ok() { + EmbedStore::generate_embeddings_for(&arc).await; + } + let archive = arc.clone(); tokio::spawn(async move { @@ -25,7 +37,8 @@ async fn launch() -> _ { pages::index, pages::render_website, pages::domain_info_route, - pages::favicon_route + pages::favicon_route, + pages::vector_search ], ) .manage(arc) diff --git a/src/pages/mod.rs b/src/pages/mod.rs index 4b1b502..9a722f5 100644 --- a/src/pages/mod.rs +++ b/src/pages/mod.rs @@ -1,13 +1,17 @@ use std::{io::Read, path::PathBuf}; -use based::request::{assets::DataResponse, RequestContext, StringResponse}; +use based::request::{assets::DataResponse, respond_json, RequestContext, StringResponse}; use maud::html; use rocket::{get, State}; pub mod component; use component::*; +use serde_json::json; -use crate::archive::WebsiteArchive; +use crate::{ + ai::{generate_embedding, EmbedStore}, + archive::WebsiteArchive, +}; /// Get the favicon of a domain #[get("/favicon/")] @@ -60,15 +64,7 @@ pub async fn domain_info_route( ) -> StringResponse { let domain = arc.get_domain(domain); let document = domain.path(paths.to_str().unwrap()); - let versions: Vec = document - .versions() - .into_iter() - .map(|x| { - x.trim_start_matches("index_") - .trim_end_matches(".html") - .to_string() - }) - .collect(); + let versions: Vec = document.versions(); let (path_entries, is_doc) = domain.paths(paths.to_str().unwrap()); let path_seperations: Vec<&str> = paths.to_str().unwrap().split('/').collect(); @@ -163,3 +159,28 @@ pub async fn render_website( None } + +#[get("/vector_search?")] +pub async fn vector_search(query: &str) -> Option { + if std::env::var("OLLAMA_URL").is_err() { + return None; + } + + if query.ends_with(".json") { + let query = query.trim_end_matches(".json"); + let results = EmbedStore::search_vector(&pgvector::Vector::from( + generate_embedding(query.to_string()).await?, + )) + .await; + return Some(respond_json(&json!(&results))); + } + + let results = EmbedStore::search_vector(&pgvector::Vector::from( + generate_embedding(query.to_string()).await?, + )) + .await; + + // TODO : Implement Search UI with HTMX + + None +} From 75262802238d1f656204d5f3acf8504f00b24767 Mon Sep 17 00:00:00 2001 From: JMARyA Date: Mon, 30 Dec 2024 21:25:40 +0100 Subject: [PATCH 2/2] add vector search --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/ai.rs | 26 +++++++++-- src/archive.rs | 6 ++- src/main.rs | 1 + src/pages/component.rs | 11 ++++- src/pages/mod.rs | 97 +++++++++++++++++++++++++++++++++--------- 7 files changed, 116 insertions(+), 29 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7d196c4..88425d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -164,7 +164,7 @@ checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" [[package]] name = "based" version = "0.1.0" -source = "git+https://git.hydrar.de/jmarya/based#d6555edc29de66ff5190b716a1f8ebac8dbb2110" +source = "git+https://git.hydrar.de/jmarya/based#00bb6f152d758252d62a511705ef35c8aa118168" dependencies = [ "bcrypt", "chrono", diff --git a/Cargo.toml b/Cargo.toml index 1da2ff1..1385b30 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ tokio = { version = "1.35.1", features = ["full"] } uuid = { version = "1.8.0", features = ["v4", "serde"] } sqlx = { version = "0.8", features = ["postgres", "runtime-tokio-native-tls", "derive", "uuid", "chrono", "json"] } maud = "0.26.0" -based = { git = "https://git.hydrar.de/jmarya/based", features = [] } +based = { git = "https://git.hydrar.de/jmarya/based", features = ["htmx"] } url = "2.5.4" reqwest = "0.12.11" ollama-rs = "0.2.2" diff --git a/src/ai.rs b/src/ai.rs index 31a094b..7f103ce 100644 --- a/src/ai.rs +++ b/src/ai.rs @@ -1,8 +1,9 @@ use std::collections::VecDeque; -use based::get_pg; +use based::{get_pg, request::api::ToAPI}; use ollama_rs::generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest}; use serde::Serialize; +use serde_json::json; use sqlx::FromRow; use crate::archive::{Document, Domain, WebsiteArchive}; @@ -20,6 +21,17 @@ pub struct DocEmbedding { pub similarity: f64, } +impl ToAPI for DocEmbedding { + async fn api(&self) -> serde_json::Value { + json!({ + "domain": self.domain, + "path": self.path, + "ver": self.ver, + "similarity": self.similarity + }) + } +} + pub trait Embedding { fn embedding(&self, ver: Option) -> impl std::future::Future>>; @@ -41,7 +53,7 @@ impl Embedding for Document { } } -pub async fn generate_embedding(input: String) -> Option> { +pub async fn generate_embedding(mut input: String) -> Option> { if let Ok(ollama_url) = std::env::var("OLLAMA_URL") { let (host, port) = ollama_url.split_once(':')?; let ollama = ollama_rs::Ollama::new(format!("http://{host}"), port.parse().ok()?); @@ -59,6 +71,10 @@ pub async fn generate_embedding(input: String) -> Option> { .ok()?; } + if input.is_empty() { + input = " ".to_string(); + } + let res = ollama .generate_embeddings(GenerateEmbeddingsRequest::new( "mxbai-embed-large".to_string(), @@ -129,11 +145,13 @@ impl EmbedStore { } } - pub async fn search_vector(v: &pgvector::Vector) -> Vec { + pub async fn search_vector(v: &pgvector::Vector, limit: i64, offset: i64) -> Vec { sqlx::query_as( - "SELECT *, 1 / (1 + (embed_mxbai_embed_large <-> $1)) AS similarity FROM doc_embedding ORDER BY embed_mxbai_embed_large <-> $1 LIMIT 5", + "SELECT *, 1 / (1 + (embed_mxbai_embed_large <-> $1)) AS similarity FROM doc_embedding ORDER BY embed_mxbai_embed_large <-> $1 LIMIT $2 OFFSET $3", ) .bind(v) + .bind(limit) + .bind(offset) .fetch_all(get_pg!()) .await .unwrap() diff --git a/src/archive.rs b/src/archive.rs index 05e6290..16475e1 100644 --- a/src/archive.rs +++ b/src/archive.rs @@ -151,7 +151,11 @@ impl Document { pub fn new(domain: &str, path: &str, base_dir: PathBuf) -> Self { Self { domain: domain.to_string(), - path: path.to_string(), + path: path + .split('/') + .filter(|x| !x.is_empty()) + .collect::>() + .join("/"), base_dir, } } diff --git a/src/main.rs b/src/main.rs index 5f0db31..ba26c6f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -34,6 +34,7 @@ async fn launch() -> _ { .mount( "/", routes![ + based::htmx::htmx_script_route, pages::index, pages::render_website, pages::domain_info_route, diff --git a/src/pages/component.rs b/src/pages/component.rs index a7931ba..d5e2f95 100644 --- a/src/pages/component.rs +++ b/src/pages/component.rs @@ -59,10 +59,18 @@ pub fn gen_path_link( /// /// # Returns /// A `PreEscaped` containing the HTML markup for the path header. -pub fn gen_path_header(path_seperations: Vec<&str>, domain: &str) -> PreEscaped { +pub fn gen_path_header( + path_seperations: Vec<&str>, + domain: &str, + link: bool, +) -> PreEscaped { html! { @for (index, path) in path_seperations.iter().enumerate() { + @if link { (gen_path_link(path, index, &path_seperations, domain)) + } @else { + p { (path) } + } @if index < path_seperations.len()-1 { (slash_seperator()) }; @@ -79,6 +87,7 @@ pub async fn render_page(content: PreEscaped, ctx: RequestContext) -> St html! { script src="https://cdn.tailwindcss.com" {}; meta name="viewport" content="width=device-width, initial-scale=1.0" {}; + script src="/assets/htmx.min.js" {}; }, html! {}, Some("bg-zinc-950 text-white min-h-screen flex pt-8 justify-center".to_string()), diff --git a/src/pages/mod.rs b/src/pages/mod.rs index 9a722f5..562281f 100644 --- a/src/pages/mod.rs +++ b/src/pages/mod.rs @@ -1,7 +1,12 @@ use std::{io::Read, path::PathBuf}; -use based::request::{assets::DataResponse, respond_json, RequestContext, StringResponse}; -use maud::html; +use based::{ + page::search::Search, + request::{ + api::GeneratedPager, assets::DataResponse, respond_json, RequestContext, StringResponse, + }, +}; +use maud::{html, PreEscaped}; use rocket::{get, State}; pub mod component; @@ -9,10 +14,12 @@ use component::*; use serde_json::json; use crate::{ - ai::{generate_embedding, EmbedStore}, + ai::{generate_embedding, DocEmbedding, EmbedStore}, archive::WebsiteArchive, }; +const SEARCH_BAR_STYLE: &'static str = "w-full px-4 mb-4 py-2 text-white bg-black border-2 border-neon-blue placeholder-neon-blue focus:ring-2 focus:ring-neon-pink focus:outline-none font-mono text-lg"; + /// Get the favicon of a domain #[get("/favicon/")] pub async fn favicon_route(domain: &str) -> Option { @@ -36,9 +43,17 @@ pub async fn index(ctx: RequestContext, arc: &State) -> StringRe let content = html! { div class="container mx-auto p-4" { + + div class="mb-4" { + input type="search" name="query" placeholder="Search..." class=(SEARCH_BAR_STYLE) + hx-get=("/vector_search") + hx-target="#website_grid" hx-push-url="true" hx-swap="outerHTML" {}; + }; + + + div id="website_grid" { h1 class="text-5xl font-bold text-center mb-10" { "Websites" }; div class="grid grid-cols-2 sm:grid-cols-3 lg:grid-cols-5 xl:grid-cols-6 2xl:grid-cols-8 gap-6" { - @for site in websites { a href=(format!("/d/{site}")) class="bg-neutral-900 shadow-md rounded-lg hover:bg-neutral-800 bg-gray-1 hover:cursor-pointer transition-all duration-300 flex flex-col items-center justify-center aspect-square max-w-60" { div class="bg-blue-500 text-white rounded-full p-4" { @@ -48,6 +63,7 @@ pub async fn index(ctx: RequestContext, arc: &State) -> StringRe }; }; }; + }; } }; @@ -75,7 +91,7 @@ pub async fn domain_info_route( img class="p-2" src=(format!("/favicon/{}", &domain.name)) {}; a href=(format!("/d/{}", &domain.name)) { (domain.name) }; (slash_seperator()) - (gen_path_header(path_seperations, &domain.name)) + (gen_path_header(path_seperations, &domain.name, true)) }; @if !versions.is_empty() { @@ -160,27 +176,66 @@ pub async fn render_website( None } -#[get("/vector_search?")] -pub async fn vector_search(query: &str) -> Option { +pub fn gen_search_element(x: &DocEmbedding) -> PreEscaped { + html! { + div class="text-xl font-bold mt-4 p-4 flex items-center w-full max-w-4xl max-h-40 mx-auto bg-neutral-800 shadow-md rounded-lg overflow-hidden border border-neutral-900 hover:cursor-pointer" + hx-get=(format!("/d/{}/{}", x.domain, x.path)) + hx-target="#main_content" hx-push-url="true" hx-swap="innerHTML" + { + img class="p-2" src=(format!("/favicon/{}", &x.domain)); + a { (x.domain) }; + (slash_seperator()); + (gen_path_header(x.path.split('/').collect(), &x.domain, false)); + p class="font-bold p-2 text-stone-400" { (format!("{:.2} %", x.similarity * 100.0)) }; + }; + } +} + +#[get("/vector_search?&")] +pub async fn vector_search( + query: Option<&str>, + page: Option, + ctx: RequestContext, +) -> Option { if std::env::var("OLLAMA_URL").is_err() { return None; } - if query.ends_with(".json") { - let query = query.trim_end_matches(".json"); - let results = EmbedStore::search_vector(&pgvector::Vector::from( - generate_embedding(query.to_string()).await?, - )) + let page = page.unwrap_or(1); + + // Search + let search = + Search::new("/vector_search".to_string()).search_class(SEARCH_BAR_STYLE.to_string()); + + if let Some(query) = query { + // If we have query + let real_query = query.trim_end_matches(".json"); + + // Search Results + let vector = pgvector::Vector::from(generate_embedding(real_query.to_string()).await?); + + let results = GeneratedPager::new( + |input, offset, limit| { + Box::pin(async move { + EmbedStore::search_vector(&input, limit as i64, offset as i64).await + }) + }, + 5, + ) + .pager(page as u64, vector) .await; - return Some(respond_json(&json!(&results))); + + // API Route + if query.ends_with(".json") { + return Some(respond_json(&json!(&results.page(page as u64)))); + } + + let content = search.build_response(&ctx, results, page, real_query, gen_search_element); + + return Some(render_page(content, ctx).await); } - let results = EmbedStore::search_vector(&pgvector::Vector::from( - generate_embedding(query.to_string()).await?, - )) - .await; - - // TODO : Implement Search UI with HTMX - - None + // Return new search site + let content = search.build("", html! {}); + Some(render_page(content, ctx).await) }