@@ -11,47 +11,105 @@ cimport numpy
1111numpy.import_array()
1212
1313from libc.stdint cimport int64_t, uint64_t
14+ from libc.stddef cimport size_t
15+ from libcpp.vector cimport vector
1416
15- from opteryx.third_party.abseil.containers cimport FlatHashMap
16- from opteryx.compiled.structures.buffers cimport IntBuffer
17+ from time import perf_counter_ns
18+
19+ from opteryx.third_party.abseil.containers cimport (
20+ FlatHashMap,
21+ IdentityHash,
22+ flat_hash_map,
23+ )
24+ from opteryx.compiled.structures.buffers cimport CIntBuffer, IntBuffer
1725from opteryx.compiled.table_ops.hash_ops cimport compute_row_hashes
1826from opteryx.compiled.table_ops.null_avoidant_ops cimport non_null_row_indices
1927
28+ cdef extern from " join_kernels.h" :
29+ void inner_join_probe(
30+ flat_hash_map[uint64_t, vector[int64_t], IdentityHash]* left_map,
31+ const int64_t* non_null_indices,
32+ size_t non_null_count,
33+ const uint64_t* row_hashes,
34+ size_t row_hash_count,
35+ CIntBuffer* left_out,
36+ CIntBuffer* right_out
37+ ) nogil
38+
39+ cdef public long long last_hash_time_ns = 0
40+ cdef public long long last_probe_time_ns = 0
41+ cdef public long long last_materialize_time_ns = 0
42+ cdef public Py_ssize_t last_rows_hashed = 0
43+ cdef public Py_ssize_t last_candidate_rows = 0
44+ cdef public Py_ssize_t last_result_rows = 0
45+
2046
2147cpdef tuple inner_join(object right_relation, list join_columns, FlatHashMap left_hash_table):
2248 """
2349 Perform an inner join between a right-hand relation and a pre-built left-side hash table.
2450 This function uses precomputed hashes and avoids null rows for optimal speed.
2551 """
52+ global last_hash_time_ns, last_probe_time_ns, last_materialize_time_ns
53+ global last_rows_hashed, last_candidate_rows, last_result_rows
2654 cdef IntBuffer left_indexes = IntBuffer()
2755 cdef IntBuffer right_indexes = IntBuffer()
2856 cdef int64_t num_rows = right_relation.num_rows
2957 cdef int64_t[::1 ] non_null_indices = non_null_row_indices(right_relation, join_columns)
58+ cdef Py_ssize_t candidate_count = non_null_indices.shape[0 ]
59+
60+ if candidate_count == 0 or num_rows == 0 :
61+ last_hash_time_ns = 0
62+ last_probe_time_ns = 0
63+ last_rows_hashed = num_rows
64+ last_candidate_rows = candidate_count
65+ last_result_rows = 0
66+ last_materialize_time_ns = 0
67+ return numpy.empty(0 , dtype = numpy.int64), numpy.empty(0 , dtype = numpy.int64)
68+
3069 cdef uint64_t[::1 ] row_hashes = numpy.empty(num_rows, dtype = numpy.uint64)
31- cdef int64_t i, row_idx
32- cdef uint64_t hash_val
33- cdef size_t match_count
34- cdef int j
70+ cdef long long t_start = perf_counter_ns()
3571
3672 # Precompute hashes for right relation
3773 compute_row_hashes(right_relation, join_columns, row_hashes)
74+ cdef long long t_after_hash = perf_counter_ns()
75+ last_hash_time_ns = t_after_hash - t_start
76+
77+ with nogil:
78+ inner_join_probe(
79+ & left_hash_table._map,
80+ & non_null_indices[0 ],
81+ < size_t> candidate_count,
82+ & row_hashes[0 ],
83+ < size_t> num_rows,
84+ left_indexes.c_buffer,
85+ right_indexes.c_buffer,
86+ )
87+ cdef long long t_after_probe = perf_counter_ns()
88+ last_probe_time_ns = t_after_probe - t_after_hash
89+ last_rows_hashed = num_rows
90+ last_candidate_rows = candidate_count
3891
39- for i in range (non_null_indices.shape[0 ]):
40- row_idx = non_null_indices[i]
41- hash_val = row_hashes[row_idx]
92+ # Return matched row indices from both sides
93+ cdef long long t_before_numpy = perf_counter_ns()
94+ cdef numpy.ndarray[int64_t, ndim= 1 ] left_np = left_indexes.to_numpy()
95+ cdef numpy.ndarray[int64_t, ndim= 1 ] right_np = right_indexes.to_numpy()
96+ cdef long long t_after_numpy = perf_counter_ns()
97+ last_result_rows = left_np.shape[0 ]
98+ last_materialize_time_ns = t_after_numpy - t_before_numpy
4299
43- # Probe the left-side hash table
44- left_matches = left_hash_table.get(hash_val)
45- match_count = left_matches.size()
46- if match_count == 0 :
47- continue
100+ return left_np, right_np
48101
49- for j in range (match_count):
50- left_indexes.append(left_matches[j])
51- right_indexes.append(row_idx)
52102
53- # Return matched row indices from both sides
54- return left_indexes.to_numpy(), right_indexes.to_numpy()
103+ cpdef tuple get_last_inner_join_metrics():
104+ """ Return instrumentation captured during the most recent inner join call."""
105+ return (
106+ last_hash_time_ns,
107+ last_probe_time_ns,
108+ last_rows_hashed,
109+ last_candidate_rows,
110+ last_result_rows,
111+ last_materialize_time_ns,
112+ )
55113
56114
57115cpdef FlatHashMap build_side_hash_map(object relation, list join_columns):
@@ -72,42 +130,3 @@ cpdef FlatHashMap build_side_hash_map(object relation, list join_columns):
72130 ht.insert(row_hashes[row_idx], row_idx)
73131
74132 return ht
75-
76-
77- cpdef tuple nested_loop_join(left_relation, right_relation, list left_columns, list right_columns):
78- """
79- A buffer-aware nested loop join using direct Arrow buffer access and hash computation.
80- Only intended for small relations (<1000 rows), primarily used for correctness testing or fallbacks.
81- """
82- # determine the rows we're going to try to join on
83- cdef int64_t[::1 ] left_non_null_indices = non_null_row_indices(left_relation, left_columns)
84- cdef int64_t[::1 ] right_non_null_indices = non_null_row_indices(right_relation, right_columns)
85-
86- cdef int64_t nl = left_non_null_indices.shape[0 ]
87- cdef int64_t nr = right_non_null_indices.shape[0 ]
88- cdef IntBuffer left_indexes = IntBuffer()
89- cdef IntBuffer right_indexes = IntBuffer()
90- cdef int64_t left_non_null_idx, right_non_null_idx, left_record_idx, right_record_idx
91-
92- cdef uint64_t[::1 ] left_hashes = numpy.empty(nl, dtype = numpy.uint64)
93- cdef uint64_t[::1 ] right_hashes = numpy.empty(nr, dtype = numpy.uint64)
94-
95- # remove the rows from the relations
96- left_relation = left_relation.select(sorted (set (left_columns))).drop_null()
97- right_relation = right_relation.select(sorted (set (right_columns))).drop_null()
98-
99- # build hashes for the columns we're joining on
100- compute_row_hashes(left_relation, left_columns, left_hashes)
101- compute_row_hashes(right_relation, right_columns, right_hashes)
102-
103- # Compare each pair of rows (naive quadratic approach)
104- for left_non_null_idx in range (nl):
105- for right_non_null_idx in range (nr):
106- # if we have a match, look up the offset in the original table
107- if left_hashes[left_non_null_idx] == right_hashes[right_non_null_idx]:
108- left_record_idx = left_non_null_indices[left_non_null_idx]
109- right_record_idx = right_non_null_indices[right_non_null_idx]
110- left_indexes.append(left_record_idx)
111- right_indexes.append(right_record_idx)
112-
113- return (left_indexes.to_numpy(), right_indexes.to_numpy())
0 commit comments