Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions _example/clustering/clustering.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package main

import (
"fmt"
"log"
"math/rand"
"time"

"github.com/blevesearch/go-faiss"
)

func main() {
rng := rand.New(rand.NewSource(123456))

const (
d = 64 // vector dimension
n = 10_000 // number of training vectors
k = 10 // number of clusters
)

fmt.Printf("Clustering %d vectors of dimension %d into %d clusters\n\n", n, d, k)

train := make([]float32, n*d)
centers := make([][]float32, k)
for i := 0; i < k; i++ {
centers[i] = make([]float32, d)
for j := 0; j < d; j++ {
centers[i][j] = rng.Float32() * 100
}
}

// Generate points around these centers with some noise
pointsPerCluster := n / k
for i := 0; i < n; i++ {
cluster := i / pointsPerCluster
if cluster >= k {
cluster = k - 1
}

// Add Gaussian noise around the cluster center
for j := 0; j < d; j++ {
noise := float32(rng.NormFloat64() * 5)
train[i*d+j] = centers[cluster][j] + noise
}
}

fmt.Println("Running simple k-means clustering...")
start := time.Now()

centroids, qerr, err := faiss.KMeansClustering(d, n, k, train)
if err != nil {
log.Fatalf("k-means: %v", err)
}

simpleTime := time.Since(start)
fmt.Printf("Simple k-means completed in %v\n", simpleTime)
fmt.Printf("Average quantization error: %.2f\n\n", qerr/float32(n))

fmt.Println("Running clustering with custom parameters...")

params := faiss.NewClusteringParameters()
params.Niter = 25
params.Nredo = 3
params.Verbose = true
params.Seed = 1234
params.MinPointsPerCentroid = 39
params.MaxPointsPerCentroid = 256

clustering, err := faiss.NewClusteringWithParams(d, k, params)
if err != nil {
log.Fatalf("new clustering: %v", err)
}
defer clustering.Close()

// Create an index to accelerate clustering
// For larger datasets, consider using a faster index like IndexIVFFlat
accelIdx, err := faiss.NewIndexFlatL2(d)
if err != nil {
log.Fatalf("index: %v", err)
}
defer accelIdx.Close()

start = time.Now()
if err = clustering.Train(train, accelIdx); err != nil {
log.Fatalf("train: %v", err)
}
advancedTime := time.Since(start)

advCentroids := clustering.Centroids()
fmt.Printf("\nAdvanced clustering completed in %v\n\n", advancedTime)

fmt.Println("Comparing clustering quality...")

baseIdx, err := faiss.NewIndexFlatL2(d)
if err != nil {
log.Fatalf("index: %v", err)
}
defer baseIdx.Close()
if err = baseIdx.Add(centroids); err != nil {
log.Fatalf("add centroids: %v", err)
}

advIdx, err := faiss.NewIndexFlatL2(d)
if err != nil {
log.Fatalf("index: %v", err)
}
defer advIdx.Close()
if err = advIdx.Add(advCentroids); err != nil {
log.Fatalf("add centroids: %v", err)
}

// Find nearest centroid for each training point
baseDist, _, _ := baseIdx.Search(train, 1)
advDist, _, _ := advIdx.Search(train, 1)

avgBase := mean(baseDist)
avgAdv := mean(advDist)

fmt.Printf("Average distance to nearest centroid:\n")
fmt.Printf(" Simple k-means: %.2f\n", avgBase)
fmt.Printf(" Advanced method: %.2f\n", avgAdv)
fmt.Printf(" Improvement: %.2f%% better\n", 100*(avgBase-avgAdv)/avgBase)
fmt.Printf("\nTime comparison:\n")
fmt.Printf(" Simple: %v\n", simpleTime)
fmt.Printf(" Advanced: %v (%.1fx slower due to multiple runs)\n",
advancedTime, float64(advancedTime)/float64(simpleTime))
}

func mean(x []float32) float64 {
var s float64
for _, v := range x {
s += float64(v)
}
return s / float64(len(x))
}
170 changes: 170 additions & 0 deletions clustering.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
package faiss

/*
#include <faiss/c_api/Clustering_c.h>
#include <faiss/c_api/Index_c.h>
*/
import "C"
import "unsafe"

type ClusteringParameters struct {
Niter int // Number of clustering iterations
Nredo int // Number of times to redo clustering and keep best
Verbose bool // Verbose output
Spherical bool // Do we want normalized centroids?
IntCentroids bool // Round centroids coordinates to integer
UpdateIndex bool // Update index after each iteration?
FrozenCentroids bool // Use the centroids provided as input and do not change them during iterations
MinPointsPerCentroid int // Otherwise you get a warning
MaxPointsPerCentroid int // To limit size of dataset
Seed int // Seed for the random number generator
DecodeBlockSize uint64 // How many vectors at a time to decode
}

// Create a new ClusteringParameters with default values.
func NewClusteringParameters() *ClusteringParameters {
var cparams C.FaissClusteringParameters
C.faiss_ClusteringParameters_init(&cparams)

return &ClusteringParameters{
Niter: int(cparams.niter),
Nredo: int(cparams.nredo),
Verbose: cparams.verbose != 0,
Spherical: cparams.spherical != 0,
IntCentroids: cparams.int_centroids != 0,
UpdateIndex: cparams.update_index != 0,
FrozenCentroids: cparams.frozen_centroids != 0,
MinPointsPerCentroid: int(cparams.min_points_per_centroid),
MaxPointsPerCentroid: int(cparams.max_points_per_centroid),
Seed: int(cparams.seed),
DecodeBlockSize: uint64(cparams.decode_block_size),
}
}

func (p *ClusteringParameters) toCStruct() C.FaissClusteringParameters {
return C.FaissClusteringParameters{
niter: C.int(p.Niter),
nredo: C.int(p.Nredo),
verbose: boolToInt(p.Verbose),
spherical: boolToInt(p.Spherical),
int_centroids: boolToInt(p.IntCentroids),
update_index: boolToInt(p.UpdateIndex),
frozen_centroids: boolToInt(p.FrozenCentroids),
min_points_per_centroid: C.int(p.MinPointsPerCentroid),
max_points_per_centroid: C.int(p.MaxPointsPerCentroid),
seed: C.int(p.Seed),
decode_block_size: C.size_t(p.DecodeBlockSize),
}
}

type Clustering struct {
clustering *C.FaissClustering
d int
k int
}

// Create a new clustering object with default parameters.
func NewClustering(d, k int) (*Clustering, error) {
var clustering *C.FaissClustering
if c := C.faiss_Clustering_new(&clustering, C.int(d), C.int(k)); c != 0 {
return nil, getLastError()
}
return &Clustering{
clustering: clustering,
d: d,
k: k,
}, nil
}

func NewClusteringWithParams(d, k int, params *ClusteringParameters) (*Clustering, error) {
var clustering *C.FaissClustering
cparams := params.toCStruct()
if c := C.faiss_Clustering_new_with_params(&clustering, C.int(d), C.int(k), &cparams); c != 0 {
return nil, getLastError()
}
return &Clustering{
clustering: clustering,
d: d,
k: k,
}, nil
}

// Return the dimension of the vectors.
func (c *Clustering) D() int {
return c.d
}

// Return the number of clusters.
func (c *Clustering) K() int {
return c.k
}

func (c *Clustering) cPtr() *C.FaissClustering {
return c.clustering
}

// Train performs the k-means clustering on the provided vectors.
// The index parameter can be used to accelerate the clustering by providing
// a fast way to perform nearest-neighbor queries. If nil, a default index
// will be used internally.
func (c *Clustering) Train(x []float32, index Index) error {
n := len(x) / c.D()

var idx *C.FaissIndex
if index != nil {
idx = index.cPtr()
}

if code := C.faiss_Clustering_train(
c.clustering,
C.idx_t(n),
(*C.float)(&x[0]),
idx,
); code != 0 {
return getLastError()
}
return nil
}

// Return the cluster centroids after training.
func (c *Clustering) Centroids() []float32 {
var centroids *C.float
var size C.size_t
C.faiss_Clustering_centroids(c.clustering, &centroids, &size)
return (*[1 << 30]float32)(unsafe.Pointer(centroids))[:size:size]
}

// Free the memory used by the clustering object.
func (c *Clustering) Close() {
if c.clustering != nil {
C.faiss_Clustering_free(c.clustering)
c.clustering = nil
}
}

// KMeansClustering is a simplified interface for k-means clustering.
// It performs clustering and returns the centroids and quantization error.
func KMeansClustering(d, n, k int, x []float32) (centroids []float32, qerr float32, err error) {
centroids = make([]float32, k*d)
var cqerr C.float

if c := C.faiss_kmeans_clustering(
C.size_t(d),
C.size_t(n),
C.size_t(k),
(*C.float)(&x[0]),
(*C.float)(&centroids[0]),
&cqerr,
); c != 0 {
return nil, 0, getLastError()
}

return centroids, float32(cqerr), nil
}

func boolToInt(b bool) C.int {
if b {
return 1
}
return 0
}