Skip to content

Commit

Permalink
initial trial to create strongly typed metric restul
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamedeliwa committed Aug 15, 2024
1 parent 3d91ffe commit eb981ae
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 41 deletions.
12 changes: 7 additions & 5 deletions src/distance/cosine.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use super::{DistanceError, DistanceFunction};
use crate::storage::Storage;
use crate::{models::types::MetricResult, storage::Storage};
#[derive(Debug)]
pub struct CosineDistance;

impl DistanceFunction for CosineDistance {
// Implementation here
fn calculate(&self, x: &Storage, y: &Storage) -> Result<f32, DistanceError> {
// TODO: this method is calculating and returning COSINE SIMILARITY, while it's named COSINE DISTANCE
// should not it be renamed???
fn calculate(&self, x: &Storage, y: &Storage) -> Result<MetricResult, DistanceError> {
match (x, y) {
(
Storage::UnsignedByte {
Expand All @@ -19,7 +21,7 @@ impl DistanceFunction for CosineDistance {
) => {
// Implement cosine similarity for UnsignedByte storage
//unimplemented!("Cosine similarity for UnsignedByte not implemented yet")
Ok(0.0)
Ok(MetricResult::CosineSimilarity(0.0))
}
(
Storage::SubByte {
Expand Down Expand Up @@ -84,12 +86,12 @@ fn cosine_similarity_from_dot_product(
dot_product: f32,
mag_x: u32,
mag_y: u32,
) -> Result<f32, DistanceError> {
) -> Result<MetricResult, DistanceError> {
let denominator = (mag_x as f32).sqrt() * (mag_y as f32).sqrt();
if denominator == 0.0 {
Err(DistanceError::CalculationError)
} else {
Ok(dot_product / denominator)
Ok(MetricResult::CosineSimilarity(dot_product / denominator))
}
}

Expand Down
11 changes: 8 additions & 3 deletions src/distance/dotproduct.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use super::{DistanceError, DistanceFunction};
use crate::models::dot_product::dot_product_u8;
use crate::models::types::MetricResult;
use crate::storage::Storage;
use half::f16;

#[derive(Debug)]
pub struct DotProductDistance;

impl DistanceFunction for DotProductDistance {
fn calculate(&self, x: &Storage, y: &Storage) -> Result<f32, DistanceError> {
fn calculate(&self, x: &Storage, y: &Storage) -> Result<MetricResult, DistanceError> {
match (x, y) {
(
Storage::UnsignedByte {
Expand All @@ -16,15 +17,19 @@ impl DistanceFunction for DotProductDistance {
Storage::UnsignedByte {
quant_vec: vec_y, ..
},
) => Ok(dot_product_u8(vec_x, vec_y) as f32),
) => Ok(MetricResult::DotProductDistance(
dot_product_u8(vec_x, vec_y) as f32,
)),
(
Storage::HalfPrecisionFP {
quant_vec: vec_x, ..
},
Storage::HalfPrecisionFP {
quant_vec: vec_y, ..
},
) => Ok(dot_product_f16(vec_x, vec_y)),
) => Ok(MetricResult::DotProductDistance(dot_product_f16(
vec_x, vec_y,
))),
(Storage::SubByte { .. }, Storage::SubByte { .. }) => {
Err(DistanceError::CalculationError) // Implement if needed
}
Expand Down
44 changes: 24 additions & 20 deletions src/distance/euclidean.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use super::{DistanceError, DistanceFunction};
use crate::storage::Storage;
use crate::{models::types::MetricResult, storage::Storage};
use half::f16;

#[derive(Debug)]
pub struct EuclideanDistance;

impl DistanceFunction for EuclideanDistance {
fn calculate(&self, x: &Storage, y: &Storage) -> Result<f32, DistanceError> {
fn calculate(&self, x: &Storage, y: &Storage) -> Result<MetricResult, DistanceError> {
match (x, y) {
(
Storage::UnsignedByte {
Expand All @@ -31,24 +31,28 @@ impl DistanceFunction for EuclideanDistance {
}
}
}
pub fn euclidean_distance_u8(x: &[u8], y: &[u8]) -> f32 {
x.iter()
.zip(y.iter())
.map(|(&a, &b)| {
let diff = (a as i16) - (b as i16);
(diff * diff) as f32
})
.sum::<f32>()
.sqrt()
pub fn euclidean_distance_u8(x: &[u8], y: &[u8]) -> MetricResult {
MetricResult::EuclideanDistance(
x.iter()
.zip(y.iter())
.map(|(&a, &b)| {
let diff = (a as i16) - (b as i16);
(diff * diff) as f32
})
.sum::<f32>()
.sqrt(),
)
}

pub fn euclidean_distance_f16(x: &[f16], y: &[f16]) -> f32 {
x.iter()
.zip(y.iter())
.map(|(&a, &b)| {
let diff = f32::from(a) - f32::from(b);
diff * diff
})
.sum::<f32>()
.sqrt()
pub fn euclidean_distance_f16(x: &[f16], y: &[f16]) -> MetricResult {
MetricResult::EuclideanDistance(
x.iter()
.zip(y.iter())
.map(|(&a, &b)| {
let diff = f32::from(a) - f32::from(b);
diff * diff
})
.sum::<f32>()
.sqrt(),
)
}
4 changes: 2 additions & 2 deletions src/distance/hamming.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use super::{DistanceError, DistanceFunction};
use crate::storage::Storage;
use crate::{models::types::MetricResult, storage::Storage};

#[derive(Debug)]
pub struct HammingDistance;

impl DistanceFunction for HammingDistance {
// Implementation here
fn calculate(&self, x: &Storage, y: &Storage) -> Result<f32, DistanceError> {
fn calculate(&self, x: &Storage, y: &Storage) -> Result<MetricResult, DistanceError> {
match (x, y) {
(
Storage::UnsignedByte {
Expand Down
4 changes: 2 additions & 2 deletions src/distance/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ pub mod dotproduct;
pub mod euclidean;
pub mod hamming;

use crate::storage::Storage;
use crate::{models::types::MetricResult, storage::Storage};

pub trait DistanceFunction: std::fmt::Debug + Send + Sync {
fn calculate(&self, x: &Storage, y: &Storage) -> Result<f32, DistanceError>;
fn calculate(&self, x: &Storage, y: &Storage) -> Result<MetricResult, DistanceError>;
}

#[derive(Debug)]
Expand Down
33 changes: 29 additions & 4 deletions src/models/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,27 @@ pub struct MergedNode {
pub persist_flag: Arc<AtomicBool>,
}

#[derive(Clone)]
pub enum MetricResult {
CosineSimilarity(f32),
CosineDistance(f32),
EuclideanDistance(f32),
HammingDistance(f32),
DotProductDistance(f32),
}

impl MetricResult {
pub fn get_value(&self) -> f32 {
match self {
MetricResult::CosineSimilarity(value) => *value,
MetricResult::CosineDistance(value) => *value,
MetricResult::EuclideanDistance(value) => *value,
MetricResult::HammingDistance(value) => *value,
MetricResult::DotProductDistance(value) => *value,
}
}
}

#[derive(Debug)]
pub enum DistanceMetric {
Cosine,
Expand All @@ -127,7 +148,7 @@ pub enum DistanceMetric {
}

impl DistanceFunction for DistanceMetric {
fn calculate(&self, x: &Storage, y: &Storage) -> Result<f32, DistanceError> {
fn calculate(&self, x: &Storage, y: &Storage) -> Result<MetricResult, DistanceError> {
match self {
Self::Cosine => CosineDistance.calculate(x, y),
Self::Euclidean => EuclideanDistance.calculate(x, y),
Expand Down Expand Up @@ -176,9 +197,13 @@ impl MergedNode {
}
}

pub fn add_ready_neighbor(&self, neighbor: LazyItem<MergedNode>, cosine_similarity: f32) {
pub fn add_ready_neighbor(
&self,
neighbor: LazyItem<MergedNode>,
cosine_similarity: MetricResult,
) {
self.neighbors
.insert(EagerLazyItem(cosine_similarity, neighbor));
.insert(EagerLazyItem(cosine_similarity.get_value(), neighbor));
}

pub fn set_parent(&self, parent: LazyItem<MergedNode>) {
Expand All @@ -191,7 +216,7 @@ impl MergedNode {
arc.update(child);
}

pub fn add_ready_neighbors(&self, neighbors_list: Vec<(LazyItem<MergedNode>, f32)>) {
pub fn add_ready_neighbors(&self, neighbors_list: Vec<(LazyItem<MergedNode>, MetricResult)>) {
for (neighbor, cosine_similarity) in neighbors_list {
self.add_ready_neighbor(neighbor, cosine_similarity);
}
Expand Down
10 changes: 5 additions & 5 deletions src/vector_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ fn insert_node_create_edges(
vec_store: Arc<VectorStore>,
fvec: Arc<Storage>,
hs: VectorId,
nbs: Vec<(LazyItem<MergedNode>, f32)>,
nbs: Vec<(LazyItem<MergedNode>, MetricResult)>,
cur_level: i8,
) -> Result<(), WaCustomError> {
let node_prop = NodeProp {
Expand All @@ -712,7 +712,7 @@ fn insert_node_create_edges(
.map(|nbr2| (nbr2.1, nbr2.0))
.collect();

neighbor_list.push((LazyItem::from_item(nn.clone()), cs));
neighbor_list.push((LazyItem::from_item(nn.clone()), cs.get_value()));

neighbor_list
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Expand Down Expand Up @@ -741,8 +741,8 @@ fn traverse_find_nearest(
skipm: &mut HashSet<VectorId>,
cur_level: i8,
skip_hop: bool,
) -> Result<Vec<(LazyItem<MergedNode>, f32)>, WaCustomError> {
let mut tasks: SmallVec<[Vec<(LazyItem<MergedNode>, f32)>; 24]> = SmallVec::new();
) -> Result<Vec<(LazyItem<MergedNode>, MetricResult)>, WaCustomError> {
let mut tasks: SmallVec<[Vec<(LazyItem<MergedNode>, MetricResult)>; 24]> = SmallVec::new();

let mut node_arc = match vtm.clone() {
LazyItem::Valid {
Expand Down Expand Up @@ -826,7 +826,7 @@ fn traverse_find_nearest(
}

let mut nn: Vec<_> = tasks.into_iter().flatten().collect();
nn.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
nn.sort_by(|a, b| b.1.get_value().partial_cmp(&a.1.get_value()).unwrap());
let mut seen = HashSet::new();
nn.retain(|(lazy_node, _)| {
if let LazyItem::Valid {
Expand Down

0 comments on commit eb981ae

Please sign in to comment.