Program Listing for File hnsw.h

Return to documentation for file (include/n2/hnsw.h)

// Copyright 2017 Kakao Corp. <http://www.kakaocorp.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <omp.h>

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "hnsw_build.h"
#include "hnsw_model.h"
#include "hnsw_search.h"

namespace n2 {

class Hnsw {
public:
    Hnsw();

    Hnsw(int dim,std::string metric="angular");
    Hnsw(const Hnsw& other);
    Hnsw(Hnsw&& other) noexcept;
    ~Hnsw();

    Hnsw& operator=(const Hnsw& other);
    Hnsw& operator=(Hnsw&& other) noexcept;

    // Build

    void AddData(const std::vector<float>& data);

    void SetConfigs(const std::vector<std::pair<std::string, std::string>>& configs);

    void Build(int m=-1, int max_m0=-1, int ef_construction=-1,
               int n_threads=-1, float mult=-1,
               NeighborSelectingPolicy neighbor_selecting=NeighborSelectingPolicy::HEURISTIC,
               GraphPostProcessing graph_merging=GraphPostProcessing::SKIP,
               bool ensure_k=false);

    void Fit();

    // Model
    bool SaveModel(const std::string& fname) const;

    bool LoadModel(const std::string& fname, const bool use_mmap=true);

    void UnloadModel();

    // Search
    inline void SearchByVector(const std::vector<float>& qvec, size_t k, size_t ef_search,
                               std::vector<int>& result) {
        searcher_->SearchByVector(qvec, k, ef_search, ensure_k_, result);
    }

    inline void SearchByVector(const std::vector<float>& qvec, size_t k,
                               size_t ef_search,
                               std::vector<std::pair<int, float>>& result) {
        searcher_->SearchByVector(qvec, k, ef_search, ensure_k_, result);
    }
    inline void SearchById(int id, size_t k, size_t ef_search, std::vector<int>& result) {
        searcher_->SearchById(id, k, ef_search, ensure_k_, result);
    }

    inline void SearchById(int id, size_t k, size_t ef_search,
        std::vector<std::pair<int, float>>& result) {
        searcher_->SearchById(id, k, ef_search, ensure_k_, result);
    }

    inline void BatchSearchByVectors(const std::vector<std::vector<float>>& qvecs, size_t k,
                                     size_t ef_search, size_t n_threads, std::vector<std::vector<int>>& results) {
        BatchSearchByVectors_(qvecs, k, ef_search, n_threads, results);
    }

    inline void BatchSearchByVectors(const std::vector<std::vector<float>>& qvecs, size_t k,
                                     size_t ef_search, size_t n_threads,
                                     std::vector<std::vector<std::pair<int, float>>>& results) {
        BatchSearchByVectors_(qvecs, k, ef_search, n_threads, results);
    }
    inline void BatchSearchByIds(const std::vector<int> ids, size_t k, size_t ef_search, size_t n_threads,
                                 std::vector<std::vector<int>>& results) {
        BatchSearchByIds_(ids, k, ef_search, n_threads, results);
    }

    inline void BatchSearchByIds(const std::vector<int> ids, size_t k, size_t ef_search, size_t n_threads,
                                 std::vector<std::vector<std::pair<int, float>>>& results) {
        BatchSearchByIds_(ids, k, ef_search, n_threads, results);
    }

    // Build(Misc)
    void PrintDegreeDist() const;

    void PrintConfigs() const;

private:
    void InitSearcherAndSearcherPool_();

    template<typename ResultType>
    void BatchSearchByVectors_(const std::vector<std::vector<float>>& qvecs, size_t k,
                               size_t ef_search, size_t n_threads, ResultType& results) {
        results.resize(qvecs.size());
        while (searcher_pool_.size() < n_threads) {
            searcher_pool_.push_back(HnswSearch::GenerateSearcher(model_, data_dim_, metric_));
        }

        #pragma omp parallel num_threads(n_threads)
        {
            #pragma omp for schedule(runtime)
            for (size_t i = 0; i < qvecs.size(); ++i) {
                auto& s = searcher_pool_[omp_get_thread_num()];
                s->SearchByVector(qvecs[i], k, ef_search, ensure_k_, results[i]);
            }
        }
    }

    template<typename ResultType>
    void BatchSearchByIds_(const std::vector<int> ids, size_t k, size_t ef_search, size_t n_threads,
                           ResultType& results) {
        results.resize(ids.size());
        while (searcher_pool_.size() < n_threads) {
            searcher_pool_.push_back(HnswSearch::GenerateSearcher(model_, data_dim_, metric_));
        }

        #pragma omp parallel num_threads(n_threads)
        {
            #pragma omp for schedule(runtime)
            for (size_t i = 0; i < ids.size(); ++i) {
                auto& s = searcher_pool_[omp_get_thread_num()];
                s->SearchById(ids[i], k, ef_search, ensure_k_, results[i]);
            }
        }
    }

private:
    std::unique_ptr<HnswBuild> builder_;
    std::shared_ptr<const HnswModel> model_;
    std::shared_ptr<HnswSearch> searcher_;                      // for single-thread search
    std::vector<std::shared_ptr<HnswSearch>> searcher_pool_;    // for multi-threads batch search

    size_t data_dim_;
    DistanceKind metric_;
    bool ensure_k_ = false;
};

} // namespace n2