From 609d52098632352239a7ef7c2e7d9ed45c32b874 Mon Sep 17 00:00:00 2001 From: JMARyA Date: Mon, 30 Dec 2024 14:06:32 +0100 Subject: [PATCH] add vector db --- Cargo.lock | 304 +++++++++++++++++++++++++++++++--- Cargo.toml | 3 + docker-compose.yml | 13 ++ env | 6 + migrations/0001_embedding.sql | 10 ++ src/ai.rs | 179 ++++++++++++++++++++ src/archive.rs | 12 +- src/main.rs | 15 +- src/pages/mod.rs | 43 +++-- 9 files changed, 547 insertions(+), 38 deletions(-) create mode 100644 migrations/0001_embedding.sql create mode 100644 src/ai.rs diff --git a/Cargo.lock b/Cargo.lock index 10e9ca6..7d196c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -78,7 +78,7 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -89,7 +89,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -276,6 +276,12 @@ dependencies = [ "shlex", ] +[[package]] +name = "cesu8" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" + [[package]] name = "cfg-if" version = "1.0.0" @@ -307,6 +313,16 @@ dependencies = [ "inout", ] +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "memchr", +] + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -487,7 +503,7 @@ dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -510,7 +526,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -649,6 +665,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futf" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df420e2e84819663797d1ec6544b13c5be84629e7bb00dc960d6917db2987843" +dependencies = [ + "mac", + "new_debug_unreachable", +] + [[package]] name = "futures" version = "0.3.31" @@ -716,7 +742,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -909,6 +935,34 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "html2md" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be92446e11d68f5d71367d571c229d09ced1f24ab6d08ea0bff329d5f6c0b2a3" +dependencies = [ + "html5ever", + "jni", + "lazy_static", + "markup5ever_rcdom", + "percent-encoding", + "regex", +] + +[[package]] +name = "html5ever" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bea68cab48b8459f17cf1c944c67ddc572d272d9f2b274140f223ecb1da4a3b7" +dependencies = [ + "log", + "mac", + "markup5ever", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "http" version = "0.2.12" @@ -1230,7 +1284,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -1303,6 +1357,26 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" +[[package]] +name = "jni" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6df18c2e3db7e453d3c6ac5b3e9d5182664d28788126d39b91f2d1e22b017ec" +dependencies = [ + "cesu8", + "combine", + "jni-sys", + "log", + "thiserror", + "walkdir", +] + +[[package]] +name = "jni-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" + [[package]] name = "js-sys" version = "0.3.76" @@ -1388,6 +1462,38 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "mac" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4" + +[[package]] +name = "markup5ever" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2629bb1404f3d34c2e921f21fd34ba00b206124c81f65c50b43b6aaefeb016" +dependencies = [ + "log", + "phf", + "phf_codegen", + "string_cache", + "string_cache_codegen", + "tendril", +] + +[[package]] +name = "markup5ever_rcdom" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9521dd6750f8e80ee6c53d65e2e4656d7de37064f3a7a5d2d11d05df93839c2" +dependencies = [ + "html5ever", + "markup5ever", + "tendril", + "xml5ever", +] + [[package]] name = "matchers" version = "0.1.0" @@ -1416,7 +1522,7 @@ dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -1503,6 +1609,12 @@ dependencies = [ "tempfile", ] +[[package]] +name = "new_debug_unreachable" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" + [[package]] name = "nom" version = "7.1.3" @@ -1595,6 +1707,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "ollama-rs" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "763afb01db2dced00e656cc2cdcd875659fc3fac4c449e6337a4f04f9e3d9efc" +dependencies = [ + "async-stream", + "async-trait", + "log", + "reqwest 0.12.11", + "serde", + "serde_json", + "url", +] + [[package]] name = "once_cell" version = "1.20.2" @@ -1624,7 +1751,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -1706,7 +1833,7 @@ dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -1724,6 +1851,53 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "pgvector" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0e8871b6d7ca78348c6cd29b911b94851f3429f0cd403130ca17f26c1fb91a6" +dependencies = [ + "sqlx", +] + +[[package]] +name = "phf" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fabbf1ead8a5bcbc20f5f8b939ee3f5b0f6f281b6ad3468b84656b658b455259" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_codegen" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb1c3a8bc4dd4e5cfce29b44ffc14bedd2ee294559a294e2a4d4c9e9a6a13cd" +dependencies = [ + "phf_generator", + "phf_shared", +] + +[[package]] +name = "phf_generator" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d5285893bb5eb82e6aaf5d59ee909a06a16737a8970984dd7746ba9283498d6" +dependencies = [ + "phf_shared", + "rand", +] + +[[package]] +name = "phf_shared" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6796ad771acdc0123d2a88dc428b5e38ef24456743ddb1744ed628f9815c096" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project-lite" version = "0.2.15" @@ -1778,6 +1952,12 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "precomputed-hash" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -1818,7 +1998,7 @@ checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", "version_check", "yansi", ] @@ -1908,7 +2088,7 @@ checksum = "bcc303e793d3734489387d205e9b186fac9c6cfacedd98cbb2e8a5943595f3e6" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -2119,7 +2299,7 @@ dependencies = [ "proc-macro2", "quote", "rocket_http", - "syn", + "syn 2.0.93", "unicode-xid", "version_check", ] @@ -2337,7 +2517,7 @@ checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -2429,6 +2609,12 @@ dependencies = [ "rand_core", ] +[[package]] +name = "siphasher" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" + [[package]] name = "slab" version = "0.4.9" @@ -2557,7 +2743,7 @@ dependencies = [ "quote", "sqlx-core", "sqlx-macros-core", - "syn", + "syn 2.0.93", ] [[package]] @@ -2580,7 +2766,7 @@ dependencies = [ "sqlx-mysql", "sqlx-postgres", "sqlx-sqlite", - "syn", + "syn 2.0.93", "tempfile", "tokio", "url", @@ -2719,6 +2905,32 @@ dependencies = [ "loom", ] +[[package]] +name = "string_cache" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f91138e76242f575eb1d3b38b4f1362f10d3a43f47d182a5b359af488a02293b" +dependencies = [ + "new_debug_unreachable", + "once_cell", + "parking_lot", + "phf_shared", + "precomputed-hash", + "serde", +] + +[[package]] +name = "string_cache_codegen" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bb30289b722be4ff74a408c3cc27edeaad656e06cb1fe8fa9231fa59c728988" +dependencies = [ + "phf_generator", + "phf_shared", + "proc-macro2", + "quote", +] + [[package]] name = "stringprep" version = "0.1.5" @@ -2736,6 +2948,17 @@ version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.93" @@ -2770,7 +2993,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -2828,6 +3051,17 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "tendril" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d24a120c5fc464a3458240ee02c299ebcb9d67b5249c8848b09d639dca8d7bb0" +dependencies = [ + "futf", + "mac", + "utf-8", +] + [[package]] name = "termcolor" version = "1.4.1" @@ -2854,7 +3088,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -2949,7 +3183,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -3077,7 +3311,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -3228,6 +3462,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf16_iter" version = "1.0.5" @@ -3320,7 +3560,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn", + "syn 2.0.93", "wasm-bindgen-shared", ] @@ -3355,7 +3595,7 @@ checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3384,8 +3624,11 @@ dependencies = [ "chrono", "env_logger", "futures", + "html2md", "log", "maud", + "ollama-rs", + "pgvector", "regex", "reqwest 0.12.11", "rocket", @@ -3665,6 +3908,17 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" +[[package]] +name = "xml5ever" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4034e1d05af98b51ad7214527730626f019682d797ba38b51689212118d8e650" +dependencies = [ + "log", + "mac", + "markup5ever", +] + [[package]] name = "yansi" version = "1.0.1" @@ -3694,7 +3948,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", "synstructure", ] @@ -3716,7 +3970,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] [[package]] @@ -3736,7 +3990,7 @@ checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", "synstructure", ] @@ -3765,5 +4019,5 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.93", ] diff --git a/Cargo.toml b/Cargo.toml index cc5a0f8..1da2ff1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,3 +19,6 @@ maud = "0.26.0" based = { git = "https://git.hydrar.de/jmarya/based", features = [] } url = "2.5.4" reqwest = "0.12.11" +ollama-rs = "0.2.2" +pgvector = { version = "0.4", features = ["sqlx"] } +html2md = "0.2.14" diff --git a/docker-compose.yml b/docker-compose.yml index ef00b97..00690e8 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -7,3 +7,16 @@ services: - ./websites:/websites - ./favicon:/favicon env_file: env + + postgres: + # Any Postgres with support for pgvector + image: git.hydrar.de/hydra/postgres:latest + restart: always + ports: + - 5432:5432 + volumes: + - ./db:/var/lib/postgresql/data/ + environment: + - POSTGRES_USER=user + - POSTGRES_PASSWORD=pass + - POSTGRES_DB=webarc diff --git a/env b/env index 61f6371..c2354bd 100644 --- a/env +++ b/env @@ -10,3 +10,9 @@ DOWNLOAD_ON_DEMAND=true # Blacklisted domains (Comma-seperated regex) BLACKLIST_DOMAINS="google.com,.*.youtube.com" + +# Database +DATABASE_URL=postgres://user:pass@postgres/webarc + +# Ollama URL (Enables vector search) +OLLAMA_URL=127.0.0.1:11434 diff --git a/migrations/0001_embedding.sql b/migrations/0001_embedding.sql new file mode 100644 index 0000000..1826b2b --- /dev/null +++ b/migrations/0001_embedding.sql @@ -0,0 +1,10 @@ + +CREATE EXTENSION IF NOT EXISTS vector; + +CREATE TABLE doc_embedding ( + domain VARCHAR(500) NOT NULL, + path VARCHAR(1000) NOT NULL, + ver VARCHAR(10) NOT NULL, + embed_mxbai_embed_large vector(1024) NOT NULL, + PRIMARY KEY (domain, path, ver) +) diff --git a/src/ai.rs b/src/ai.rs new file mode 100644 index 0000000..31a094b --- /dev/null +++ b/src/ai.rs @@ -0,0 +1,179 @@ +use std::collections::VecDeque; + +use based::get_pg; +use ollama_rs::generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest}; +use serde::Serialize; +use sqlx::FromRow; + +use crate::archive::{Document, Domain, WebsiteArchive}; + +#[derive(Debug, Clone, FromRow, Serialize)] +pub struct DocEmbedding { + pub domain: String, + pub path: String, + pub ver: String, + + #[serde(skip)] + embed_mxbai_embed_large: pgvector::Vector, + + #[sqlx(default)] + pub similarity: f64, +} + +pub trait Embedding { + fn embedding(&self, ver: Option) + -> impl std::future::Future>>; +} + +impl Embedding for Document { + async fn embedding(&self, ver: Option) -> Option> { + let latest = "latest".to_string(); + log::info!( + "Generating Vector embeddings for {} / {} @ {}", + self.domain, + self.path, + ver.as_ref().unwrap_or(&latest) + ); + + let content_html = self.render_local(ver).await.unwrap(); + let content = html2md::parse_html(&content_html); + generate_embedding(content).await + } +} + +pub async fn generate_embedding(input: String) -> Option> { + if let Ok(ollama_url) = std::env::var("OLLAMA_URL") { + let (host, port) = ollama_url.split_once(':')?; + let ollama = ollama_rs::Ollama::new(format!("http://{host}"), port.parse().ok()?); + + let models = ollama.list_local_models().await.ok()?; + + if !models + .into_iter() + .any(|x| x.name.starts_with("mxbai-embed-large")) + { + log::info!("Model not found. Pulling 'mxbai-embed-large'"); + ollama + .pull_model("mxbai-embed-large".to_string(), false) + .await + .ok()?; + } + + let res = ollama + .generate_embeddings(GenerateEmbeddingsRequest::new( + "mxbai-embed-large".to_string(), + EmbeddingsInput::Single(input), + )) + .await + .ok()?; + let embed = res.embeddings.first()?; + return Some(embed.clone()); + } + + None +} + +pub struct EmbedStore; + +impl EmbedStore { + pub async fn get_embedding(doc: &Document, ver: Option<&str>) -> Option { + let use_ver = ver.map_or_else( + || { + let version = doc.versions(); + version.first().unwrap().clone() + }, + |x| x.to_string(), + ); + sqlx::query_as("SELECT * FROM doc_embedding WHERE domain = $1 AND path = $2 AND ver = $3") + .bind(&doc.domain) + .bind(&doc.path) + .bind(use_ver) + .fetch_optional(get_pg!()) + .await + .unwrap() + } + + pub async fn embed_document(doc: &Document, ver: &str) { + if let Some(embed) = doc.embedding(Some(ver.to_string())).await { + let _ = sqlx::query( + "DELETE FROM doc_embedding WHERE domain = $1 AND path = $2 AND ver = $3", + ) + .bind(&doc.domain) + .bind(&doc.path) + .bind(ver) + .execute(get_pg!()) + .await; + + sqlx::query("INSERT INTO doc_embedding VALUES ($1, $2, $3, $4)") + .bind(&doc.domain) + .bind(&doc.path) + .bind(ver) + .bind(embed) + .execute(get_pg!()) + .await + .unwrap(); + } else { + log::warn!( + "No embeds could be generated for {} / {}", + doc.domain, + doc.path + ); + } + } + + pub async fn ensure_embedding(doc: &Document) { + for ver in doc.versions() { + if Self::get_embedding(doc, Some(ver.as_str())).await.is_none() { + Self::embed_document(doc, &ver).await; + } + } + } + + pub async fn search_vector(v: &pgvector::Vector) -> Vec { + sqlx::query_as( + "SELECT *, 1 / (1 + (embed_mxbai_embed_large <-> $1)) AS similarity FROM doc_embedding ORDER BY embed_mxbai_embed_large <-> $1 LIMIT 5", + ) + .bind(v) + .fetch_all(get_pg!()) + .await + .unwrap() + } + + pub async fn generate_embeddings_for(arc: &WebsiteArchive) { + log::info!("Generating embeddings"); + + for dom in arc.domains() { + let dom = arc.get_domain(&dom); + embed_path(&dom, "/").await; + } + + log::info!("Done generating embeddings"); + } +} + +pub async fn embed_path(dom: &Domain, path: &str) { + let (paths, is_doc) = dom.paths(path); + + // If the path is a document, process the root path. + if is_doc { + let doc = dom.path("/"); + EmbedStore::ensure_embedding(&doc).await; + } + + // Create a queue to process paths iteratively + let mut queue = VecDeque::new(); + + // Add the initial paths to the queue + queue.extend(paths); + + while let Some(next_path) = queue.pop_front() { + let (next_paths, is_doc) = dom.paths(next_path.path()); + + if is_doc { + let doc = dom.path(next_path.path()); + EmbedStore::ensure_embedding(&doc).await; + } + + queue.extend(next_paths); + } +} diff --git a/src/archive.rs b/src/archive.rs index 7fc566c..05e6290 100644 --- a/src/archive.rs +++ b/src/archive.rs @@ -214,7 +214,17 @@ impl Document { pub fn versions(&self) -> Vec { let mut res: Vec = read_dir(&self.doc_dir()) .into_iter() - .filter(|x| x.starts_with("index_") && x.ends_with(".html")) + .filter_map(|x| { + if x.starts_with("index_") && x.ends_with(".html") { + return Some( + x.trim_start_matches("index_") + .trim_end_matches(".html") + .to_string(), + ); + } + + None + }) .collect(); res.sort(); res.reverse(); diff --git a/src/main.rs b/src/main.rs index df8ddb4..5f0db31 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,9 @@ +use ai::EmbedStore; use archive::WebsiteArchive; +use based::get_pg; use rocket::routes; +mod ai; mod archive; mod blacklist; mod favicon; @@ -12,6 +15,15 @@ async fn launch() -> _ { let arc = WebsiteArchive::new("./websites"); + if std::env::var("DATABASE_URL").is_ok() { + let pg = get_pg!(); + sqlx::migrate!("./migrations").run(pg).await.unwrap(); + } + + if std::env::var("OLLAMA_URL").is_ok() { + EmbedStore::generate_embeddings_for(&arc).await; + } + let archive = arc.clone(); tokio::spawn(async move { @@ -25,7 +37,8 @@ async fn launch() -> _ { pages::index, pages::render_website, pages::domain_info_route, - pages::favicon_route + pages::favicon_route, + pages::vector_search ], ) .manage(arc) diff --git a/src/pages/mod.rs b/src/pages/mod.rs index 4b1b502..9a722f5 100644 --- a/src/pages/mod.rs +++ b/src/pages/mod.rs @@ -1,13 +1,17 @@ use std::{io::Read, path::PathBuf}; -use based::request::{assets::DataResponse, RequestContext, StringResponse}; +use based::request::{assets::DataResponse, respond_json, RequestContext, StringResponse}; use maud::html; use rocket::{get, State}; pub mod component; use component::*; +use serde_json::json; -use crate::archive::WebsiteArchive; +use crate::{ + ai::{generate_embedding, EmbedStore}, + archive::WebsiteArchive, +}; /// Get the favicon of a domain #[get("/favicon/")] @@ -60,15 +64,7 @@ pub async fn domain_info_route( ) -> StringResponse { let domain = arc.get_domain(domain); let document = domain.path(paths.to_str().unwrap()); - let versions: Vec = document - .versions() - .into_iter() - .map(|x| { - x.trim_start_matches("index_") - .trim_end_matches(".html") - .to_string() - }) - .collect(); + let versions: Vec = document.versions(); let (path_entries, is_doc) = domain.paths(paths.to_str().unwrap()); let path_seperations: Vec<&str> = paths.to_str().unwrap().split('/').collect(); @@ -163,3 +159,28 @@ pub async fn render_website( None } + +#[get("/vector_search?")] +pub async fn vector_search(query: &str) -> Option { + if std::env::var("OLLAMA_URL").is_err() { + return None; + } + + if query.ends_with(".json") { + let query = query.trim_end_matches(".json"); + let results = EmbedStore::search_vector(&pgvector::Vector::from( + generate_embedding(query.to_string()).await?, + )) + .await; + return Some(respond_json(&json!(&results))); + } + + let results = EmbedStore::search_vector(&pgvector::Vector::from( + generate_embedding(query.to_string()).await?, + )) + .await; + + // TODO : Implement Search UI with HTMX + + None +}