Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 121 additions & 4 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,51 @@ pub struct Collection {
}

impl Collection {
pub fn get_similarity(&self, query: &[f32], k: usize) -> Vec<SimilarityResult> {
pub fn list(&self) -> Vec<String> {
self
.embeddings
.iter()
.map(|e| e.id.to_owned())
.collect()
}

pub fn get(&self, id: &str) -> Option<&Embedding> {
self
.embeddings
.iter()
.find(|e| e.id == id)
}

pub fn get_by_metadata(&self, filter: &[HashMap<String, String>], k: usize) -> Vec<Embedding> {
self
.embeddings
.iter()
.filter_map(|embedding| {
if match_embedding(embedding, filter) {
Some(embedding.clone())
} else {
None
}
})
.take(k)
.collect()
}

pub fn get_by_metadata_and_similarity(&self, filter: &[HashMap<String, String>], query: &[f32], k: usize) -> Vec<SimilarityResult> {
let memo_attr = get_cache_attr(self.distance, query);
let distance_fn = get_distance_fn(self.distance);

let scores = self
.embeddings
.par_iter()
.enumerate()
.map(|(index, embedding)| {
let score = distance_fn(&embedding.vector, query, memo_attr);
ScoreIndex { score, index }
.filter_map(|(index, embedding)| {
if match_embedding(embedding, filter) {
let score = distance_fn(&embedding.vector, query, memo_attr);
Some(ScoreIndex { score, index })
} else {
None
}
})
.collect::<Vec<_>>();

Expand All @@ -88,6 +122,77 @@ impl Collection {
})
.collect()
}

pub fn delete(&mut self, id: &str) -> bool {
let index_opt = self.embeddings
.iter()
.position(|e| e.id == id);

match index_opt {
None => false,
Some(index) => { self.embeddings.remove(index); true }
}
}

pub fn delete_by_metadata(&mut self, filter: &[HashMap<String, String>]) {
if filter.len() == 0 {
self.embeddings.clear();
return
}

let indexes = self
.embeddings
.par_iter()
.enumerate()
.filter_map(|(index, embedding)| {
if match_embedding(embedding, filter) {
Some(index)
} else {
None
}
})
.collect::<Vec<_>>();

for index in indexes {
self.embeddings.remove(index);
}
}
}

fn match_embedding(embedding: &Embedding, filter: &[HashMap<String, String>]) -> bool {
// an empty filter matches any embedding
if filter.len() == 0 {
return true
}

match &embedding.metadata {
// no metadata in an embedding cannot be matched by a not empty filter
None => false,
Some(metadata) => {
// enumerate criteria with OR semantics; look for the first one matching
for criteria in filter {
let mut matches = true;
// enumerate entries with AND semantics; look for the first one failing
for (key, expected) in criteria {
let found = match metadata.get(key) {
None => false,
Some(actual) => actual == expected
};
// a not matching entry means the whole embedding not matching
if !found {
matches = false;
break
}
}
// all entries matching mean the whole embedding matching
if matches {
return true
}
}
// no match found
false
}
}
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, JsonSchema)]
Expand Down Expand Up @@ -171,6 +276,18 @@ impl Db {
self.collections.get(name)
}

pub fn get_collection_mut(&mut self, name: &str) -> Option<&mut Collection> {
self.collections.get_mut(name)
}

pub fn list(&self) -> Vec<String> {
self
.collections
.keys()
.map(|name| name.to_owned())
.collect()
}

fn load_from_store() -> anyhow::Result<Self> {
if !STORE_PATH.exists() {
tracing::debug!("Creating database store");
Expand Down
140 changes: 137 additions & 3 deletions src/routes/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use aide::axum::{
use axum::{extract::Path, http::StatusCode, Extension};
use axum_jsonschema::Json;
use schemars::JsonSchema;
use std::time::Instant;
use std::{
collections::HashMap,
time::Instant,
};

use crate::{
db::{self, Collection, DbExtension, Embedding, Error as DbError, SimilarityResult},
Expand All @@ -17,14 +20,33 @@ pub fn handler() -> ApiRouter {
ApiRouter::new().nest(
"/collections",
ApiRouter::new()
.api_route("/", get(get_collections))
.api_route("/:collection_name", put(create_collection))
.api_route("/:collection_name", post(query_collection))
.api_route("/:collection_name", get(get_collection_info))
.api_route("/:collection_name", delete(delete_collection))
.api_route("/:collection_name/insert", post(insert_into_collection)),
.api_route("/:collection_name/insert", post(insert_into_collection))
.api_route("/:collection_name/embeddings", get(get_embeddings))
.api_route("/:collection_name/embeddings", post(query_embeddings))
.api_route("/:collection_name/embeddings", delete(delete_embeddings))
.api_route("/:collection_name/embeddings/:embedding_id", get(get_embedding))
.api_route("/:collection_name/embeddings/:embedding_id", delete(delete_embedding)),
)
}

/// Get collection names
async fn get_collections(
Extension(db): DbExtension,
) -> Result<Json<Vec<String>>, HTTPError> {
tracing::trace!("Getting collection names");

let db = db.read().await;

let results = db.list();

Ok(Json(results))
}

/// Create a new collection
async fn create_collection(
Path(collection_name): Path<String>,
Expand Down Expand Up @@ -54,6 +76,8 @@ async fn create_collection(
struct QueryCollectionQuery {
/// Vector to query with
query: Vec<f32>,
/// Metadata to filter with
filter: Option<Vec<HashMap<String, String>>>,
/// Number of results to return
k: Option<usize>,
}
Expand All @@ -77,7 +101,7 @@ async fn query_collection(
}

let instant = Instant::now();
let results = collection.get_similarity(&req.query, req.k.unwrap_or(1));
let results = collection.get_by_metadata_and_similarity(&req.filter.unwrap_or_default(), &req.query, req.k.unwrap_or(1));
drop(db);

tracing::trace!("Query to {collection_name} took {:?}", instant.elapsed());
Expand Down Expand Up @@ -165,3 +189,113 @@ async fn insert_into_collection(
.with_status(StatusCode::BAD_REQUEST)),
}
}

/// Query embeddings in a collection
async fn get_embeddings(
Path(collection_name): Path<String>,
Extension(db): DbExtension,
) -> Result<Json<Vec<String>>, HTTPError> {
tracing::trace!("Querying embeddings from collection {collection_name}");

let db = db.read().await;
let collection = db
.get_collection(&collection_name)
.ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?;

let results = collection.list();
drop(db);

Ok(Json(results))
}

#[derive(Debug, serde::Deserialize, JsonSchema)]
struct EmbeddingsQuery {
/// Metadata to filter with
filter: Vec<HashMap<String, String>>,
/// Number of results to return
k: Option<usize>,
}

/// Query embeddings in a collection
async fn query_embeddings(
Path(collection_name): Path<String>,
Extension(db): DbExtension,
Json(req): Json<EmbeddingsQuery>,
) -> Result<Json<Vec<Embedding>>, HTTPError> {
tracing::trace!("Querying embeddings from collection {collection_name}");

let db = db.read().await;
let collection = db
.get_collection(&collection_name)
.ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?;

let instant = Instant::now();
let results = collection.get_by_metadata(&req.filter, req.k.unwrap_or(1));
drop(db);

tracing::trace!("Query embeddings from {collection_name} took {:?}", instant.elapsed());
Ok(Json(results))
}

/// Delete embeddings in a collection
async fn delete_embeddings(
Path(collection_name): Path<String>,
Extension(db): DbExtension,
Json(req): Json<EmbeddingsQuery>,
) -> Result<StatusCode, HTTPError> {
tracing::trace!("Querying embeddings from collection {collection_name}");

let mut db = db.write().await;
let collection = db
.get_collection_mut(&collection_name)
.ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?;

collection.delete_by_metadata(&req.filter);
drop(db);

Ok(StatusCode::NO_CONTENT)
}

/// Get an embedding from a collection
async fn get_embedding(
Path((collection_name, embedding_id)): Path<(String, String)>,
Extension(db): DbExtension,
) -> Result<Json<Embedding>, HTTPError> {
tracing::trace!("Getting {embedding_id} from collection {collection_name}");

if embedding_id.len() == 0 {
return Err(HTTPError::new("Embedding identifier empty").with_status(StatusCode::BAD_REQUEST));
}

let db = db.read().await;
let collection = db
.get_collection(&collection_name)
.ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?;

let embedding = collection
.get(&embedding_id)
.ok_or_else(|| HTTPError::new("Embedding not found").with_status(StatusCode::NOT_FOUND))?;

Ok(Json(embedding.to_owned()))
}

/// Delete an embedding from a collection
async fn delete_embedding(
Path((collection_name, embedding_id)): Path<(String, String)>,
Extension(db): DbExtension,
) -> Result<StatusCode, HTTPError> {
tracing::trace!("Removing embedding {embedding_id} from collection {collection_name}");

let mut db = db.write().await;
let collection = db
.get_collection_mut(&collection_name)
.ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?;

let delete_result = collection.delete(&embedding_id);
drop(db);

match delete_result {
true => Ok(StatusCode::NO_CONTENT),
false => Err(HTTPError::new("Embedding not found").with_status(StatusCode::NOT_FOUND)),
}
}