This commit is contained in:
parent
e50d31479c
commit
6aea22576c
5 changed files with 136 additions and 25 deletions
3
env
3
env
|
@ -9,7 +9,8 @@ ROUTE_INTERNAL=true
|
||||||
DOWNLOAD_ON_DEMAND=true
|
DOWNLOAD_ON_DEMAND=true
|
||||||
|
|
||||||
# Blacklisted domains (Comma-seperated regex)
|
# Blacklisted domains (Comma-seperated regex)
|
||||||
BLACKLIST_DOMAINS="google.com,.*.youtube.com"
|
# You can blacklist sites which wont work well
|
||||||
|
BLACKLIST_DOMAINS="^gitlab"
|
||||||
|
|
||||||
# Database
|
# Database
|
||||||
DATABASE_URL=postgres://user:pass@postgres/webarc
|
DATABASE_URL=postgres://user:pass@postgres/webarc
|
||||||
|
|
|
@ -5,6 +5,7 @@ CREATE TABLE doc_embedding (
|
||||||
domain VARCHAR(500) NOT NULL,
|
domain VARCHAR(500) NOT NULL,
|
||||||
path VARCHAR(1000) NOT NULL,
|
path VARCHAR(1000) NOT NULL,
|
||||||
ver VARCHAR(10) NOT NULL,
|
ver VARCHAR(10) NOT NULL,
|
||||||
|
chunk INTEGER NOT NULL,
|
||||||
embed_mxbai_embed_large vector(1024) NOT NULL,
|
embed_mxbai_embed_large vector(1024) NOT NULL,
|
||||||
PRIMARY KEY (domain, path, ver)
|
PRIMARY KEY (domain, path, ver, chunk)
|
||||||
)
|
)
|
||||||
|
|
141
src/ai.rs
141
src/ai.rs
|
@ -1,4 +1,4 @@
|
||||||
use std::collections::VecDeque;
|
use std::collections::{HashMap, VecDeque};
|
||||||
|
|
||||||
use based::{get_pg, request::api::ToAPI, result::LogNoneAndPass};
|
use based::{get_pg, request::api::ToAPI, result::LogNoneAndPass};
|
||||||
use ollama_rs::generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest};
|
use ollama_rs::generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest};
|
||||||
|
@ -8,11 +8,15 @@ use sqlx::FromRow;
|
||||||
|
|
||||||
use crate::archive::{Document, Domain, WebsiteArchive};
|
use crate::archive::{Document, Domain, WebsiteArchive};
|
||||||
|
|
||||||
|
// TODO : Chunked embeddings + better search + ranking
|
||||||
|
// TODO : Real citese embeddings + search
|
||||||
|
|
||||||
#[derive(Debug, Clone, FromRow, Serialize)]
|
#[derive(Debug, Clone, FromRow, Serialize)]
|
||||||
pub struct DocEmbedding {
|
pub struct DocEmbedding {
|
||||||
pub domain: String,
|
pub domain: String,
|
||||||
pub path: String,
|
pub path: String,
|
||||||
pub ver: String,
|
pub ver: String,
|
||||||
|
pub chunk: i32,
|
||||||
|
|
||||||
#[serde(skip)]
|
#[serde(skip)]
|
||||||
embed_mxbai_embed_large: pgvector::Vector,
|
embed_mxbai_embed_large: pgvector::Vector,
|
||||||
|
@ -21,6 +25,53 @@ pub struct DocEmbedding {
|
||||||
pub similarity: f64,
|
pub similarity: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
pub struct SearchResult {
|
||||||
|
pub domain: String,
|
||||||
|
pub path: String,
|
||||||
|
pub chunks: Vec<DocEmbedding>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SearchResult {
|
||||||
|
pub fn new(domain: String, path: String) -> Self {
|
||||||
|
Self {
|
||||||
|
domain,
|
||||||
|
path,
|
||||||
|
chunks: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn similarity(&self) -> f64 {
|
||||||
|
total_score(&self.chunks)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn avg_sim(e: &[DocEmbedding]) -> f64 {
|
||||||
|
let mut score = 0.0;
|
||||||
|
|
||||||
|
for e in e {
|
||||||
|
score += e.similarity;
|
||||||
|
}
|
||||||
|
|
||||||
|
score / e.len() as f64
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn max_sim(e: &[DocEmbedding]) -> f64 {
|
||||||
|
let mut score = 0.0;
|
||||||
|
|
||||||
|
for e in e {
|
||||||
|
if e.similarity > score {
|
||||||
|
score = e.similarity;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
score
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn total_score(e: &[DocEmbedding]) -> f64 {
|
||||||
|
(avg_sim(e) + max_sim(e)) / 2.0
|
||||||
|
}
|
||||||
|
|
||||||
impl ToAPI for DocEmbedding {
|
impl ToAPI for DocEmbedding {
|
||||||
async fn api(&self) -> serde_json::Value {
|
async fn api(&self) -> serde_json::Value {
|
||||||
json!({
|
json!({
|
||||||
|
@ -33,12 +84,23 @@ impl ToAPI for DocEmbedding {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait Embedding {
|
pub trait Embedding {
|
||||||
fn embedding(&self, ver: Option<String>)
|
fn embedding(
|
||||||
-> impl std::future::Future<Output = Option<Vec<f32>>>;
|
&self,
|
||||||
|
ver: Option<String>,
|
||||||
|
) -> impl std::future::Future<Output = Option<Vec<Vec<f32>>>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn chunked(s: &str) -> Vec<String> {
|
||||||
|
const CHUNK_SIZE: usize = 500;
|
||||||
|
s.chars()
|
||||||
|
.collect::<Vec<char>>()
|
||||||
|
.chunks(CHUNK_SIZE)
|
||||||
|
.map(|chunk| chunk.iter().collect())
|
||||||
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Embedding for Document {
|
impl Embedding for Document {
|
||||||
async fn embedding(&self, ver: Option<String>) -> Option<Vec<f32>> {
|
async fn embedding(&self, ver: Option<String>) -> Option<Vec<Vec<f32>>> {
|
||||||
let latest = "latest".to_string();
|
let latest = "latest".to_string();
|
||||||
log::info!(
|
log::info!(
|
||||||
"Generating Vector embeddings for {} / {} @ {}",
|
"Generating Vector embeddings for {} / {} @ {}",
|
||||||
|
@ -47,9 +109,26 @@ impl Embedding for Document {
|
||||||
ver.as_ref().unwrap_or(&latest)
|
ver.as_ref().unwrap_or(&latest)
|
||||||
);
|
);
|
||||||
|
|
||||||
let content_html = self.render_local(ver).await?;
|
let content_html = self.render_local(ver.clone()).await?;
|
||||||
let content = html2md::parse_html(&content_html);
|
let content = html2md::parse_html(&content_html);
|
||||||
generate_embedding(content).await
|
|
||||||
|
let mut embeddings = Vec::new();
|
||||||
|
let content = chunked(&content);
|
||||||
|
let len = content.len();
|
||||||
|
|
||||||
|
for (index, c) in content.into_iter().enumerate() {
|
||||||
|
log::info!(
|
||||||
|
"Generating Vector embeddings for {} / {} @ {} [ {} / {} ]",
|
||||||
|
self.domain,
|
||||||
|
self.path,
|
||||||
|
ver.as_ref().unwrap_or(&latest),
|
||||||
|
index + 1,
|
||||||
|
len
|
||||||
|
);
|
||||||
|
embeddings.push(generate_embedding(c).await?);
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(embeddings)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -129,14 +208,17 @@ impl EmbedStore {
|
||||||
.execute(get_pg!())
|
.execute(get_pg!())
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
sqlx::query("INSERT INTO doc_embedding VALUES ($1, $2, $3, $4)")
|
for (index, embed) in embed.iter().enumerate() {
|
||||||
.bind(&doc.domain)
|
sqlx::query("INSERT INTO doc_embedding VALUES ($1, $2, $3, $4, $5)")
|
||||||
.bind(&doc.path)
|
.bind(&doc.domain)
|
||||||
.bind(ver)
|
.bind(&doc.path)
|
||||||
.bind(embed)
|
.bind(ver)
|
||||||
.execute(get_pg!())
|
.bind(index as i64)
|
||||||
.await
|
.bind(embed)
|
||||||
.unwrap();
|
.execute(get_pg!())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -148,8 +230,12 @@ impl EmbedStore {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn search_vector(v: &pgvector::Vector, limit: i64, offset: i64) -> Vec<DocEmbedding> {
|
pub async fn search_vector(v: &pgvector::Vector, limit: i64, offset: i64) -> Vec<SearchResult> {
|
||||||
sqlx::query_as(
|
// TODO : fix search
|
||||||
|
// + new ranked algorithm
|
||||||
|
// + better repr
|
||||||
|
|
||||||
|
let results: Vec<DocEmbedding> = sqlx::query_as(
|
||||||
"SELECT *, 1 / (1 + (embed_mxbai_embed_large <-> $1)) AS similarity FROM doc_embedding ORDER BY embed_mxbai_embed_large <-> $1 LIMIT $2 OFFSET $3",
|
"SELECT *, 1 / (1 + (embed_mxbai_embed_large <-> $1)) AS similarity FROM doc_embedding ORDER BY embed_mxbai_embed_large <-> $1 LIMIT $2 OFFSET $3",
|
||||||
)
|
)
|
||||||
.bind(v)
|
.bind(v)
|
||||||
|
@ -157,7 +243,28 @@ impl EmbedStore {
|
||||||
.bind(offset)
|
.bind(offset)
|
||||||
.fetch_all(get_pg!())
|
.fetch_all(get_pg!())
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap();
|
||||||
|
|
||||||
|
let mut search_res: HashMap<String, HashMap<String, SearchResult>> = HashMap::new();
|
||||||
|
|
||||||
|
for res in results {
|
||||||
|
let domain = search_res
|
||||||
|
.entry(res.domain.clone())
|
||||||
|
.or_insert(HashMap::new());
|
||||||
|
let doc = domain
|
||||||
|
.entry(res.path.clone())
|
||||||
|
.or_insert(SearchResult::new(res.domain.clone(), res.path.clone()));
|
||||||
|
doc.chunks.push(res);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut flat = search_res
|
||||||
|
.into_values()
|
||||||
|
.map(|x| x.into_values().collect::<Vec<SearchResult>>())
|
||||||
|
.flatten()
|
||||||
|
.collect::<Vec<SearchResult>>();
|
||||||
|
|
||||||
|
flat.sort_by(|a, b| b.chunks.len().cmp(&a.chunks.len()));
|
||||||
|
flat
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn generate_embeddings_for(arc: &WebsiteArchive) {
|
pub async fn generate_embeddings_for(arc: &WebsiteArchive) {
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use std::{io::Read, path::PathBuf};
|
use std::{io::Read, path::PathBuf};
|
||||||
|
|
||||||
use based::request::RequestContext;
|
use based::{request::RequestContext, result::LogAndIgnore};
|
||||||
use maud::html;
|
use maud::html;
|
||||||
|
|
||||||
use crate::{blacklist::check_blacklist, favicon::download_fav_for, pages::component::render_page};
|
use crate::{blacklist::check_blacklist, favicon::download_fav_for, pages::component::render_page};
|
||||||
|
@ -62,7 +62,8 @@ impl Domain {
|
||||||
/// A new `Domain` instance.
|
/// A new `Domain` instance.
|
||||||
pub fn new(name: &str, dir: PathBuf) -> Self {
|
pub fn new(name: &str, dir: PathBuf) -> Self {
|
||||||
if !check_blacklist(name) {
|
if !check_blacklist(name) {
|
||||||
std::fs::create_dir_all(&dir).unwrap();
|
std::fs::create_dir_all(&dir)
|
||||||
|
.log_err_and_ignore(&format!("Could not create domain dir {name}"));
|
||||||
}
|
}
|
||||||
Self {
|
Self {
|
||||||
name: name.to_string(),
|
name: name.to_string(),
|
||||||
|
|
|
@ -14,7 +14,7 @@ use component::*;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
ai::{generate_embedding, DocEmbedding, EmbedStore},
|
ai::{generate_embedding, EmbedStore, SearchResult},
|
||||||
archive::WebsiteArchive,
|
archive::WebsiteArchive,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -176,7 +176,7 @@ pub async fn render_website(
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn gen_search_element(x: &DocEmbedding) -> PreEscaped<String> {
|
pub fn gen_search_element(x: &SearchResult) -> PreEscaped<String> {
|
||||||
html! {
|
html! {
|
||||||
div class="text-xl font-bold mt-4 p-4 flex items-center w-full max-w-4xl max-h-40 mx-auto bg-neutral-800 shadow-md rounded-lg overflow-hidden border border-neutral-900 hover:cursor-pointer"
|
div class="text-xl font-bold mt-4 p-4 flex items-center w-full max-w-4xl max-h-40 mx-auto bg-neutral-800 shadow-md rounded-lg overflow-hidden border border-neutral-900 hover:cursor-pointer"
|
||||||
hx-get=(format!("/d/{}/{}", x.domain, x.path))
|
hx-get=(format!("/d/{}/{}", x.domain, x.path))
|
||||||
|
@ -186,7 +186,7 @@ pub fn gen_search_element(x: &DocEmbedding) -> PreEscaped<String> {
|
||||||
a { (x.domain) };
|
a { (x.domain) };
|
||||||
(slash_seperator());
|
(slash_seperator());
|
||||||
(gen_path_header(x.path.split('/').collect(), &x.domain, false));
|
(gen_path_header(x.path.split('/').collect(), &x.domain, false));
|
||||||
p class="font-bold p-2 text-stone-400" { (format!("{:.2} %", x.similarity * 100.0)) };
|
p class="font-bold p-2 text-stone-400" { (format!("{:.2} % [{} matches]", x.similarity() * 100.0, x.chunks.len())) };
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -220,13 +220,14 @@ pub async fn vector_search(
|
||||||
EmbedStore::search_vector(&input, limit as i64, offset as i64).await
|
EmbedStore::search_vector(&input, limit as i64, offset as i64).await
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
5,
|
50,
|
||||||
)
|
)
|
||||||
.pager(page as u64, vector)
|
.pager(page as u64, vector)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// API Route
|
// API Route
|
||||||
if query.ends_with(".json") {
|
if query.ends_with(".json") {
|
||||||
|
// TODO : Better search API
|
||||||
return Some(respond_json(&json!(&results.page(page as u64))));
|
return Some(respond_json(&json!(&results.page(page as u64))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue