Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions faiss/IndexFlat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -627,22 +627,25 @@ inline void flat_pano_search_core(
threshold = res.heap_dis[0];
}

size_t num_active =
index.pano
.progressive_filter_batch<CMax<float, int64_t>>(
index.codes.data(),
index.cum_sums.data(),
xi,
query_cum_norms.data(),
batch_no,
index.ntotal,
sel,
nullptr,
use_sel,
active_indices,
exact_distances,
threshold,
local_stats);
size_t num_active = with_metric_type(
index.metric_type, [&]<MetricType M>() {
return index.pano.progressive_filter_batch<
CMax<float, int64_t>,
M>(
index.codes.data(),
index.cum_sums.data(),
xi,
query_cum_norms.data(),
batch_no,
index.ntotal,
sel,
nullptr,
use_sel,
active_indices,
exact_distances,
threshold,
local_stats);
});

for (size_t j = 0; j < num_active; j++) {
res.add_result(
Expand Down
52 changes: 8 additions & 44 deletions faiss/IndexHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -653,33 +653,8 @@ IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric)
* IndexHNSWFlatPanorama implementation
**************************************************************/

void IndexHNSWFlatPanorama::compute_cum_sums(
const float* x,
float* dst_cum_sums,
int d,
int num_panorama_levels,
int panorama_level_width) {
// Iterate backwards through levels, accumulating sum as we go.
// This avoids computing the suffix sum for each vector, which takes
// extra memory.

float sum = 0.0f;
dst_cum_sums[num_panorama_levels] = 0.0f;
for (int level = num_panorama_levels - 1; level >= 0; level--) {
int start_idx = level * panorama_level_width;
int end_idx = std::min(start_idx + panorama_level_width, d);
for (int j = start_idx; j < end_idx; j++) {
sum += x[j] * x[j];
}
dst_cum_sums[level] = std::sqrt(sum);
}
}

IndexHNSWFlatPanorama::IndexHNSWFlatPanorama()
: IndexHNSWFlat(),
cum_sums(),
panorama_level_width(0),
num_panorama_levels(0) {}
: IndexHNSWFlat(), cum_sums(), pano(0, 1, 1), num_panorama_levels(0) {}

IndexHNSWFlatPanorama::IndexHNSWFlatPanorama(
int d,
Expand All @@ -688,8 +663,7 @@ IndexHNSWFlatPanorama::IndexHNSWFlatPanorama(
MetricType metric)
: IndexHNSWFlat(d, M, metric),
cum_sums(),
panorama_level_width(
(d + num_panorama_levels - 1) / num_panorama_levels),
pano(d * sizeof(float), num_panorama_levels, 1),
num_panorama_levels(num_panorama_levels) {
// For now, we only support L2 distance.
// Supporting dot product and cosine distance is a trivial addition
Expand All @@ -704,18 +678,8 @@ IndexHNSWFlatPanorama::IndexHNSWFlatPanorama(

void IndexHNSWFlatPanorama::add(idx_t n, const float* x) {
idx_t n0 = ntotal;
cum_sums.resize((ntotal + n) * (num_panorama_levels + 1));

for (size_t idx = 0; idx < n; idx++) {
const float* vector = x + idx * d;
compute_cum_sums(
vector,
&cum_sums[(n0 + idx) * (num_panorama_levels + 1)],
d,
num_panorama_levels,
panorama_level_width);
}

cum_sums.resize((ntotal + n) * (pano.n_levels + 1));
pano.compute_cumulative_sums(cum_sums.data(), n0, n, x);
IndexHNSWFlat::add(n, x);
}

Expand All @@ -725,13 +689,13 @@ void IndexHNSWFlatPanorama::reset() {
}

void IndexHNSWFlatPanorama::permute_entries(const idx_t* perm) {
std::vector<float> new_cum_sums(ntotal * (num_panorama_levels + 1));
std::vector<float> new_cum_sums(ntotal * (pano.n_levels + 1));

for (idx_t i = 0; i < ntotal; i++) {
idx_t src = perm[i];
memcpy(&new_cum_sums[i * (num_panorama_levels + 1)],
&cum_sums[src * (num_panorama_levels + 1)],
(num_panorama_levels + 1) * sizeof(float));
memcpy(&new_cum_sums[i * (pano.n_levels + 1)],
&cum_sums[src * (pano.n_levels + 1)],
(pano.n_levels + 1) * sizeof(float));
}

std::swap(cum_sums, new_cum_sums);
Expand Down
14 changes: 3 additions & 11 deletions faiss/IndexHNSW.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <faiss/IndexPQ.h>
#include <faiss/IndexScalarQuantizer.h>
#include <faiss/impl/HNSW.h>
#include <faiss/impl/Panorama.h>
#include <faiss/utils/utils.h>

namespace faiss {
Expand Down Expand Up @@ -164,20 +165,11 @@ struct IndexHNSWFlatPanorama : IndexHNSWFlat {

/// Inline for performance - called frequently in search hot path.
const float* get_cum_sum(idx_t i) const {
return cum_sums.data() + i * (num_panorama_levels + 1);
return cum_sums.data() + i * (pano.n_levels + 1);
}

/// Compute cumulative sums for a vector (used both for database points and
/// queries).
static void compute_cum_sums(
const float* x,
float* dst_cum_sums,
int d,
int num_panorama_levels,
int panorama_level_width);

std::vector<float> cum_sums;
const size_t panorama_level_width;
Panorama pano;
const size_t num_panorama_levels;
};

Expand Down
71 changes: 35 additions & 36 deletions faiss/IndexIVFFlatPanorama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <cstdio>

#include <faiss/IndexFlat.h>
#include <faiss/MetricType.h>

#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/IDSelector.h>
Expand All @@ -32,10 +33,7 @@ IndexIVFFlatPanorama::IndexIVFFlatPanorama(
MetricType metric,
bool own_invlists)
: IndexIVFFlat(quantizer, d, nlist, metric, false), n_levels(n_levels) {
// For now, we only support L2 distance.
// Supporting dot product and cosine distance is a trivial addition
// left for future work.
FAISS_THROW_IF_NOT(metric == METRIC_L2);
FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);

// We construct the inverted lists here so that we can use the
// level-oriented storage. This does not cause a leak as we constructed
Expand All @@ -53,6 +51,7 @@ struct IVFFlatScannerPanorama : InvertedListScanner {
VectorDistance vd;
const ArrayInvertedListsPanorama* storage;
using C = typename VectorDistance::C;
static constexpr MetricType metric = VectorDistance::metric;

IVFFlatScannerPanorama(
const VectorDistance& vd,
Expand Down Expand Up @@ -109,22 +108,22 @@ struct IVFFlatScannerPanorama : InvertedListScanner {
for (size_t batch_no = 0; batch_no < n_batches; batch_no++) {
size_t batch_start = batch_no * storage->kBatchSize;

size_t num_active =
storage->pano
.progressive_filter_batch<CMax<float, int64_t>>(
codes,
cum_sums_data,
xi,
cum_sums.data(),
batch_no,
list_size,
sel,
ids,
use_sel,
active_indices,
exact_distances,
simi[0],
local_stats);
size_t num_active = with_metric_type(metric, [&]<MetricType M>() {
return storage->pano.progressive_filter_batch<C, M>(
codes,
cum_sums_data,
xi,
cum_sums.data(),
batch_no,
list_size,
sel,
ids,
use_sel,
active_indices,
exact_distances,
simi[0],
local_stats);
});

// Add batch survivors to heap.
for (size_t i = 0; i < num_active; i++) {
Expand Down Expand Up @@ -167,22 +166,22 @@ struct IVFFlatScannerPanorama : InvertedListScanner {
for (size_t batch_no = 0; batch_no < n_batches; batch_no++) {
size_t batch_start = batch_no * storage->kBatchSize;

size_t num_active =
storage->pano
.progressive_filter_batch<CMax<float, int64_t>>(
codes,
cum_sums_data,
xi,
cum_sums.data(),
batch_no,
list_size,
sel,
ids,
use_sel,
active_indices,
exact_distances,
radius,
local_stats);
size_t num_active = with_metric_type(metric, [&]<MetricType M>() {
return storage->pano.progressive_filter_batch<C, M>(
codes,
cum_sums_data,
xi,
cum_sums.data(),
batch_no,
list_size,
sel,
ids,
use_sel,
active_indices,
exact_distances,
radius,
local_stats);
});

// Add batch survivors to range result.
for (size_t i = 0; i < num_active; i++) {
Expand Down
48 changes: 48 additions & 0 deletions faiss/MetricType.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <cstdint>
#include <cstdio>
#include <cstdlib>

namespace faiss {

Expand Down Expand Up @@ -52,6 +53,53 @@ constexpr bool is_similarity_metric(MetricType metric_type) {
(metric_type == METRIC_Jaccard));
}

/// Dispatch to a lambda with MetricType as a compile-time constant.
/// This allows writing generic code that works with different metrics
/// while maintaining compile-time optimization.
///
/// Example usage:
/// auto result = with_metric_type(runtime_metric, [&](auto metric_tag) {
/// constexpr MetricType M = decltype(metric_tag)::value;
/// return compute_distance<M>(x, y);
/// });
#ifndef SWIG

template <typename LambdaType>
inline auto with_metric_type(MetricType metric, LambdaType&& action) {
switch (metric) {
case METRIC_INNER_PRODUCT:
return action.template operator()<METRIC_INNER_PRODUCT>();
case METRIC_L2:
return action.template operator()<METRIC_L2>();
case METRIC_L1:
return action.template operator()<METRIC_L1>();
case METRIC_Linf:
return action.template operator()<METRIC_Linf>();
case METRIC_Lp:
return action.template operator()<METRIC_Lp>();
case METRIC_Canberra:
return action.template operator()<METRIC_Canberra>();
case METRIC_BrayCurtis:
return action.template operator()<METRIC_BrayCurtis>();
case METRIC_JensenShannon:
return action.template operator()<METRIC_JensenShannon>();
case METRIC_Jaccard:
return action.template operator()<METRIC_Jaccard>();
case METRIC_NaNEuclidean:
return action.template operator()<METRIC_NaNEuclidean>();
case METRIC_GOWER:
return action.template operator()<METRIC_GOWER>();
default: {
fprintf(stderr,
"FATAL ERROR: with_metric_type called with unknown "
"metric %d\n",
static_cast<int>(metric));
abort();
}
}
}
#endif // SWIG

} // namespace faiss

#endif
19 changes: 7 additions & 12 deletions faiss/impl/HNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -802,13 +802,8 @@ int search_from_candidates_panorama(
std::vector<float> exact_distances(M);

const float* query = flat_codes_qdis->q;
std::vector<float> query_cum_sums(panorama_index->num_panorama_levels + 1);
IndexHNSWFlatPanorama::compute_cum_sums(
query,
query_cum_sums.data(),
panorama_index->d,
panorama_index->num_panorama_levels,
panorama_index->panorama_level_width);
std::vector<float> query_cum_sums(panorama_index->pano.n_levels + 1);
panorama_index->pano.compute_query_cum_sums(query, query_cum_sums.data());
float query_norm_sq = query_cum_sums[0] * query_cum_sums[0];

int nstep = 0;
Expand Down Expand Up @@ -854,14 +849,14 @@ int search_from_candidates_panorama(

size_t batch_size = initial_size;
size_t curr_panorama_level = 0;
const size_t num_panorama_levels = panorama_index->num_panorama_levels;
const size_t num_panorama_levels = panorama_index->pano.n_levels;
while (curr_panorama_level < num_panorama_levels && batch_size > 0) {
float query_cum_norm = query_cum_sums[curr_panorama_level + 1];

const size_t panorama_level_width =
panorama_index->panorama_level_width;
size_t start_dim = curr_panorama_level * panorama_level_width;
size_t end_dim = (curr_panorama_level + 1) * panorama_level_width;
size_t start_dim = curr_panorama_level *
panorama_index->pano.level_width_floats;
size_t end_dim = (curr_panorama_level + 1) *
panorama_index->pano.level_width_floats;
end_dim = std::min(end_dim, static_cast<size_t>(panorama_index->d));

size_t i = 0;
Expand Down
Loading
Loading