diff --git a/src/bindings.cpp b/src/bindings.cpp index cf61137..f5c9682 100644 --- a/src/bindings.cpp +++ b/src/bindings.cpp @@ -5,60 +5,62 @@ namespace py = pybind11; PYBIND11_MODULE(SimpleHNSW, m) { - m.doc() = "SimpleHNSW - A simple HNSW (Hierarchical Navigable Small World) implementation for approximate nearest neighbor search"; - + m.doc() = "SimpleHNSW - A simple HNSW implementation for approximate nearest neighbor search"; + py::class_(m, "SimpleHNSWIndex") - .def(py::init(), - py::arg("L") = 5, - py::arg("mL") = 0.62, - py::arg("efc") = 10, - py::arg("maxConnections") = 16, - R"doc( + .def(py::init(), + py::arg("L") = 5, + py::arg("mL") = 0.62, + py::arg("efc") = 10, + py::arg("maxConnections") = 16, + py::arg("seed") = 0u, + R"doc( Initialize a SimpleHNSW index. - + Args: - L (int): Number of layers in the hierarchical graph (default: 5) + L (int): Number of layers (default: 5) mL (float): Normalization factor for layer assignment (default: 0.62) efc (int): Size of the dynamic candidate list during construction (default: 10) maxConnections (int): Maximum number of connections per node (default: 16) - )doc") + seed (int): RNG seed (0 => non-deterministic) + )doc") .def("insert", &SimpleHNSWIndex::insert, - py::arg("vector"), - R"doc( + py::arg("vector"), + R"doc( Insert a vector into the index. - + Args: vector (list[float]): The vector to insert - )doc") + )doc") .def("search", &SimpleHNSWIndex::search, - py::arg("query"), - py::arg("ef") = 1, - R"doc( + py::arg("query"), + py::arg("ef") = 1, + R"doc( Search for the nearest neighbors of a query vector. - + Args: query (list[float]): The query vector ef (int): Size of the dynamic candidate list during search (default: 1) - + Returns: list[tuple[float, int]]: List of (distance, index) pairs for nearest neighbors - )doc") + )doc") .def("toJSON", &SimpleHNSWIndex::toJSON, - R"doc( + R"doc( Serialize the index to a JSON string. - + Returns: str: JSON representation of the index - )doc") + )doc") .def_static("fromJSON", &SimpleHNSWIndex::fromJSON, - py::arg("json"), - R"doc( + py::arg("json"), + R"doc( Deserialize an index from a JSON string. - + Args: json (str): JSON representation of the index - + Returns: SimpleHNSWIndex: Deserialized index - )doc"); + )doc"); } diff --git a/src/cache.cpp b/src/cache.cpp index 233e0ec..7557dce 100644 --- a/src/cache.cpp +++ b/src/cache.cpp @@ -4,22 +4,16 @@ #include #include "lru_cache.h" +// Provide a thread-safe initialization via function-local static class Cache { private: - static std::unique_ptr>> instance; - Cache() {} public: + // Return a reference to a process-wide cache instance. + // Function-local static ensures thread-safe initialization (C++11+). static LRUCache>& getInstance(size_t max = 10000, std::chrono::milliseconds maxAge = std::chrono::milliseconds(1000 * 60 * 10)) { - if (!instance) { - instance = std::make_unique>>(max, maxAge); - } - return *instance; + static LRUCache> instance(max, maxAge); + return instance; } }; - -// Initialize the static member -std::unique_ptr>> Cache::instance = nullptr; - - diff --git a/src/lru_cache.h b/src/lru_cache.h index 53e3b13..a20c3f5 100644 --- a/src/lru_cache.h +++ b/src/lru_cache.h @@ -3,91 +3,99 @@ #include #include -#include #include -#include -#include -#include -#include +#include +#include +#include +#include +#include -// LRUCache template class template class LRUCache { -private: - using Timestamp = std::chrono::steady_clock::time_point; - - struct CacheItem { - Value value; - Timestamp timestamp; - }; - - std::list> itemList; - std::unordered_map itemMap; - size_t maxSize; - std::chrono::milliseconds maxAge; - - void moveToFront(typename decltype(itemList)::iterator it) { - itemList.splice(itemList.begin(), itemList, it); - } - - void evict() { - while (itemList.size() > maxSize || (maxAge.count() > 0 && !itemList.empty() && - std::chrono::duration_cast(std::chrono::steady_clock::now() - itemList.back().second.timestamp).count() > maxAge.count())) { - itemMap.erase(itemList.back().first); - itemList.pop_back(); - } - } - public: + using Clock = std::chrono::steady_clock; + using Timestamp = Clock::time_point; + LRUCache(size_t maxSize, std::chrono::milliseconds maxAge = std::chrono::milliseconds(0)) - : maxSize(maxSize), maxAge(maxAge) {} + : maxSize_(std::max(1, maxSize)), maxAge_(maxAge) {} + // Put (copy) void put(const Key& key, const Value& value) { - auto now = std::chrono::steady_clock::now(); - auto it = itemMap.find(key); - if (it != itemMap.end()) { - it->second->second.value = value; - it->second->second.timestamp = now; - moveToFront(it->second); - } else { - itemList.push_front({ key, { value, now } }); - itemMap[key] = itemList.begin(); - } - evict(); + put_impl(key, value); } - Value get(const Key& key) { - auto it = itemMap.find(key); - if (it == itemMap.end()) { - throw std::runtime_error("Key not found"); - } - moveToFront(it->second); + // Put (move) + void put(Key&& key, Value&& value) { + put_impl(std::move(key), std::move(value)); + } + + // Get: returns optional to avoid throwing for missing keys + std::optional get(const Key& key) { + std::lock_guard lock(mutex_); + auto it = map_.find(key); + if (it == map_.end()) return std::nullopt; + // move node to front + item_list_.splice(item_list_.begin(), item_list_, it->second); + it->second->second.timestamp = Clock::now(); return it->second->second.value; } bool contains(const Key& key) const { - return itemMap.find(key) != itemMap.end(); + std::lock_guard lock(mutex_); + return map_.find(key) != map_.end(); } size_t size() const { - return itemMap.size(); + std::lock_guard lock(mutex_); + return map_.size(); + } + + void clear() { + std::lock_guard lock(mutex_); + item_list_.clear(); + map_.clear(); } -}; -// Cache singleton class -// class Cache { -// private: -// static std::unique_ptr>> instance; +private: + struct CacheItem { + Value value; + Timestamp timestamp; + }; -// Cache() {} + using ListIt = typename std::list>::iterator; -// public: -// static LRUCache>& getInstance(size_t max = 10000, std::chrono::milliseconds maxAge = std::chrono::milliseconds(1000 * 60 * 10)) { -// if (!instance) { -// instance = std::make_unique>>(max, maxAge); -// } -// return *instance; -// } -// }; + template + void put_impl(K&& key, V&& value) { + std::lock_guard lock(mutex_); + auto now = Clock::now(); + auto it = map_.find(key); + if (it != map_.end()) { + // Update existing + it->second->second.value = std::forward(value); + it->second->second.timestamp = now; + item_list_.splice(item_list_.begin(), item_list_, it->second); + } else { + // Insert new + item_list_.emplace_front(std::forward(key), CacheItem{ std::forward(value), now }); + map_[item_list_.begin()->first] = item_list_.begin(); + evict_if_needed(); + } + } + + void evict_if_needed() { + while (item_list_.size() > maxSize_ || + (maxAge_.count() > 0 && !item_list_.empty() && + std::chrono::duration_cast(Clock::now() - item_list_.back().second.timestamp) > maxAge_)) { + map_.erase(item_list_.back().first); + item_list_.pop_back(); + } + } + + mutable std::mutex mutex_; + std::list> item_list_; + std::unordered_map map_; + size_t maxSize_; + std::chrono::milliseconds maxAge_; +}; #endif // LRU_CACHE_H diff --git a/src/not_implemented_exception.h b/src/not_implemented_exception.h index 5bf4cbc..7075a21 100644 --- a/src/not_implemented_exception.h +++ b/src/not_implemented_exception.h @@ -1,41 +1,18 @@ #ifndef NOT_IMPLEMENTED_EXCEPTION_H #define NOT_IMPLEMENTED_EXCEPTION_H +#include #include -class NotImplementedException : public std::logic_error -{ -private: - - std::string _text; - - NotImplementedException(const char* message, const char* function) - : - std::logic_error("Not Implemented") - { - _text = message; - _text += " : "; - _text += function; - }; - +class NotImplementedException : public std::logic_error { public: + explicit NotImplementedException(const std::string& message = "Not Implemented") + : std::logic_error(message) {} - NotImplementedException() - : - NotImplementedException("Not Implememented", __FUNCTION__) - { - } - - NotImplementedException(const char* message) - : - NotImplementedException(message, __FUNCTION__) - { - } - - virtual const char *what() const throw() - { - return _text.c_str(); + // Use noexcept-qualified what() override for compatibility and clarity + const char* what() const noexcept override { + return std::logic_error::what(); } }; -#endif //NOT_IMPLEMENTED_EXCEPTION_H \ No newline at end of file +#endif // NOT_IMPLEMENTED_EXCEPTION_H diff --git a/src/priority_queue.h b/src/priority_queue.h index 2dbfc36..580062e 100644 --- a/src/priority_queue.h +++ b/src/priority_queue.h @@ -2,27 +2,23 @@ #define PRIORITY_QUEUE_H #include -#include -#include -#include #include -#include #include #include -template +template> class PriorityQueue { private: std::vector elements; - using Compare = std::function; Compare compareFn; struct HeapComparator { Compare compareFn; - bool operator()(const T& lhs, const T& rhs) const { - // Flip operands so compareFn's "smaller" element rises to the top. - return compareFn(rhs, lhs); + // std::make_heap expects a comparator where the top is the "largest" element; + // when using std::less it will create a max-heap. We want compareFn to describe + // the ordering so use compareFn on lhs, rhs directly. + return compareFn(lhs, rhs); } }; @@ -31,9 +27,9 @@ class PriorityQueue { } public: - PriorityQueue(std::vector elements, Compare compareFn) - : elements(std::move(elements)), compareFn(std::move(compareFn)) { - std::make_heap(this->elements.begin(), this->elements.end(), heapComparator()); + PriorityQueue(std::vector initElements = {}, Compare cmp = Compare()) + : elements(std::move(initElements)), compareFn(std::move(cmp)) { + std::make_heap(elements.begin(), elements.end(), heapComparator()); } void push(const T& element) { @@ -53,9 +49,17 @@ class PriorityQueue { return element; } + const T& top() const { + return elements.front(); + } + bool isEmpty() const { return elements.empty(); } + + size_t size() const { + return elements.size(); + } }; -#endif //PRIORITY_QUEUE_H +#endif // PRIORITY_QUEUE_H diff --git a/src/simple_hnsw.h b/src/simple_hnsw.h index 9bd87f1..60a915c 100644 --- a/src/simple_hnsw.h +++ b/src/simple_hnsw.h @@ -6,119 +6,44 @@ #include #include #include -// #include +#include +#include +#include #include #include "not_implemented_exception.h" -using namespace std; using json = nlohmann::json; +// Public types using Vector = std::vector; using Distance = double; -using NodeIndex = unsigned long; +using NodeIndex = std::size_t; +static constexpr NodeIndex INVALID_NODE = std::numeric_limits::max(); struct LayerNode { Vector vector; std::vector connections; - NodeIndex layerBelow; + NodeIndex layerBelow = INVALID_NODE; }; using Layer = std::vector; -double EuclideanDistance(const Vector& a, const Vector& b) { +// Use squared Euclidean distance for comparisons to avoid sqrt costs +inline double squaredEuclideanDistance(const Vector& a, const Vector& b) { if (a.size() != b.size()) { throw std::invalid_argument("Vectors must have the same length"); } - double sum = 0.0; for (size_t i = 0; i < a.size(); ++i) { double diff = a[i] - b[i]; sum += diff * diff; } - return std::sqrt(sum); + return sum; } -int getInsertLayer(int L, double mL) { - return std::min(static_cast(-std::floor(std::log(std::rand() / static_cast(RAND_MAX)) * mL)), L - 1); -} - -std::vector> _searchLayer( - const Layer& graph, NodeIndex entry, const Vector& query, int ef) { - if (entry < 0 || entry >= graph.size()) { - throw std::invalid_argument("Invalid entry index"); - } - - if (ef <= 0) { - return {}; - } - - const LayerNode& graphEntry = graph[entry]; - Distance entryDist = EuclideanDistance(graphEntry.vector, query); - std::vector visited(graph.size(), false); - visited[entry] = true; - - auto maxHeapComp = [](const std::pair& lhs, const std::pair& rhs) { - return lhs.first < rhs.first; - }; - - std::vector> best; - best.reserve(static_cast(ef)); - - auto emplaceBest = [&](const std::pair& candidate) { - if (best.size() < static_cast(ef)) { - best.push_back(candidate); - std::push_heap(best.begin(), best.end(), maxHeapComp); - } else if (candidate.first < best.front().first) { - std::pop_heap(best.begin(), best.end(), maxHeapComp); - best.back() = candidate; - std::push_heap(best.begin(), best.end(), maxHeapComp); - } - }; - - emplaceBest({ entryDist, entry }); - - auto minHeapComp = [](const std::pair& lhs, const std::pair& rhs) { - return lhs.first > rhs.first; - }; - std::priority_queue< - std::pair, - std::vector>, - decltype(minHeapComp)> candidates(minHeapComp); - - candidates.emplace(entryDist, entry); - - while (!candidates.empty()) { - auto current = candidates.top(); - candidates.pop(); - - if (!best.empty() && current.first > best.front().first) { - continue; - } - - const LayerNode& graphCurrent = graph[current.second]; - for (NodeIndex neighbor : graphCurrent.connections) { - if (neighbor >= graph.size()) { - continue; - } - if (visited[neighbor]) { - continue; - } - visited[neighbor] = true; - - const LayerNode& graphNeighbor = graph[neighbor]; - Distance dist = EuclideanDistance(graphNeighbor.vector, query); - - if (best.size() < static_cast(ef) || dist < best.front().first) { - candidates.emplace(dist, neighbor); - emplaceBest({ dist, neighbor }); - } - } - } - - std::sort(best.begin(), best.end(), [](const auto& lhs, const auto& rhs) { - return lhs.first < rhs.first; - }); - return best; +// Return actual Euclidean distance when caller needs it +inline double euclideanDistance(const Vector& a, const Vector& b) { + return std::sqrt(squaredEuclideanDistance(a, b)); } class SimpleHNSWIndex { @@ -129,12 +54,15 @@ class SimpleHNSWIndex { int maxConnections; std::vector index; + // RNG for deterministic/controllable layer selection + std::mt19937 rng; + static bool containsConnection(const LayerNode& node, NodeIndex target) { return std::find(node.connections.begin(), node.connections.end(), target) != node.connections.end(); } void pruneNodeConnections(Layer& layer, NodeIndex nodeIndex) { - if (nodeIndex >= layer.size()) { + if (nodeIndex == INVALID_NODE || nodeIndex >= layer.size()) { return; } @@ -143,14 +71,14 @@ class SimpleHNSWIndex { return; } - std::vector> scored; + std::vector> scored; scored.reserve(node.connections.size()); for (NodeIndex connection : node.connections) { if (connection == nodeIndex || connection >= layer.size()) { continue; } const LayerNode& neighbor = layer[connection]; - scored.emplace_back(EuclideanDistance(node.vector, neighbor.vector), connection); + scored.emplace_back(squaredEuclideanDistance(node.vector, neighbor.vector), connection); } if (scored.empty()) { @@ -160,13 +88,8 @@ class SimpleHNSWIndex { size_t target = static_cast(maxConnections); if (scored.size() > target) { - using ScoreDifferenceType = std::vector>::difference_type; - auto nth = scored.begin() + static_cast(target); - std::nth_element( - scored.begin(), - nth, - scored.end(), - [](const auto& lhs, const auto& rhs) { return lhs.first < rhs.first; }); + auto nth = scored.begin() + static_cast(target); + std::nth_element(scored.begin(), nth, scored.end(), [](const auto& lhs, const auto& rhs) { return lhs.first < rhs.first; }); scored.resize(target); } @@ -183,17 +106,100 @@ class SimpleHNSWIndex { } } + // Internal search in a single layer. Returns vector of (squaredDistance, NodeIndex) sorted ascending. + std::vector> _searchLayer(const Layer& graph, NodeIndex entry, const Vector& query, int ef) const { + if (graph.empty()) return {}; + if (entry == INVALID_NODE || entry >= graph.size()) { + throw std::invalid_argument("Invalid entry index"); + } + if (ef <= 0) return {}; + + // compute entry distance + const LayerNode& graphEntry = graph[entry]; + double entryDist = squaredEuclideanDistance(graphEntry.vector, query); + + std::vector visited(graph.size(), false); + visited[entry] = true; + + auto maxHeapComp = [](const std::pair& lhs, const std::pair& rhs) { + return lhs.first < rhs.first; // max-heap by distance + }; + + std::vector> best; + best.reserve(static_cast(ef)); + + auto emplaceBest = [&](const std::pair& candidate) { + if (best.size() < static_cast(ef)) { + best.push_back(candidate); + std::push_heap(best.begin(), best.end(), maxHeapComp); + } else if (candidate.first < best.front().first) { + std::pop_heap(best.begin(), best.end(), maxHeapComp); + best.back() = candidate; + std::push_heap(best.begin(), best.end(), maxHeapComp); + } + }; + + emplaceBest({ entryDist, entry }); + + auto minHeapComp = [](const std::pair& lhs, const std::pair& rhs) { + return lhs.first > rhs.first; // min-heap by distance + }; + std::priority_queue, std::vector>, decltype(minHeapComp)> candidates(minHeapComp); + + candidates.emplace(entryDist, entry); + + while (!candidates.empty()) { + auto current = candidates.top(); + candidates.pop(); + + if (!best.empty() && current.first > best.front().first) { + continue; + } + + const LayerNode& graphCurrent = graph[current.second]; + for (NodeIndex neighbor : graphCurrent.connections) { + if (neighbor >= graph.size()) { + continue; + } + if (visited[neighbor]) continue; + visited[neighbor] = true; + + const LayerNode& graphNeighbor = graph[neighbor]; + double dist = squaredEuclideanDistance(graphNeighbor.vector, query); + + if (best.size() < static_cast(ef) || dist < best.front().first) { + candidates.emplace(dist, neighbor); + emplaceBest({ dist, neighbor }); + } + } + } + + std::sort(best.begin(), best.end(), [](const auto& lhs, const auto& rhs) { return lhs.first < rhs.first; }); + return best; + } + public: - SimpleHNSWIndex(int L = 5, double mL = 0.62, int efc = 10, int maxConnectionsPerLayer = 16) - : L(L), - mL(mL), - efc(efc), - maxConnections(std::max(1, maxConnectionsPerLayer)), - index(L) {} - - void setIndex(const std::vector& index) { - this->index = index; - for (auto& layer : this->index) { + SimpleHNSWIndex(int L_ = 5, double mL_ = 0.62, int efc_ = 10, int maxConnectionsPerLayer = 16, uint32_t seed = std::random_device{}()) + : L(L_), mL(mL_), efc(efc_), maxConnections(std::max(1, maxConnectionsPerLayer)), index(), rng(seed) { + if (L <= 0) throw std::invalid_argument("L must be positive"); + index.resize(static_cast(L)); + } + + // Deterministic or seeded layer assignment using uniform real and negative log: + int getInsertLayer() { + std::uniform_real_distribution dist(0.0, 1.0); + double u = dist(rng); + // ensure u in (0,1], avoid log(0) + if (u <= 0.0) u = std::numeric_limits::min(); + int layer = static_cast(std::floor(-std::log(u) * mL)); + if (layer < 0) layer = 0; + if (layer >= L) layer = L - 1; + return layer; + } + + void setIndex(const std::vector& newIndex) { + index = newIndex; + for (auto& layer : index) { for (size_t i = 0; i < layer.size(); ++i) { pruneNodeConnections(layer, static_cast(i)); } @@ -201,127 +207,167 @@ class SimpleHNSWIndex { } void insert(const Vector& vec) { - int l = getInsertLayer(L, mL); - int startV = 0; - + int l = getInsertLayer(); + NodeIndex startV = 0; + // If layer 0 is empty, insert into all intermediate empty layers appropriately for (int n = 0; n < L; ++n) { - Layer& graph = index[n]; + Layer& graph = index[static_cast(n)]; if (graph.empty()) { - graph.push_back({ vec, {}, n < L - 1 ? static_cast(index[n + 1].size()) : static_cast(-1) }); + LayerNode ln; + ln.vector = vec; + ln.layerBelow = (n < L - 1 ? static_cast(index[static_cast(n + 1)].size()) : INVALID_NODE); + graph.push_back(std::move(ln)); continue; } if (n < l) { - auto searchLayerResult = _searchLayer(graph, startV, vec, 1); - startV = searchLayerResult[0].second; + auto res = _searchLayer(graph, startV, vec, 1); + if (res.empty()) { + startV = 0; + } else { + startV = res[0].second; + } } else { - LayerNode node = { vec, {}, n < L - 1 ? static_cast(index[n + 1].size()) : static_cast(-1) }; + LayerNode node; + node.vector = vec; + node.layerBelow = (n < L - 1 ? static_cast(index[static_cast(n + 1)].size()) : INVALID_NODE); auto nns = _searchLayer(graph, startV, vec, efc); std::vector selectedNeighbors; selectedNeighbors.reserve(std::min(static_cast(maxConnections), nns.size())); for (const auto& nn : nns) { - if (selectedNeighbors.size() >= static_cast(maxConnections)) { - break; - } + if (selectedNeighbors.size() >= static_cast(maxConnections)) break; selectedNeighbors.push_back(nn.second); } node.connections = selectedNeighbors; NodeIndex newIndex = static_cast(graph.size()); - graph.push_back(node); + graph.push_back(std::move(node)); + pruneNodeConnections(graph, newIndex); for (NodeIndex neighborIndex : selectedNeighbors) { - if (neighborIndex >= graph.size()) { - continue; - } + if (neighborIndex >= graph.size()) continue; LayerNode& neighborNode = graph[neighborIndex]; if (!containsConnection(neighborNode, newIndex)) { neighborNode.connections.push_back(newIndex); } pruneNodeConnections(graph, neighborIndex); - if (!containsConnection(neighborNode, newIndex)) { - auto& newConnections = graph[newIndex].connections; - newConnections.erase(std::remove(newConnections.begin(), newConnections.end(), neighborIndex), newConnections.end()); - } + + // ensure symmetric-ish relationship: remove neighbor from new node if exceeded + auto& newConnections = graph[newIndex].connections; + newConnections.erase(std::remove(newConnections.begin(), newConnections.end(), neighborIndex), newConnections.end()); } pruneNodeConnections(graph, newIndex); - startV = graph[startV].layerBelow; + + // prepare startV for next layer down + if (startV < graph.size()) { + startV = graph[startV].layerBelow; + } else { + startV = INVALID_NODE; + } } } } - std::vector> search(const Vector& query, int ef = 1) { - if (index[0].empty()) { + // Public search: returns actual Euclidean distances + node indices, limited by ef parameter + std::vector> search(const Vector& query, int ef = 1) const { + if (index.empty() || index[0].empty()) { return {}; } - int bestV = 0; - for (const auto& graph : index) { + NodeIndex bestV = 0; + for (const Layer& graph : index) { auto searchLayer = _searchLayer(graph, bestV, query, ef); + if (searchLayer.empty()) continue; bestV = searchLayer[0].second; - if (graph[bestV].layerBelow == -1) { - return _searchLayer(graph, bestV, query, ef); + if (graph[bestV].layerBelow == INVALID_NODE) { + // convert squared distances to actual distances + std::vector> out; + out.reserve(searchLayer.size()); + for (const auto& p : _searchLayer(graph, bestV, query, ef)) { + out.emplace_back(std::sqrt(p.first), p.second); + } + return out; } bestV = graph[bestV].layerBelow; + if (bestV == INVALID_NODE) break; } return {}; } + // JSON serialization with versioning and basic validation std::string toJSON() const { - json jsonData; - jsonData["L"] = L; - jsonData["mL"] = mL; - jsonData["efc"] = efc; - jsonData["maxConnections"] = maxConnections; - + json j; + j["version"] = 1; + j["L"] = L; + j["mL"] = mL; + j["efc"] = efc; + j["maxConnections"] = maxConnections; + + j["index"] = json::array(); for (const auto& layer : index) { - json layerData; + json layerData = json::array(); for (const auto& node : layer) { json nodeData; nodeData["vector"] = node.vector; nodeData["connections"] = node.connections; - nodeData["layerBelow"] = node.layerBelow; + // store layerBelow as -1 if INVALID_NODE for compatibility + nodeData["layerBelow"] = (node.layerBelow == INVALID_NODE ? -1 : static_cast(node.layerBelow)); layerData.push_back(nodeData); } - jsonData["index"].push_back(layerData); + j["index"].push_back(layerData); } - - return jsonData.dump(); + return j.dump(); } - static SimpleHNSWIndex fromJSON(const std::string& json) { - auto jsonData = json::parse(json); + static SimpleHNSWIndex fromJSON(const std::string& str) { + json j; + try { + j = json::parse(str); + } catch (const std::exception& e) { + throw std::invalid_argument(std::string("Invalid JSON: ") + e.what()); + } + + if (!j.contains("L") || !j.contains("mL") || !j.contains("index")) { + throw std::invalid_argument("Missing required fields in JSON"); + } + + int L = j["L"].get(); + double mL = j["mL"].get(); + int efc = j.value("efc", 10); + int maxConnections = j.value("maxConnections", 16); - int L = jsonData["L"]; - double mL = jsonData["mL"]; - int efc = jsonData["efc"]; - int maxConnections = jsonData.value("maxConnections", 16); - std::vector> index(L); + if (L <= 0) throw std::invalid_argument("Invalid L in JSON"); + + std::vector indexVec(static_cast(L)); + if (!j["index"].is_array() || j["index"].size() != static_cast(L)) { + throw std::invalid_argument("Index layer count mismatch"); + } for (int i = 0; i < L; ++i) { - for (const auto& nodeData : jsonData["index"][i]) { + for (const auto& nodeData : j["index"][i]) { LayerNode node; - node.vector = nodeData["vector"].get>(); - node.connections = nodeData["connections"].get>(); - node.layerBelow = nodeData["layerBelow"]; - index[i].push_back(node); + node.vector = nodeData.at("vector").get>(); + node.connections = nodeData.at("connections").get>(); + long long layerBelowRaw = nodeData.value("layerBelow", -1); + node.layerBelow = (layerBelowRaw < 0 ? INVALID_NODE : static_cast(layerBelowRaw)); + indexVec[static_cast(i)].push_back(std::move(node)); } } - SimpleHNSWIndex hnsw(L, mL, efc, maxConnections); - hnsw.setIndex(index); - return hnsw; + SimpleHNSWIndex h(L, mL, efc, maxConnections); + h.setIndex(indexVec); + return h; } std::vector toBinary() const { throw NotImplementedException("Binary serialization is not implemented yet."); } - static SimpleHNSWIndex fromBinary(const std::vector& binary) { + static SimpleHNSWIndex fromBinary(const std::vector& /*binary*/) { throw NotImplementedException("Binary deserialization is not implemented yet."); } }; -#endif //SIMPLE_HNSW_H +#endif // SIMPLE_HNSW_H diff --git a/src/util.cpp b/src/util.cpp index 0b7ca1d..ca5c1a9 100644 --- a/src/util.cpp +++ b/src/util.cpp @@ -3,15 +3,16 @@ #include #include #include +#include // Function to calculate cosine similarity -double cosineSimilarity(const std::vector& vecA, const std::vector& vecB, int precision = 6) { - // Check if both vectors have the same length +// Returns double in range [-1, 1]. If either vector has zero magnitude, returns 0.0 +// If rounding is required, caller can round; lightweight rounding option provided. +double cosineSimilarity(const std::vector& vecA, const std::vector& vecB, int precision = -1) { if (vecA.size() != vecB.size()) { throw std::invalid_argument("Vectors must have the same length"); } - // Compute dot product and magnitudes double dotProduct = 0.0; double magnitudeA = 0.0; double magnitudeB = 0.0; @@ -25,17 +26,21 @@ double cosineSimilarity(const std::vector& vecA, const std::vector 1.0) cosineSim = 1.0; + if (cosineSim < -1.0) cosineSim = -1.0; + if (precision < 0) { + return cosineSim; + } + + double power = std::pow(10.0, precision); + return std::round(cosineSim * power) / power; +} diff --git a/src/wasm_bindings.cpp b/src/wasm_bindings.cpp index ebd2991..b724ba2 100644 --- a/src/wasm_bindings.cpp +++ b/src/wasm_bindings.cpp @@ -4,35 +4,35 @@ using namespace emscripten; -// Helper functions to convert between JavaScript and C++ types -std::vector convertJSArrayToVector(const val& jsArray) { +// Helper functions +static std::vector convertJSArrayToVector(const val& jsArray) { std::vector vec; unsigned int length = jsArray["length"].as(); + vec.reserve(length); for (unsigned int i = 0; i < length; ++i) { vec.push_back(jsArray[i].as()); } return vec; } -val convertResultsToJS(const std::vector>& results) { +static val convertResultsToJS(const std::vector>& results) { val jsArray = val::array(); for (size_t i = 0; i < results.size(); ++i) { val item = val::object(); item.set("distance", results[i].first); - item.set("nodeIndex", results[i].second); + item.set("nodeIndex", static_cast(results[i].second)); jsArray.set(i, item); } return jsArray; } -// Wrapper class for SimpleHNSWIndex with JavaScript-friendly interface class SimpleHNSWIndexWrapper { private: SimpleHNSWIndex index; public: - SimpleHNSWIndexWrapper(int L = 5, double mL = 0.62, int efc = 10, int maxConnections = 16) - : index(L, mL, efc, maxConnections) {} + SimpleHNSWIndexWrapper(int L = 5, double mL = 0.62, int efc = 10, int maxConnections = 16, unsigned int seed = 0u) + : index(L, mL, efc, maxConnections, seed) {} void insert(const val& jsVector) { Vector vec = convertJSArrayToVector(jsVector); @@ -49,21 +49,18 @@ class SimpleHNSWIndexWrapper { return index.toJSON(); } - static SimpleHNSWIndexWrapper fromJSON(const std::string& json) { - SimpleHNSWIndex loadedIndex = SimpleHNSWIndex::fromJSON(json); + static SimpleHNSWIndexWrapper fromJSON(const std::string& jsonStr) { + SimpleHNSWIndex loadedIndex = SimpleHNSWIndex::fromJSON(jsonStr); SimpleHNSWIndexWrapper wrapper; - wrapper.index = loadedIndex; + wrapper = SimpleHNSWIndexWrapper(); + wrapper.index = std::move(loadedIndex); return wrapper; } }; EMSCRIPTEN_BINDINGS(simple_hnsw) { class_("SimpleHNSWIndex") - .constructor<>() - .constructor() - .constructor() - .constructor() - .constructor() + .constructor() .function("insert", &SimpleHNSWIndexWrapper::insert) .function("search", &SimpleHNSWIndexWrapper::search) .function("toJSON", &SimpleHNSWIndexWrapper::toJSON)