|
2 | 2 |
|
3 | 3 | #include "visited_list_pool.h"
|
4 | 4 | #include "hnswlib.h"
|
5 |
| -#include <atomic> |
6 |
| -#include <random> |
7 |
| -#include <stdlib.h> |
| 5 | + |
8 | 6 | #include <assert.h>
|
9 |
| -#include <unordered_set> |
| 7 | +#include <stdlib.h> |
| 8 | + |
| 9 | +#include <atomic> |
| 10 | +#include <limits> |
10 | 11 | #include <list>
|
11 | 12 | #include <memory>
|
| 13 | +#include <mutex> |
| 14 | +#include <random> |
| 15 | +#include <unordered_set> |
12 | 16 |
|
13 | 17 | namespace hnswlib {
|
14 | 18 | typedef unsigned int tableint;
|
| 19 | +constexpr tableint kInvalidInternalId = std::numeric_limits<tableint>::max(); |
15 | 20 | typedef unsigned int linklistsizeint;
|
16 | 21 |
|
17 | 22 | template<typename dist_t>
|
@@ -195,6 +200,17 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
195 | 200 | }
|
196 | 201 |
|
197 | 202 |
|
| 203 | + tableint getInternalIdByLabel(labeltype label) const { |
| 204 | + std::lock_guard<std::mutex> lock_table(label_lookup_lock); |
| 205 | + auto label_lookup_result = label_lookup_.find(label); |
| 206 | + if (label_lookup_result == label_lookup_.end() || |
| 207 | + isMarkedDeleted(label_lookup_result->second)) { |
| 208 | + return kInvalidInternalId; |
| 209 | + } |
| 210 | + return label_lookup_result->second; |
| 211 | + } |
| 212 | + |
| 213 | + |
198 | 214 | inline void setExternalLabel(tableint internal_id, labeltype label) const {
|
199 | 215 | memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype));
|
200 | 216 | }
|
@@ -870,13 +886,10 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
870 | 886 | // lock all operations with element by label
|
871 | 887 | std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
|
872 | 888 |
|
873 |
| - std::unique_lock <std::mutex> lock_table(label_lookup_lock); |
874 |
| - auto search = label_lookup_.find(label); |
875 |
| - if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { |
| 889 | + tableint internalId = getInternalIdByLabel(label); |
| 890 | + if (internalId == kInvalidInternalId) { |
876 | 891 | return Status("Label not found");
|
877 | 892 | }
|
878 |
| - tableint internalId = search->second; |
879 |
| - lock_table.unlock(); |
880 | 893 |
|
881 | 894 | char* data_ptrv = getDataByInternalId(internalId);
|
882 | 895 | size_t dim = *((size_t *) dist_func_param_);
|
@@ -1190,7 +1203,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
1190 | 1203 | }
|
1191 | 1204 |
|
1192 | 1205 |
|
1193 |
| - // This internal function adds a point at a specific level. If level is |
| 1206 | + // This internal function adds a point at a specific level. |
1194 | 1207 | StatusOr<tableint> addPointWithLevel(const void *data_point, labeltype label, int level) {
|
1195 | 1208 | tableint cur_c = 0;
|
1196 | 1209 | {
|
|
0 commit comments