Skip to content

Commit b7614fb

Browse files
authored
Merge pull request #2889 from mabel-dev/#2887
separate nested loop join
2 parents edd75eb + e1ac7f0 commit b7614fb

File tree

18 files changed

+293
-91
lines changed

18 files changed

+293
-91
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,4 @@ opteryx/third_party/tktech/csimdjson.cpp
195195
opteryx/third_party/ulfjack/ryu.c
196196
opteryx/compiled/joins/joins.pyx
197197
pyiceberg_catalog.db
198+
opteryx/compiled/joins/joins.h

opteryx/__version__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# THIS FILE IS AUTOMATICALLY UPDATED DURING THE BUILD PROCESS
22
# DO NOT EDIT THIS FILE DIRECTLY
33

4-
__build__ = 1722
4+
__build__ = 1724
55
__author__ = "@joocer"
6-
__version__ = "0.26.0-beta.1722"
6+
__version__ = "0.26.0-beta.1724"
77

88
# Store the version here so:
99
# 1) we don't load dependencies by storing it in __init__.py

opteryx/compiled/joins/cross_join.pyx

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ cpdef tuple build_filtered_rows_indices_and_column(object column, set valid_valu
213213
Py_ssize_t arr_offset = column.offset
214214
const int32_t* offsets32 = <const int32_t*><uintptr_t>(buffers[1].address)
215215
Py_ssize_t i, j, k = 0, start, end, str_len
216-
Py_ssize_t allocated_size = row_count * 4
216+
Py_ssize_t allocated_size = row_count * 4 if row_count > 0 else 4
217217

218218
numpy.ndarray indices = numpy.empty(allocated_size, dtype=numpy.int64)
219219
int64_t[::1] indices_mv = indices
@@ -258,9 +258,15 @@ cpdef tuple build_filtered_rows_indices_and_column(object column, set valid_valu
258258
if value_bytes in valid_bytes:
259259
if k >= allocated_size:
260260
allocated_size *= 2
261-
indices = numpy.resize(indices, allocated_size)
261+
new_indices = numpy.empty(allocated_size, dtype=numpy.int64)
262+
new_indices[:k] = indices_mv[:k]
263+
indices = new_indices
262264
indices_mv = indices
263-
flat_data = numpy.resize(flat_mv.base, allocated_size)
265+
266+
new_flat = numpy.empty(allocated_size, dtype=object)
267+
new_flat[:k] = flat_data[:k]
268+
flat_data = new_flat
269+
flat_mv = flat_data
264270

265271
flat_mv[k] = value_bytes
266272
indices_mv[k] = i

opteryx/compiled/joins/filter_join.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ from opteryx.third_party.abseil.containers cimport FlatHashSet
1818
cpdef FlatHashSet filter_join_set(table, list columns=None, FlatHashSet seen_hashes=None):
1919
cdef:
2020
Py_ssize_t num_rows = table.num_rows
21-
uint64_t[::1] row_hashes = numpy.zeros(num_rows, dtype=numpy.uint64)
21+
uint64_t[::1] row_hashes = numpy.empty(num_rows, dtype=numpy.uint64)
2222
list columns_of_interest = columns if columns else table.column_names
2323
Py_ssize_t row_idx
2424

@@ -37,7 +37,7 @@ cpdef semi_join(object relation, list join_columns, FlatHashSet seen_hashes):
3737
cdef:
3838
Py_ssize_t num_rows = relation.num_rows
3939
Py_ssize_t row_idx, count = 0
40-
uint64_t[::1] row_hashes = numpy.zeros(num_rows, dtype=numpy.uint64)
40+
uint64_t[::1] row_hashes = numpy.empty(num_rows, dtype=numpy.uint64)
4141
numpy.ndarray[int64_t, ndim=1] index_buffer = numpy.empty(num_rows, dtype=numpy.int64)
4242

4343
compute_row_hashes(relation, join_columns, row_hashes)
@@ -53,7 +53,7 @@ cpdef anti_join(object relation, list join_columns, FlatHashSet seen_hashes):
5353
cdef:
5454
Py_ssize_t num_rows = relation.num_rows
5555
Py_ssize_t row_idx, count = 0
56-
uint64_t[::1] row_hashes = numpy.zeros(num_rows, dtype=numpy.uint64)
56+
uint64_t[::1] row_hashes = numpy.empty(num_rows, dtype=numpy.uint64)
5757
numpy.ndarray[int64_t, ndim=1] index_buffer = numpy.empty(num_rows, dtype=numpy.int64)
5858

5959
compute_row_hashes(relation, join_columns, row_hashes)

opteryx/compiled/joins/inner_join.pyx

Lines changed: 77 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -11,47 +11,105 @@ cimport numpy
1111
numpy.import_array()
1212

1313
from libc.stdint cimport int64_t, uint64_t
14+
from libc.stddef cimport size_t
15+
from libcpp.vector cimport vector
1416

15-
from opteryx.third_party.abseil.containers cimport FlatHashMap
16-
from opteryx.compiled.structures.buffers cimport IntBuffer
17+
from time import perf_counter_ns
18+
19+
from opteryx.third_party.abseil.containers cimport (
20+
FlatHashMap,
21+
IdentityHash,
22+
flat_hash_map,
23+
)
24+
from opteryx.compiled.structures.buffers cimport CIntBuffer, IntBuffer
1725
from opteryx.compiled.table_ops.hash_ops cimport compute_row_hashes
1826
from opteryx.compiled.table_ops.null_avoidant_ops cimport non_null_row_indices
1927

28+
cdef extern from "join_kernels.h":
29+
void inner_join_probe(
30+
flat_hash_map[uint64_t, vector[int64_t], IdentityHash]* left_map,
31+
const int64_t* non_null_indices,
32+
size_t non_null_count,
33+
const uint64_t* row_hashes,
34+
size_t row_hash_count,
35+
CIntBuffer* left_out,
36+
CIntBuffer* right_out
37+
) nogil
38+
39+
cdef public long long last_hash_time_ns = 0
40+
cdef public long long last_probe_time_ns = 0
41+
cdef public long long last_materialize_time_ns = 0
42+
cdef public Py_ssize_t last_rows_hashed = 0
43+
cdef public Py_ssize_t last_candidate_rows = 0
44+
cdef public Py_ssize_t last_result_rows = 0
45+
2046

2147
cpdef tuple inner_join(object right_relation, list join_columns, FlatHashMap left_hash_table):
2248
"""
2349
Perform an inner join between a right-hand relation and a pre-built left-side hash table.
2450
This function uses precomputed hashes and avoids null rows for optimal speed.
2551
"""
52+
global last_hash_time_ns, last_probe_time_ns, last_materialize_time_ns
53+
global last_rows_hashed, last_candidate_rows, last_result_rows
2654
cdef IntBuffer left_indexes = IntBuffer()
2755
cdef IntBuffer right_indexes = IntBuffer()
2856
cdef int64_t num_rows = right_relation.num_rows
2957
cdef int64_t[::1] non_null_indices = non_null_row_indices(right_relation, join_columns)
58+
cdef Py_ssize_t candidate_count = non_null_indices.shape[0]
59+
60+
if candidate_count == 0 or num_rows == 0:
61+
last_hash_time_ns = 0
62+
last_probe_time_ns = 0
63+
last_rows_hashed = num_rows
64+
last_candidate_rows = candidate_count
65+
last_result_rows = 0
66+
last_materialize_time_ns = 0
67+
return numpy.empty(0, dtype=numpy.int64), numpy.empty(0, dtype=numpy.int64)
68+
3069
cdef uint64_t[::1] row_hashes = numpy.empty(num_rows, dtype=numpy.uint64)
31-
cdef int64_t i, row_idx
32-
cdef uint64_t hash_val
33-
cdef size_t match_count
34-
cdef int j
70+
cdef long long t_start = perf_counter_ns()
3571

3672
# Precompute hashes for right relation
3773
compute_row_hashes(right_relation, join_columns, row_hashes)
74+
cdef long long t_after_hash = perf_counter_ns()
75+
last_hash_time_ns = t_after_hash - t_start
76+
77+
with nogil:
78+
inner_join_probe(
79+
&left_hash_table._map,
80+
&non_null_indices[0],
81+
<size_t>candidate_count,
82+
&row_hashes[0],
83+
<size_t>num_rows,
84+
left_indexes.c_buffer,
85+
right_indexes.c_buffer,
86+
)
87+
cdef long long t_after_probe = perf_counter_ns()
88+
last_probe_time_ns = t_after_probe - t_after_hash
89+
last_rows_hashed = num_rows
90+
last_candidate_rows = candidate_count
3891

39-
for i in range(non_null_indices.shape[0]):
40-
row_idx = non_null_indices[i]
41-
hash_val = row_hashes[row_idx]
92+
# Return matched row indices from both sides
93+
cdef long long t_before_numpy = perf_counter_ns()
94+
cdef numpy.ndarray[int64_t, ndim=1] left_np = left_indexes.to_numpy()
95+
cdef numpy.ndarray[int64_t, ndim=1] right_np = right_indexes.to_numpy()
96+
cdef long long t_after_numpy = perf_counter_ns()
97+
last_result_rows = left_np.shape[0]
98+
last_materialize_time_ns = t_after_numpy - t_before_numpy
4299

43-
# Probe the left-side hash table
44-
left_matches = left_hash_table.get(hash_val)
45-
match_count = left_matches.size()
46-
if match_count == 0:
47-
continue
100+
return left_np, right_np
48101

49-
for j in range(match_count):
50-
left_indexes.append(left_matches[j])
51-
right_indexes.append(row_idx)
52102

53-
# Return matched row indices from both sides
54-
return left_indexes.to_numpy(), right_indexes.to_numpy()
103+
cpdef tuple get_last_inner_join_metrics():
104+
"""Return instrumentation captured during the most recent inner join call."""
105+
return (
106+
last_hash_time_ns,
107+
last_probe_time_ns,
108+
last_rows_hashed,
109+
last_candidate_rows,
110+
last_result_rows,
111+
last_materialize_time_ns,
112+
)
55113

56114

57115
cpdef FlatHashMap build_side_hash_map(object relation, list join_columns):
@@ -72,42 +130,3 @@ cpdef FlatHashMap build_side_hash_map(object relation, list join_columns):
72130
ht.insert(row_hashes[row_idx], row_idx)
73131

74132
return ht
75-
76-
77-
cpdef tuple nested_loop_join(left_relation, right_relation, list left_columns, list right_columns):
78-
"""
79-
A buffer-aware nested loop join using direct Arrow buffer access and hash computation.
80-
Only intended for small relations (<1000 rows), primarily used for correctness testing or fallbacks.
81-
"""
82-
# determine the rows we're going to try to join on
83-
cdef int64_t[::1] left_non_null_indices = non_null_row_indices(left_relation, left_columns)
84-
cdef int64_t[::1] right_non_null_indices = non_null_row_indices(right_relation, right_columns)
85-
86-
cdef int64_t nl = left_non_null_indices.shape[0]
87-
cdef int64_t nr = right_non_null_indices.shape[0]
88-
cdef IntBuffer left_indexes = IntBuffer()
89-
cdef IntBuffer right_indexes = IntBuffer()
90-
cdef int64_t left_non_null_idx, right_non_null_idx, left_record_idx, right_record_idx
91-
92-
cdef uint64_t[::1] left_hashes = numpy.empty(nl, dtype=numpy.uint64)
93-
cdef uint64_t[::1] right_hashes = numpy.empty(nr, dtype=numpy.uint64)
94-
95-
# remove the rows from the relations
96-
left_relation = left_relation.select(sorted(set(left_columns))).drop_null()
97-
right_relation = right_relation.select(sorted(set(right_columns))).drop_null()
98-
99-
# build hashes for the columns we're joining on
100-
compute_row_hashes(left_relation, left_columns, left_hashes)
101-
compute_row_hashes(right_relation, right_columns, right_hashes)
102-
103-
# Compare each pair of rows (naive quadratic approach)
104-
for left_non_null_idx in range(nl):
105-
for right_non_null_idx in range(nr):
106-
# if we have a match, look up the offset in the original table
107-
if left_hashes[left_non_null_idx] == right_hashes[right_non_null_idx]:
108-
left_record_idx = left_non_null_indices[left_non_null_idx]
109-
right_record_idx = right_non_null_indices[right_non_null_idx]
110-
left_indexes.append(left_record_idx)
111-
right_indexes.append(right_record_idx)
112-
113-
return (left_indexes.to_numpy(), right_indexes.to_numpy())
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# cython: language_level=3
2+
# cython: nonecheck=False
3+
# cython: cdivision=True
4+
# cython: initializedcheck=False
5+
# cython: infer_types=True
6+
# cython: wraparound=False
7+
# cython: boundscheck=False
8+
9+
import numpy
10+
cimport numpy
11+
numpy.import_array()
12+
13+
from libc.stdint cimport int64_t, uint64_t
14+
15+
from opteryx.compiled.structures.buffers cimport IntBuffer
16+
from opteryx.compiled.table_ops.hash_ops cimport compute_row_hashes
17+
from opteryx.compiled.table_ops.null_avoidant_ops cimport non_null_row_indices
18+
19+
20+
cpdef tuple nested_loop_join(left_relation, right_relation, list left_columns, list right_columns):
21+
"""
22+
Perform a buffer-aware nested loop join using Arrow buffer hashing.
23+
24+
This implementation is optimized for small relations where building a hash map would be
25+
more expensive than a quadratic scan.
26+
"""
27+
cdef int64_t[::1] left_non_null_indices = non_null_row_indices(left_relation, left_columns)
28+
cdef int64_t[::1] right_non_null_indices = non_null_row_indices(right_relation, right_columns)
29+
30+
cdef int64_t nl = left_non_null_indices.shape[0]
31+
cdef int64_t nr = right_non_null_indices.shape[0]
32+
33+
if nl == 0 or nr == 0:
34+
return numpy.empty(0, dtype=numpy.int64), numpy.empty(0, dtype=numpy.int64)
35+
36+
cdef IntBuffer left_indexes = IntBuffer()
37+
cdef IntBuffer right_indexes = IntBuffer()
38+
cdef uint64_t[::1] left_hashes = numpy.empty(left_relation.num_rows, dtype=numpy.uint64)
39+
cdef uint64_t[::1] right_hashes = numpy.empty(right_relation.num_rows, dtype=numpy.uint64)
40+
cdef int64_t i, j, left_row, right_row
41+
cdef uint64_t left_hash, right_hash
42+
43+
compute_row_hashes(left_relation, left_columns, left_hashes)
44+
compute_row_hashes(right_relation, right_columns, right_hashes)
45+
46+
if nl <= nr:
47+
for i in range(nl):
48+
left_row = left_non_null_indices[i]
49+
left_hash = left_hashes[left_row]
50+
for j in range(nr):
51+
right_row = right_non_null_indices[j]
52+
if left_hash == right_hashes[right_row]:
53+
left_indexes.append(left_row)
54+
right_indexes.append(right_row)
55+
else:
56+
for j in range(nr):
57+
right_row = right_non_null_indices[j]
58+
right_hash = right_hashes[right_row]
59+
for i in range(nl):
60+
left_row = left_non_null_indices[i]
61+
if right_hash == left_hashes[left_row]:
62+
left_indexes.append(left_row)
63+
right_indexes.append(right_row)
64+
65+
return left_indexes.to_numpy(), right_indexes.to_numpy()

opteryx/compiled/structures/buffers.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ cdef extern from "intbuffer.h" namespace "":
2020
void extend(const int64_t* values, size_t count)
2121
const int64_t* data() const
2222
size_t size() const
23+
void append_repeated(int64_t value, size_t count)
2324

2425

2526
cdef class IntBuffer:

opteryx/compiled/structures/buffers.pyx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ cdef extern from "intbuffer.h":
2323
void extend(const int64_t* data, size_t count) nogil
2424
const int64_t* data() nogil
2525
size_t size() nogil
26+
void append_repeated(int64_t value, size_t count) nogil
2627

2728
cdef class IntBuffer:
2829

opteryx/operators/inner_join_node.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from opteryx import EOS
2828
from opteryx.compiled.joins import build_side_hash_map
29+
from opteryx.compiled.joins import get_last_inner_join_metrics
2930
from opteryx.compiled.joins import inner_join
3031
from opteryx.compiled.structures.bloom_filter import create_bloom_filter
3132
from opteryx.models import QueryProperties
@@ -71,8 +72,9 @@ def execute(self, morsel: Table, join_leg: str) -> Table:
7172

7273
start = time.monotonic_ns()
7374
self.left_hash = build_side_hash_map(self.left_relation, self.left_columns)
74-
self.statistics.time_inner_join_build_side_hash_map += (
75-
time.monotonic_ns() - start
75+
self.statistics.increase(
76+
"time_inner_join_build_side_hash_map",
77+
time.monotonic_ns() - start,
7678
)
7779

7880
# If the left side is small enough to quickly build a bloom filter, do that.
@@ -83,8 +85,10 @@ def execute(self, morsel: Table, join_leg: str) -> Table:
8385
self.left_filter = create_bloom_filter(
8486
self.left_relation, self.left_columns
8587
)
86-
self.statistics.time_build_bloom_filter += time.monotonic_ns() - start
87-
self.statistics.feature_bloom_filter += 1
88+
self.statistics.increase(
89+
"time_build_bloom_filter", time.monotonic_ns() - start
90+
)
91+
self.statistics.increase("feature_bloom_filter", 1)
8892
else:
8993
if self.left_buffer_columns is None:
9094
self.left_buffer_columns = morsel.schema.names
@@ -106,7 +110,7 @@ def execute(self, morsel: Table, join_leg: str) -> Table:
106110
maybe_in_left = self.left_filter.possibly_contains_many(
107111
morsel, self.right_columns
108112
)
109-
self.statistics.time_bloom_filtering += time.monotonic_ns() - start
113+
self.statistics.increase("time_bloom_filtering", time.monotonic_ns() - start)
110114
morsel = morsel.filter(maybe_in_left)
111115

112116
# If the bloom filter is not effective, disable it.
@@ -115,13 +119,32 @@ def execute(self, morsel: Table, join_leg: str) -> Table:
115119
eliminated_rows = len(maybe_in_left) - morsel.num_rows
116120
if eliminated_rows < 0.05 * len(maybe_in_left):
117121
self.left_filter = None
118-
self.statistics.feature_dynamically_disabled_bloom_filter += 1
122+
self.statistics.increase("feature_dynamically_disabled_bloom_filter", 1)
119123

120-
self.statistics.rows_eliminated_by_bloom_filter += eliminated_rows
124+
self.statistics.increase("rows_eliminated_by_bloom_filter", eliminated_rows)
121125

122126
# do the join
123127
left_indicies, right_indicies = inner_join(
124128
morsel, self.right_columns, self.left_hash
125129
)
126130

127-
yield align_tables(morsel, self.left_relation, right_indicies, left_indicies)
131+
# record detailed timing and row counts for diagnostics
132+
(
133+
hash_time,
134+
probe_time,
135+
rows_hashed,
136+
candidate_rows,
137+
matched_rows,
138+
materialize_time,
139+
) = get_last_inner_join_metrics()
140+
self.statistics.increase("time_inner_join_hash", hash_time)
141+
self.statistics.increase("time_inner_join_probe", probe_time)
142+
self.statistics.increase("rows_inner_join_hashed", rows_hashed)
143+
self.statistics.increase("rows_inner_join_candidates", candidate_rows)
144+
self.statistics.increase("time_inner_join_indices", materialize_time)
145+
self.statistics.increase("rows_inner_join_matched", matched_rows)
146+
start = time.monotonic_ns()
147+
aligned = align_tables(morsel, self.left_relation, right_indicies, left_indicies)
148+
self.statistics.increase("time_inner_join_align", time.monotonic_ns() - start)
149+
150+
yield aligned

0 commit comments

Comments
 (0)