Skip to content

Commit

Permalink
merge strict types
Browse files Browse the repository at this point in the history
  • Loading branch information
tinkn committed Aug 26, 2024
2 parents d6e91dd + aa11a53 commit 564d743
Show file tree
Hide file tree
Showing 25 changed files with 573 additions and 233 deletions.
2 changes: 1 addition & 1 deletion src/api/vectordb/fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub(crate) async fn fetch(web::Json(body): web::Json<FetchNeighbors>) -> HttpRes
let response_data = RPCResponseBody::RespFetchNeighbors {
neighbors: neig
.iter()
.map(|(vid, x)| (VectorIdValue::from(vid.clone()), *x))
.map(|(vid, x)| (VectorIdValue::from(vid.clone()), x.clone()))
.collect(),
vector: Vector {
id: nvid,
Expand Down
18 changes: 9 additions & 9 deletions src/api_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::cell::RefCell;
use std::fs::OpenOptions;
use std::io::Write;
use std::rc::Rc;
use std::sync::{atomic::AtomicBool, Arc};
use std::sync::Arc;

pub async fn init_vector_store(
name: String,
Expand Down Expand Up @@ -80,10 +80,10 @@ pub async fn init_vector_store(
let prop = Arc::new(NodeProp {
id: vec_hash.clone(),
value: vector_list.clone(),
location: Some((0, 0)),
location: Some((FileOffset(0), BytesToRead(0))),
});
let mut current_node = ArcShift::new(MergedNode {
hnsw_level: l as u8,
hnsw_level: HNSWLevel(l as u8),
prop: ArcShift::new(PropState::Ready(prop.clone())),
neighbors: EagerLazyItemSet::new(),
parent: LazyItemRef::new_invalid(),
Expand All @@ -92,8 +92,8 @@ pub async fn init_vector_store(
});

// TODO: Initialize with appropriate version ID
let lazy_node = LazyItem::from_arcshift(0, current_node.clone());
let nn = LazyItemRef::from_arcshift(0, current_node.clone());
let lazy_node = LazyItem::from_arcshift(VersionId(0), current_node.clone());
let nn = LazyItemRef::from_arcshift(VersionId(0), current_node.clone());

if let Some(prev_node) = prev.item.get().get_data() {
current_node
Expand All @@ -105,7 +105,7 @@ pub async fn init_vector_store(

if l == 0 {
root = nn.clone();
let prop_location = write_prop_to_file(&prop, &prop_file);
let _prop_location = write_prop_to_file(&prop, &prop_file);
current_node.get().set_prop_ready(prop);
}
nodes.push(nn.clone());
Expand Down Expand Up @@ -242,7 +242,7 @@ pub fn run_upload(
pub async fn ann_vector_query(
vec_store: Arc<VectorStore>,
query: Vec<f32>,
) -> Result<Option<Vec<(VectorId, f32)>>, WaCustomError> {
) -> Result<Option<Vec<(VectorId, MetricResult)>>, WaCustomError> {
let vector_store = vec_store.clone();
let vec_hash = VectorId::Str("query".to_string());
let root = &vector_store.root_vec;
Expand All @@ -268,7 +268,7 @@ pub async fn ann_vector_query(
pub async fn fetch_vector_neighbors(
vec_store: Arc<VectorStore>,
vector_id: VectorId,
) -> Vec<Option<(VectorId, Vec<(VectorId, f32)>)>> {
) -> Vec<Option<(VectorId, Vec<(VectorId, MetricResult)>)>> {
let results = vector_fetch(vec_store.clone(), vector_id);
return results.expect("Failed fetching vector neighbors");
}
Expand All @@ -278,7 +278,7 @@ fn calculate_statistics(_: &[i32]) -> Option<Statistics> {
None
}

fn vector_knn(vs: &Vec<f32>, vecs: &Vec<f32>) -> Vec<(i8, i8, String, f64)> {
fn vector_knn(_vs: &Vec<f32>, _vecs: &Vec<f32>) -> Vec<(i8, i8, String, f64)> {
// Placeholder for vector KNN
vec![]
}
38 changes: 26 additions & 12 deletions src/distance/cosine.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,39 @@
use serde::{Deserialize, Serialize};

use super::{DistanceError, DistanceFunction};
use crate::storage::Storage;
#[derive(Debug)]
pub struct CosineDistance;

#[derive(Debug, Clone, Deserialize, PartialEq, Serialize)]
pub struct CosineDistance(pub f32);

impl DistanceFunction for CosineDistance {
// Implementation here
fn calculate(&self, x: &Storage, y: &Storage) -> Result<f32, DistanceError> {
type Item = Self;
fn calculate(&self, _x: &Storage, _y: &Storage) -> Result<Self::Item, DistanceError> {
// placeholder method to be implemented
Err(DistanceError::CalculationError)
}
}

#[derive(Debug, Clone, Deserialize, PartialEq, Serialize)]
pub struct CosineSimilarity(pub f32);

impl DistanceFunction for CosineSimilarity {
type Item = Self;
fn calculate(&self, x: &Storage, y: &Storage) -> Result<Self::Item, DistanceError> {
match (x, y) {
(
Storage::UnsignedByte {
mag: mag_x,
quant_vec: vec_x,
mag: _mag_x,
quant_vec: _vec_x,
},
Storage::UnsignedByte {
mag: mag_y,
quant_vec: vec_y,
mag: _mag_y,
quant_vec: _vec_y,
},
) => {
// Implement cosine similarity for UnsignedByte storage
//unimplemented!("Cosine similarity for UnsignedByte not implemented yet")
Ok(0.0)
Ok(CosineSimilarity(0.0))
}
(
Storage::SubByte {
Expand Down Expand Up @@ -84,12 +98,12 @@ fn cosine_similarity_from_dot_product(
dot_product: f32,
mag_x: u32,
mag_y: u32,
) -> Result<f32, DistanceError> {
) -> Result<CosineSimilarity, DistanceError> {
let denominator = (mag_x as f32).sqrt() * (mag_y as f32).sqrt();
if denominator == 0.0 {
Err(DistanceError::CalculationError)
} else {
Ok(dot_product / denominator)
Ok(CosineSimilarity(dot_product / denominator))
}
}

Expand All @@ -101,7 +115,7 @@ fn dot_product_quaternary(x_vec: &[Vec<u8>], y_vec: &[Vec<u8>], resolution: u8)
.zip(&x_vec[1])
.zip(y_vec[0].iter().zip(&y_vec[1]))
.enumerate()
.map(|(i, ((&x_lsb, &x_msb), (&y_lsb, &y_msb)))| {
.map(|(_i, ((&x_lsb, &x_msb), (&y_lsb, &y_msb)))| {
let lsbs = (x_lsb & y_lsb).count_ones();
let mid1 = x_lsb & y_msb;
let mid2 = y_lsb & x_msb;
Expand Down
12 changes: 7 additions & 5 deletions src/distance/dotproduct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ use super::{DistanceError, DistanceFunction};
use crate::models::dot_product::dot_product_u8;
use crate::storage::Storage;
use half::f16;
use serde::{Deserialize, Serialize};

#[derive(Debug)]
pub struct DotProductDistance;
#[derive(Debug, Clone, Deserialize, PartialEq, Serialize)]
pub struct DotProductDistance(pub f32);

impl DistanceFunction for DotProductDistance {
fn calculate(&self, x: &Storage, y: &Storage) -> Result<f32, DistanceError> {
type Item = Self;
fn calculate(&self, x: &Storage, y: &Storage) -> Result<Self::Item, DistanceError> {
match (x, y) {
(
Storage::UnsignedByte {
Expand All @@ -16,15 +18,15 @@ impl DistanceFunction for DotProductDistance {
Storage::UnsignedByte {
quant_vec: vec_y, ..
},
) => Ok(dot_product_u8(vec_x, vec_y) as f32),
) => Ok(DotProductDistance(dot_product_u8(vec_x, vec_y) as f32)),
(
Storage::HalfPrecisionFP {
quant_vec: vec_x, ..
},
Storage::HalfPrecisionFP {
quant_vec: vec_y, ..
},
) => Ok(dot_product_f16(vec_x, vec_y)),
) => Ok(DotProductDistance(dot_product_f16(vec_x, vec_y))),
(Storage::SubByte { .. }, Storage::SubByte { .. }) => {
Err(DistanceError::CalculationError) // Implement if needed
}
Expand Down
48 changes: 27 additions & 21 deletions src/distance/euclidean.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use super::{DistanceError, DistanceFunction};
use crate::storage::Storage;
use half::f16;
use serde::{Deserialize, Serialize};

#[derive(Debug)]
pub struct EuclideanDistance;
#[derive(Debug, Clone, Deserialize, PartialEq, Serialize)]
pub struct EuclideanDistance(pub f32);

impl DistanceFunction for EuclideanDistance {
fn calculate(&self, x: &Storage, y: &Storage) -> Result<f32, DistanceError> {
type Item = Self;
fn calculate(&self, x: &Storage, y: &Storage) -> Result<Self::Item, DistanceError> {
match (x, y) {
(
Storage::UnsignedByte {
Expand All @@ -31,24 +33,28 @@ impl DistanceFunction for EuclideanDistance {
}
}
}
pub fn euclidean_distance_u8(x: &[u8], y: &[u8]) -> f32 {
x.iter()
.zip(y.iter())
.map(|(&a, &b)| {
let diff = (a as i16) - (b as i16);
(diff * diff) as f32
})
.sum::<f32>()
.sqrt()
pub fn euclidean_distance_u8(x: &[u8], y: &[u8]) -> EuclideanDistance {
EuclideanDistance(
x.iter()
.zip(y.iter())
.map(|(&a, &b)| {
let diff = (a as i16) - (b as i16);
(diff * diff) as f32
})
.sum::<f32>()
.sqrt(),
)
}

pub fn euclidean_distance_f16(x: &[f16], y: &[f16]) -> f32 {
x.iter()
.zip(y.iter())
.map(|(&a, &b)| {
let diff = f32::from(a) - f32::from(b);
diff * diff
})
.sum::<f32>()
.sqrt()
pub fn euclidean_distance_f16(x: &[f16], y: &[f16]) -> EuclideanDistance {
EuclideanDistance(
x.iter()
.zip(y.iter())
.map(|(&a, &b)| {
let diff = f32::from(a) - f32::from(b);
diff * diff
})
.sum::<f32>()
.sqrt(),
)
}
10 changes: 7 additions & 3 deletions src/distance/hamming.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use serde::{Deserialize, Serialize};

use super::{DistanceError, DistanceFunction};
use crate::storage::Storage;

#[derive(Debug)]
pub struct HammingDistance;
#[derive(Debug, Clone, Deserialize, PartialEq, Serialize)]
pub struct HammingDistance(pub f32);

impl DistanceFunction for HammingDistance {
type Item = Self;

// Implementation here
fn calculate(&self, x: &Storage, y: &Storage) -> Result<f32, DistanceError> {
fn calculate(&self, x: &Storage, y: &Storage) -> Result<Self::Item, DistanceError> {
match (x, y) {
(
Storage::UnsignedByte {
Expand Down
3 changes: 2 additions & 1 deletion src/distance/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ pub mod hamming;
use crate::storage::Storage;

pub trait DistanceFunction: std::fmt::Debug + Send + Sync {
fn calculate(&self, x: &Storage, y: &Storage) -> Result<f32, DistanceError>;
type Item;
fn calculate(&self, x: &Storage, y: &Storage) -> Result<Self::Item, DistanceError>;
}

#[derive(Debug)]
Expand Down
12 changes: 6 additions & 6 deletions src/models/cache_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl<R: Read + Seek> NodeRegistry<R> {
let version_id = if let FileIndex::Valid { version, .. } = &file_index {
*version
} else {
0
VersionId(0)
};

if max_loads == 0 || !skipm.insert(combined_index) {
Expand Down Expand Up @@ -137,7 +137,7 @@ impl<R: Read + Seek> NodeRegistry<R> {

pub fn combine_index(file_index: &FileIndex) -> u64 {
match file_index {
FileIndex::Valid { offset, version } => ((*offset as u64) << 32) | (*version as u64),
FileIndex::Valid { offset, version } => ((offset.0 as u64) << 32) | (version.0 as u64),
FileIndex::Invalid => u64::MAX, // Use max u64 value for Invalid
}
}
Expand All @@ -147,8 +147,8 @@ impl<R: Read + Seek> NodeRegistry<R> {
FileIndex::Invalid
} else {
FileIndex::Valid {
offset: (combined >> 32) as u32,
version: combined as u16,
offset: FileOffset((combined >> 32) as u32),
version: VersionId(combined as u16),
}
}
}
Expand All @@ -163,8 +163,8 @@ pub fn load_cache() {
.expect("failed to open");

let file_index = FileIndex::Valid {
offset: 0,
version: 0,
offset: FileOffset(0),
version: VersionId(0),
}; // Assuming initial version is 0
let cache = Arc::new(NodeRegistry::new(1000, file));
match read_node_from_file(file_index.clone(), cache) {
Expand Down
Loading

0 comments on commit 564d743

Please sign in to comment.