Program Listing for File hnsw_search_impl.h¶
↰ Return to documentation for file (include/n2/hnsw_search_impl.h
)
// Copyright 2017 Kakao Corp. <http://www.kakaocorp.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "common.h"
#include "distance.h"
#include "hnsw_model.h"
#include "hnsw_search.h"
#include "min_heap.h"
#include "visited_list.h"
namespace n2 {
template<typename DistFuncType>
class HnswSearchImpl : public HnswSearch {
public:
HnswSearchImpl(std::shared_ptr<const HnswModel> model, size_t data_dim, DistanceKind metric);
void SearchByVector(const std::vector<float>& qvec, size_t k, int ef_search, bool ensure_k,
std::vector<int>& result) override;
void SearchByVector(const std::vector<float>& qvec, size_t k, int ef_search, bool ensure_k,
std::vector<std::pair<int, float>>& result) override;
void SearchById(int id, size_t k, int ef_search, bool ensure_k,
std::vector<int>& result) override;
void SearchById(int id, size_t k, int ef_search, bool ensure_k,
std::vector<std::pair<int, float>>& result) override;
protected:
template<typename ResultType>
void SearchByVector_(const std::vector<float>& qvec, size_t k, int ef_search, bool ensure_k,
ResultType& result);
inline void CallSearchById_(int cur_node_id, float cur_dist, const float* qraw, size_t k, size_t ef_search,
bool ensure_k, std::vector<int>& result) {
if (ensure_k) {
std::vector<std::pair<int, float>> tmp_result;
CallSearchById_(cur_node_id, cur_dist, qraw, k, ef_search, ensure_k, tmp_result);
for (const auto& p : tmp_result) {
result.push_back(p.first);
}
} else {
SearchById_(cur_node_id, cur_dist, qraw, k, ef_search, false, result);
}
}
inline void CallSearchById_(int cur_node_id, float cur_dist, const float* qraw, size_t k, size_t ef_search,
bool ensure_k, std::vector<std::pair<int, float>>& result) {
if (ensure_k) {
while (result.size() < k && !ensure_k_path_.empty()) {
cur_node_id = ensure_k_path_.back().first;
cur_dist = ensure_k_path_.back().second;
ensure_k_path_.pop_back();
SearchById_(cur_node_id, cur_dist, qraw, k, ef_search, ensure_k, result);
}
} else {
SearchById_(cur_node_id, cur_dist, qraw, k, ef_search, ensure_k, result);
}
}
template<typename ResultType>
inline void SearchById_(int cur_node_id, float cur_dist, const float* qraw, size_t k, size_t ef_search,
bool ensure_k, ResultType& result) {
if (ef_search < k)
SearchByIdV1_(cur_node_id, cur_dist, qraw, k, ef_search, ensure_k, result);
else
SearchByIdV2_(cur_node_id, cur_dist, qraw, k, ef_search, ensure_k, result);
}
template<typename ResultType>
void SearchByIdV1_(int cur_node_id, float cur_dist, const float* qraw, size_t k, size_t ef_search,
bool ensure_k, ResultType& result);
template<typename ResultType>
void SearchByIdV2_(int cur_node_id, float cur_dist, const float* qraw, size_t k, size_t ef_search,
bool ensure_k, ResultType& result);
bool PrepareEnsureKSearch(int cur_node_id, std::vector<int>& result, IdDistancePairMinHeap& visited_nodes);
bool PrepareEnsureKSearch(int cur_node_id, std::vector<std::pair<int, float>>& result,
IdDistancePairMinHeap& visited_nodes);
void MakeSearchResult(size_t k, IdDistancePairMinHeap& candidates, IdDistancePairMinHeap& visited_nodes,
std::vector<int>& result);
void MakeSearchResult(size_t k, IdDistancePairMinHeap& candidates, IdDistancePairMinHeap& visited_nodes,
std::vector<std::pair<int, float>>& result);
protected:
std::shared_ptr<const HnswModel> model_;
std::unique_ptr<VisitedList> visited_list_;
size_t data_dim_;
DistanceKind metric_;
DistFuncType dist_func_;
// preallocated buffer
std::vector<float> normalized_vec_;
std::vector<std::pair<int, float>> ensure_k_path_;
// raw pointer of model
char* model_higher_level_ = nullptr;
char* model_level0_ = nullptr;
char* model_level0_node_base_offset_ = nullptr;
uint64_t memory_per_node_level0_;
uint64_t memory_per_node_higher_level_;
};
using HnswSearchAngular = HnswSearchImpl<AngularDistance>;
using HnswSearchL2 = HnswSearchImpl<L2Distance>;
using HnswSearchDot = HnswSearchImpl<DotDistance>;
} // namespace n2