From 75262802238d1f656204d5f3acf8504f00b24767 Mon Sep 17 00:00:00 2001 From: JMARyA Date: Mon, 30 Dec 2024 21:25:40 +0100 Subject: [PATCH] add vector search --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/ai.rs | 26 +++++++++-- src/archive.rs | 6 ++- src/main.rs | 1 + src/pages/component.rs | 11 ++++- src/pages/mod.rs | 97 +++++++++++++++++++++++++++++++++--------- 7 files changed, 116 insertions(+), 29 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7d196c4..88425d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -164,7 +164,7 @@ checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" [[package]] name = "based" version = "0.1.0" -source = "git+https://git.hydrar.de/jmarya/based#d6555edc29de66ff5190b716a1f8ebac8dbb2110" +source = "git+https://git.hydrar.de/jmarya/based#00bb6f152d758252d62a511705ef35c8aa118168" dependencies = [ "bcrypt", "chrono", diff --git a/Cargo.toml b/Cargo.toml index 1da2ff1..1385b30 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ tokio = { version = "1.35.1", features = ["full"] } uuid = { version = "1.8.0", features = ["v4", "serde"] } sqlx = { version = "0.8", features = ["postgres", "runtime-tokio-native-tls", "derive", "uuid", "chrono", "json"] } maud = "0.26.0" -based = { git = "https://git.hydrar.de/jmarya/based", features = [] } +based = { git = "https://git.hydrar.de/jmarya/based", features = ["htmx"] } url = "2.5.4" reqwest = "0.12.11" ollama-rs = "0.2.2" diff --git a/src/ai.rs b/src/ai.rs index 31a094b..7f103ce 100644 --- a/src/ai.rs +++ b/src/ai.rs @@ -1,8 +1,9 @@ use std::collections::VecDeque; -use based::get_pg; +use based::{get_pg, request::api::ToAPI}; use ollama_rs::generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest}; use serde::Serialize; +use serde_json::json; use sqlx::FromRow; use crate::archive::{Document, Domain, WebsiteArchive}; @@ -20,6 +21,17 @@ pub struct DocEmbedding { pub similarity: f64, } +impl ToAPI for DocEmbedding { + async fn api(&self) -> serde_json::Value { + json!({ + "domain": self.domain, + "path": self.path, + "ver": self.ver, + "similarity": self.similarity + }) + } +} + pub trait Embedding { fn embedding(&self, ver: Option) -> impl std::future::Future>>; @@ -41,7 +53,7 @@ impl Embedding for Document { } } -pub async fn generate_embedding(input: String) -> Option> { +pub async fn generate_embedding(mut input: String) -> Option> { if let Ok(ollama_url) = std::env::var("OLLAMA_URL") { let (host, port) = ollama_url.split_once(':')?; let ollama = ollama_rs::Ollama::new(format!("http://{host}"), port.parse().ok()?); @@ -59,6 +71,10 @@ pub async fn generate_embedding(input: String) -> Option> { .ok()?; } + if input.is_empty() { + input = " ".to_string(); + } + let res = ollama .generate_embeddings(GenerateEmbeddingsRequest::new( "mxbai-embed-large".to_string(), @@ -129,11 +145,13 @@ impl EmbedStore { } } - pub async fn search_vector(v: &pgvector::Vector) -> Vec { + pub async fn search_vector(v: &pgvector::Vector, limit: i64, offset: i64) -> Vec { sqlx::query_as( - "SELECT *, 1 / (1 + (embed_mxbai_embed_large <-> $1)) AS similarity FROM doc_embedding ORDER BY embed_mxbai_embed_large <-> $1 LIMIT 5", + "SELECT *, 1 / (1 + (embed_mxbai_embed_large <-> $1)) AS similarity FROM doc_embedding ORDER BY embed_mxbai_embed_large <-> $1 LIMIT $2 OFFSET $3", ) .bind(v) + .bind(limit) + .bind(offset) .fetch_all(get_pg!()) .await .unwrap() diff --git a/src/archive.rs b/src/archive.rs index 05e6290..16475e1 100644 --- a/src/archive.rs +++ b/src/archive.rs @@ -151,7 +151,11 @@ impl Document { pub fn new(domain: &str, path: &str, base_dir: PathBuf) -> Self { Self { domain: domain.to_string(), - path: path.to_string(), + path: path + .split('/') + .filter(|x| !x.is_empty()) + .collect::>() + .join("/"), base_dir, } } diff --git a/src/main.rs b/src/main.rs index 5f0db31..ba26c6f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -34,6 +34,7 @@ async fn launch() -> _ { .mount( "/", routes![ + based::htmx::htmx_script_route, pages::index, pages::render_website, pages::domain_info_route, diff --git a/src/pages/component.rs b/src/pages/component.rs index a7931ba..d5e2f95 100644 --- a/src/pages/component.rs +++ b/src/pages/component.rs @@ -59,10 +59,18 @@ pub fn gen_path_link( /// /// # Returns /// A `PreEscaped` containing the HTML markup for the path header. -pub fn gen_path_header(path_seperations: Vec<&str>, domain: &str) -> PreEscaped { +pub fn gen_path_header( + path_seperations: Vec<&str>, + domain: &str, + link: bool, +) -> PreEscaped { html! { @for (index, path) in path_seperations.iter().enumerate() { + @if link { (gen_path_link(path, index, &path_seperations, domain)) + } @else { + p { (path) } + } @if index < path_seperations.len()-1 { (slash_seperator()) }; @@ -79,6 +87,7 @@ pub async fn render_page(content: PreEscaped, ctx: RequestContext) -> St html! { script src="https://cdn.tailwindcss.com" {}; meta name="viewport" content="width=device-width, initial-scale=1.0" {}; + script src="/assets/htmx.min.js" {}; }, html! {}, Some("bg-zinc-950 text-white min-h-screen flex pt-8 justify-center".to_string()), diff --git a/src/pages/mod.rs b/src/pages/mod.rs index 9a722f5..562281f 100644 --- a/src/pages/mod.rs +++ b/src/pages/mod.rs @@ -1,7 +1,12 @@ use std::{io::Read, path::PathBuf}; -use based::request::{assets::DataResponse, respond_json, RequestContext, StringResponse}; -use maud::html; +use based::{ + page::search::Search, + request::{ + api::GeneratedPager, assets::DataResponse, respond_json, RequestContext, StringResponse, + }, +}; +use maud::{html, PreEscaped}; use rocket::{get, State}; pub mod component; @@ -9,10 +14,12 @@ use component::*; use serde_json::json; use crate::{ - ai::{generate_embedding, EmbedStore}, + ai::{generate_embedding, DocEmbedding, EmbedStore}, archive::WebsiteArchive, }; +const SEARCH_BAR_STYLE: &'static str = "w-full px-4 mb-4 py-2 text-white bg-black border-2 border-neon-blue placeholder-neon-blue focus:ring-2 focus:ring-neon-pink focus:outline-none font-mono text-lg"; + /// Get the favicon of a domain #[get("/favicon/")] pub async fn favicon_route(domain: &str) -> Option { @@ -36,9 +43,17 @@ pub async fn index(ctx: RequestContext, arc: &State) -> StringRe let content = html! { div class="container mx-auto p-4" { + + div class="mb-4" { + input type="search" name="query" placeholder="Search..." class=(SEARCH_BAR_STYLE) + hx-get=("/vector_search") + hx-target="#website_grid" hx-push-url="true" hx-swap="outerHTML" {}; + }; + + + div id="website_grid" { h1 class="text-5xl font-bold text-center mb-10" { "Websites" }; div class="grid grid-cols-2 sm:grid-cols-3 lg:grid-cols-5 xl:grid-cols-6 2xl:grid-cols-8 gap-6" { - @for site in websites { a href=(format!("/d/{site}")) class="bg-neutral-900 shadow-md rounded-lg hover:bg-neutral-800 bg-gray-1 hover:cursor-pointer transition-all duration-300 flex flex-col items-center justify-center aspect-square max-w-60" { div class="bg-blue-500 text-white rounded-full p-4" { @@ -48,6 +63,7 @@ pub async fn index(ctx: RequestContext, arc: &State) -> StringRe }; }; }; + }; } }; @@ -75,7 +91,7 @@ pub async fn domain_info_route( img class="p-2" src=(format!("/favicon/{}", &domain.name)) {}; a href=(format!("/d/{}", &domain.name)) { (domain.name) }; (slash_seperator()) - (gen_path_header(path_seperations, &domain.name)) + (gen_path_header(path_seperations, &domain.name, true)) }; @if !versions.is_empty() { @@ -160,27 +176,66 @@ pub async fn render_website( None } -#[get("/vector_search?")] -pub async fn vector_search(query: &str) -> Option { +pub fn gen_search_element(x: &DocEmbedding) -> PreEscaped { + html! { + div class="text-xl font-bold mt-4 p-4 flex items-center w-full max-w-4xl max-h-40 mx-auto bg-neutral-800 shadow-md rounded-lg overflow-hidden border border-neutral-900 hover:cursor-pointer" + hx-get=(format!("/d/{}/{}", x.domain, x.path)) + hx-target="#main_content" hx-push-url="true" hx-swap="innerHTML" + { + img class="p-2" src=(format!("/favicon/{}", &x.domain)); + a { (x.domain) }; + (slash_seperator()); + (gen_path_header(x.path.split('/').collect(), &x.domain, false)); + p class="font-bold p-2 text-stone-400" { (format!("{:.2} %", x.similarity * 100.0)) }; + }; + } +} + +#[get("/vector_search?&")] +pub async fn vector_search( + query: Option<&str>, + page: Option, + ctx: RequestContext, +) -> Option { if std::env::var("OLLAMA_URL").is_err() { return None; } - if query.ends_with(".json") { - let query = query.trim_end_matches(".json"); - let results = EmbedStore::search_vector(&pgvector::Vector::from( - generate_embedding(query.to_string()).await?, - )) + let page = page.unwrap_or(1); + + // Search + let search = + Search::new("/vector_search".to_string()).search_class(SEARCH_BAR_STYLE.to_string()); + + if let Some(query) = query { + // If we have query + let real_query = query.trim_end_matches(".json"); + + // Search Results + let vector = pgvector::Vector::from(generate_embedding(real_query.to_string()).await?); + + let results = GeneratedPager::new( + |input, offset, limit| { + Box::pin(async move { + EmbedStore::search_vector(&input, limit as i64, offset as i64).await + }) + }, + 5, + ) + .pager(page as u64, vector) .await; - return Some(respond_json(&json!(&results))); + + // API Route + if query.ends_with(".json") { + return Some(respond_json(&json!(&results.page(page as u64)))); + } + + let content = search.build_response(&ctx, results, page, real_query, gen_search_element); + + return Some(render_page(content, ctx).await); } - let results = EmbedStore::search_vector(&pgvector::Vector::from( - generate_embedding(query.to_string()).await?, - )) - .await; - - // TODO : Implement Search UI with HTMX - - None + // Return new search site + let content = search.build("", html! {}); + Some(render_page(content, ctx).await) }