Program Listing for File hnsw_model.h¶
↰ Return to documentation for file (include/n2/hnsw_model.h
)
// Copyright 2017 Kakao Corp. <http://www.kakaocorp.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "common.h"
#include "hnsw_node.h"
#include "mmap.h"
namespace n2 {
class HnswModel {
public:
static std::shared_ptr<const HnswModel> GenerateModel(const std::vector<HnswNode*> nodes, int enterpoint_id,
int max_m, int max_m0, DistanceKind metric, int max_level,
size_t data_dim);
static std::shared_ptr<const HnswModel> LoadModelFromFile(const std::string& fname, const bool use_mmap=true);
~HnswModel();
bool SaveModelToFile(const std::string& fname) const;
HnswModel(const HnswModel&) = delete;
void operator=(const HnswModel&) = delete;
inline int GetNumNodes() const { return num_nodes_; }
inline int GetEnterpointId() const { return enterpoint_id_; }
inline int GetMaxLevel() const { return max_level_; }
inline int GetDataDim() const { return data_dim_; }
inline DistanceKind GetMetric() const { return metric_; }
inline const float* GetData(int node_id) const {
return (const float*)(model_level0_node_base_offset_ + node_id * memory_per_node_level0_);
}
inline const int* GetHigherLevelFriendsWithSize(int node_id, int level) const {
int offset = *((int*)(model_level0_ + node_id * memory_per_node_level0_));
return (const int*)(model_higher_level_ + (offset+level-1) * memory_per_node_higher_level_);
}
inline const int* GetLevel0FriendsWithSize(int node_id) const {
return (const int*)(model_level0_ + node_id * memory_per_node_level0_ + sizeof(int));
}
private:
HnswModel(const std::vector<HnswNode*> nodes, int enterpoint_id, int max_m, int max_m0, DistanceKind metric,
int max_level, size_t data_dim);
HnswModel(const std::string& fname, const bool use_mmap);
size_t GetConfigSize();
void SaveConfigToModel();
void LoadConfigFromModel();
template <typename T>
char* SetValueAndIncPtr(char* ptr, const T& val) {
*((T*)(ptr)) = val;
return ptr + sizeof(T);
}
template <typename T>
char* GetValueAndIncPtr(char* ptr, T& val) {
val = *((T*)(ptr));
return ptr + sizeof(T);
}
public:
int enterpoint_id_;
int num_nodes_;
int max_level_;
size_t data_dim_ = 0;
DistanceKind metric_;
char* model_ = nullptr;
uint64_t model_byte_size_;
char* model_higher_level_ = nullptr;
char* model_level0_ = nullptr;
char* model_level0_node_base_offset_ = nullptr;
uint64_t memory_per_data_;
uint64_t memory_per_link_level0_;
uint64_t memory_per_node_level0_;
uint64_t memory_per_node_higher_level_;
Mmap* model_mmap_ = nullptr;
};
} // namespace n2