Skip to content

Commit 2fba7fb

Browse files
Merge pull request #632 from michaelbautin/develop_get_internal_id
Add a function to get internal node id by user label
2 parents f8ae6c9 + 1b49ec6 commit 2fba7fb

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

hnswlib/hnswalg.h

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,21 @@
22

33
#include "visited_list_pool.h"
44
#include "hnswlib.h"
5-
#include <atomic>
6-
#include <random>
7-
#include <stdlib.h>
5+
86
#include <assert.h>
9-
#include <unordered_set>
7+
#include <stdlib.h>
8+
9+
#include <atomic>
10+
#include <limits>
1011
#include <list>
1112
#include <memory>
13+
#include <mutex>
14+
#include <random>
15+
#include <unordered_set>
1216

1317
namespace hnswlib {
1418
typedef unsigned int tableint;
19+
constexpr tableint kInvalidInternalId = std::numeric_limits<tableint>::max();
1520
typedef unsigned int linklistsizeint;
1621

1722
template<typename dist_t>
@@ -195,6 +200,17 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
195200
}
196201

197202

203+
tableint getInternalIdByLabel(labeltype label) const {
204+
std::lock_guard<std::mutex> lock_table(label_lookup_lock);
205+
auto label_lookup_result = label_lookup_.find(label);
206+
if (label_lookup_result == label_lookup_.end() ||
207+
isMarkedDeleted(label_lookup_result->second)) {
208+
return kInvalidInternalId;
209+
}
210+
return label_lookup_result->second;
211+
}
212+
213+
198214
inline void setExternalLabel(tableint internal_id, labeltype label) const {
199215
memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype));
200216
}
@@ -870,13 +886,10 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
870886
// lock all operations with element by label
871887
std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
872888

873-
std::unique_lock <std::mutex> lock_table(label_lookup_lock);
874-
auto search = label_lookup_.find(label);
875-
if (search == label_lookup_.end() || isMarkedDeleted(search->second)) {
889+
tableint internalId = getInternalIdByLabel(label);
890+
if (internalId == kInvalidInternalId) {
876891
return Status("Label not found");
877892
}
878-
tableint internalId = search->second;
879-
lock_table.unlock();
880893

881894
char* data_ptrv = getDataByInternalId(internalId);
882895
size_t dim = *((size_t *) dist_func_param_);
@@ -1190,7 +1203,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
11901203
}
11911204

11921205

1193-
// This internal function adds a point at a specific level. If level is
1206+
// This internal function adds a point at a specific level.
11941207
StatusOr<tableint> addPointWithLevel(const void *data_point, labeltype label, int level) {
11951208
tableint cur_c = 0;
11961209
{

0 commit comments

Comments
 (0)