use std::collections::VecDeque; use based::{get_pg, request::api::ToAPI, result::LogNoneAndPass}; use ollama_rs::generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest}; use serde::Serialize; use serde_json::json; use sqlx::FromRow; use crate::archive::{Document, Domain, WebsiteArchive}; #[derive(Debug, Clone, FromRow, Serialize)] pub struct DocEmbedding { pub domain: String, pub path: String, pub ver: String, #[serde(skip)] embed_mxbai_embed_large: pgvector::Vector, #[sqlx(default)] pub similarity: f64, } impl ToAPI for DocEmbedding { async fn api(&self) -> serde_json::Value { json!({ "domain": self.domain, "path": self.path, "ver": self.ver, "similarity": self.similarity }) } } pub trait Embedding { fn embedding(&self, ver: Option) -> impl std::future::Future>>; } impl Embedding for Document { async fn embedding(&self, ver: Option) -> Option> { let latest = "latest".to_string(); log::info!( "Generating Vector embeddings for {} / {} @ {}", self.domain, self.path, ver.as_ref().unwrap_or(&latest) ); let content_html = self.render_local(ver).await?; let content = html2md::parse_html(&content_html); generate_embedding(content).await } } pub async fn generate_embedding(mut input: String) -> Option> { if let Ok(ollama_url) = std::env::var("OLLAMA_URL") { let (host, port) = ollama_url.split_once(':')?; let ollama = ollama_rs::Ollama::new(format!("http://{host}"), port.parse().ok()?); let models = ollama.list_local_models().await.ok()?; if !models .into_iter() .any(|x| x.name.starts_with("mxbai-embed-large")) { log::info!("Model not found. Pulling 'mxbai-embed-large'"); ollama .pull_model("mxbai-embed-large".to_string(), false) .await .ok()?; } if input.is_empty() { input = " ".to_string(); } let res = ollama .generate_embeddings(GenerateEmbeddingsRequest::new( "mxbai-embed-large".to_string(), EmbeddingsInput::Single(input), )) .await .ok()?; let embed = res.embeddings.first()?; return Some(embed.clone()); } None } pub struct EmbedStore; impl EmbedStore { pub async fn get_embedding(doc: &Document, ver: Option<&str>) -> Option { let use_ver = ver.map_or_else( || { let version = doc.versions(); version.first().unwrap().clone() }, |x| x.to_string(), ); sqlx::query_as("SELECT * FROM doc_embedding WHERE domain = $1 AND path = $2 AND ver = $3") .bind(&doc.domain) .bind(&doc.path) .bind(use_ver) .fetch_optional(get_pg!()) .await .unwrap() } pub async fn embed_document(doc: &Document, ver: &str) { if let Some(embed) = doc .embedding(Some(ver.to_string())) .await .log_warn_none_and_pass(|| { format!( "No embeds could be generated for {} / {}", doc.domain, doc.path ) }) { let _ = sqlx::query( "DELETE FROM doc_embedding WHERE domain = $1 AND path = $2 AND ver = $3", ) .bind(&doc.domain) .bind(&doc.path) .bind(ver) .execute(get_pg!()) .await; sqlx::query("INSERT INTO doc_embedding VALUES ($1, $2, $3, $4)") .bind(&doc.domain) .bind(&doc.path) .bind(ver) .bind(embed) .execute(get_pg!()) .await .unwrap(); } } pub async fn ensure_embedding(doc: &Document) { for ver in doc.versions() { if Self::get_embedding(doc, Some(ver.as_str())).await.is_none() { Self::embed_document(doc, &ver).await; } } } pub async fn search_vector(v: &pgvector::Vector, limit: i64, offset: i64) -> Vec { sqlx::query_as( "SELECT *, 1 / (1 + (embed_mxbai_embed_large <-> $1)) AS similarity FROM doc_embedding ORDER BY embed_mxbai_embed_large <-> $1 LIMIT $2 OFFSET $3", ) .bind(v) .bind(limit) .bind(offset) .fetch_all(get_pg!()) .await .unwrap() } pub async fn generate_embeddings_for(arc: &WebsiteArchive) { log::info!("Generating embeddings"); for dom in arc.domains() { let dom = arc.get_domain(&dom); embed_path(&dom, "/").await; } log::info!("Done generating embeddings"); } } pub async fn embed_path(dom: &Domain, path: &str) { let (paths, is_doc) = dom.paths(path); // If the path is a document, process the root path. if is_doc { let doc = dom.path("/"); EmbedStore::ensure_embedding(&doc).await; } // Create a queue to process paths iteratively let mut queue = VecDeque::new(); // Add the initial paths to the queue queue.extend(paths); while let Some(next_path) = queue.pop_front() { let (next_paths, is_doc) = dom.paths(next_path.path()); if is_doc { let doc = dom.path(next_path.path()); EmbedStore::ensure_embedding(&doc).await; } queue.extend(next_paths); } }