Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/src/strings/contains.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -53,7 +53,7 @@ struct contains_fn {

size_type end = beginning_only ? 1 // match only the beginning of the string;
: -1; // match anywhere in the string
return prog.find(thread_idx, d_str, d_str.begin(), end).has_value();
return prog.find<positional::END_ONLY>(thread_idx, d_str, d_str.begin(), end).has_value();
}
};

Expand Down
27 changes: 15 additions & 12 deletions cpp/src/strings/regex/regex.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,16 @@ namespace cudf {
namespace strings {
namespace detail {

struct relist;
/**
* @brief Template type used on `find` to specify desired position values in returned match_result
*/
enum class positional : int8_t {
BEGIN_END = 0, /// both begin and end positions are returned
END_ONLY = 1, /// only the end position is returned
};

template <positional P>
struct reljunk;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like I'm missing something about this name :D
Is it rel_junk? If yes, why junk, and if not, what does it stand for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it is re_lj_unk where the re is regex , lj is long something, and unk is unknown.
The name is from code that this is based on so I'm mostly keeping parity since I don't have a better one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I thought it might be unknown, but could not come up with any meaning for the lj part. Thanks!


using match_pair = thrust::pair<cudf::size_type, cudf::size_type>;
using match_result = cuda::std::optional<match_pair>;
Expand Down Expand Up @@ -180,13 +189,15 @@ class reprog_device {
/**
* @brief Does a find evaluation using the compiled expression on the given string.
*
* @tparam P Desired positional values. Default includes valid begin and end match positions.
* @param thread_idx The index used for mapping the state memory for this string in global memory.
* @param d_str The string to search.
* @param begin Position to begin the search within `d_str`.
* @param end Character position index to end the search within `d_str`.
* Specify -1 to match any virtual positions past the end of the string.
* @return If match found, returns character positions of the matches.
*/
template <positional P = positional::BEGIN_END>
[[nodiscard]] __device__ inline match_result find(int32_t const thread_idx,
string_view const d_str,
string_view::const_iterator begin,
Expand All @@ -213,16 +224,6 @@ class reprog_device {
cudf::size_type const group_id) const;

private:
struct reljunk {
relist* __restrict__ list1;
relist* __restrict__ list2;
int32_t starttype{};
char32_t startchar{};

__device__ inline reljunk(relist* list1, relist* list2, reinst const inst);
__device__ inline void swaplist();
};

/**
* @brief Returns the regex instruction object for a given id.
*/
Expand All @@ -236,15 +237,17 @@ class reprog_device {
/**
* @brief Executes the regex pattern on the given string.
*/
template <positional P>
[[nodiscard]] __device__ inline match_result regexec(string_view const d_str,
reljunk jnk,
reljunk<P>& jnk,
string_view::const_iterator begin,
cudf::size_type end,
cudf::size_type const group_id = 0) const;

/**
* @brief Utility wrapper to setup state memory structures for calling regexec
*/
template <positional P = positional::BEGIN_END>
[[nodiscard]] __device__ inline match_result call_regexec(
int32_t const thread_idx,
string_view const d_str,
Expand Down
82 changes: 49 additions & 33 deletions cpp/src/strings/regex/regex.inl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -71,20 +71,26 @@ struct alignas(8) relist {
size = 0;
}

template <positional P = positional::BEGIN_END>
__device__ __forceinline__ bool activate(int32_t id, int32_t begin, int32_t end)
{
if (readMask(id)) { return false; }
writeMask(id);
inst_ids[size * stride] = static_cast<int16_t>(id);
ranges[size * stride] = int2{begin, end};
if constexpr (P == positional::BEGIN_END) { ranges[size * stride] = int2{begin, end}; }
++size;
return true;
}

template <positional P = positional::BEGIN_END>
[[nodiscard]] __device__ __forceinline__ restate get_state(int16_t idx) const
{
return restate{ranges[idx * stride], inst_ids[idx * stride]};
if constexpr (P == positional::BEGIN_END) {
return restate{ranges[idx * stride], inst_ids[idx * stride]};
}
return restate{{-1, -1}, inst_ids[idx * stride]};
}

[[nodiscard]] __device__ __forceinline__ int16_t get_size() const { return size; }

private:
Expand All @@ -108,23 +114,28 @@ struct alignas(8) relist {
}
};

__device__ __forceinline__ reprog_device::reljunk::reljunk(relist* list1,
relist* list2,
reinst const inst)
: list1(list1), list2(list2)
{
if (inst.type == CHAR || inst.type == BOL) {
starttype = inst.type;
startchar = inst.u1.c;
}
}
template <positional P = positional::BEGIN_END>
struct reljunk {
relist* __restrict__ list1;
relist* __restrict__ list2;
int32_t starttype{};
char32_t startchar{};

__device__ __forceinline__ void reprog_device::reljunk::swaplist()
{
auto tmp = list1;
list1 = list2;
list2 = tmp;
}
__device__ inline reljunk(relist* list1, relist* list2, reinst const inst)
: list1(list1), list2(list2)
{
if (inst.type == CHAR || inst.type == BOL) {
starttype = inst.type;
startchar = inst.u1.c;
}
}
__device__ inline void swaplist()
{
auto tmp = list1;
list1 = list2;
list2 = tmp;
}
};

/**
* @brief Check for supported new-line characters
Expand Down Expand Up @@ -249,8 +260,9 @@ __device__ __forceinline__ static string_view::const_iterator find_char(
* @param group_id Index of the group to match in a multi-group regex pattern.
* @return >0 if match found
*/
template <positional P>
__device__ __forceinline__ match_result reprog_device::regexec(string_view const dstr,
reljunk jnk,
reljunk<P>& jnk,
string_view::const_iterator itr,
cudf::size_type end,
cudf::size_type const group_id) const
Expand Down Expand Up @@ -288,8 +300,9 @@ __device__ __forceinline__ match_result reprog_device::regexec(string_view const

if (((eos < 0) || (pos < eos)) && match == 0) {
auto ids = _startinst_ids;
while (*ids >= 0)
jnk.list1->activate(*ids++, (group_id == 0 ? pos : -1), -1);
while (*ids >= 0) {
jnk.list1->template activate<P>(*ids++, (group_id == 0 ? pos : -1), -1);
}
}

last_character = itr.byte_offset() >= dstr.size_bytes();
Expand All @@ -303,7 +316,7 @@ __device__ __forceinline__ match_result reprog_device::regexec(string_view const
expanded = false;

for (int16_t i = 0; i < jnk.list1->get_size(); i++) {
auto state = jnk.list1->get_state(i);
auto state = jnk.list1->template get_state<P>(i);
auto range = state.range;
auto const inst = get_inst(state.inst_id);
int32_t id_activate = -1;
Expand All @@ -316,12 +329,12 @@ __device__ __forceinline__ match_result reprog_device::regexec(string_view const
case NCCLASS:
case END: id_activate = state.inst_id; break;
case LBRA:
if (inst.u1.subid == group_id) range.x = pos;
if (inst.u1.subid == group_id) { range.x = pos; }
id_activate = inst.u2.next_id;
expanded = true;
break;
case RBRA:
if (inst.u1.subid == group_id) range.y = pos;
if (inst.u1.subid == group_id) { range.y = pos; }
id_activate = inst.u2.next_id;
expanded = true;
break;
Expand Down Expand Up @@ -363,12 +376,12 @@ __device__ __forceinline__ match_result reprog_device::regexec(string_view const
break;
}
case OR:
jnk.list2->activate(inst.u1.right_id, range.x, range.y);
jnk.list2->template activate<P>(inst.u1.right_id, range.x, range.y);
id_activate = inst.u2.left_id;
expanded = true;
break;
}
if (id_activate >= 0) jnk.list2->activate(id_activate, range.x, range.y);
if (id_activate >= 0) { jnk.list2->template activate<P>(id_activate, range.x, range.y); }
}
jnk.swaplist();

Expand All @@ -378,7 +391,7 @@ __device__ __forceinline__ match_result reprog_device::regexec(string_view const
bool continue_execute = true;
jnk.list2->reset();
for (int16_t i = 0; continue_execute && i < jnk.list1->get_size(); i++) {
auto const state = jnk.list1->get_state(i);
auto const state = jnk.list1->template get_state<P>(i);
auto const range = state.range;
auto const inst = get_inst(state.inst_id);
int32_t id_activate = -1;
Expand Down Expand Up @@ -408,8 +421,9 @@ __device__ __forceinline__ match_result reprog_device::regexec(string_view const
continue_execute = false;
break;
}
if (continue_execute && (id_activate >= 0))
jnk.list2->activate(id_activate, range.x, range.y);
if (continue_execute && (id_activate >= 0)) {
jnk.list2->template activate<P>(id_activate, range.x, range.y);
}
}

++pos;
Expand All @@ -421,12 +435,13 @@ __device__ __forceinline__ match_result reprog_device::regexec(string_view const
return match ? match_result({begin, end}) : cuda::std::nullopt;
}

template <positional P>
__device__ __forceinline__ match_result reprog_device::find(int32_t const thread_idx,
string_view const dstr,
string_view::const_iterator begin,
cudf::size_type end) const
{
return call_regexec(thread_idx, dstr, begin, end);
return call_regexec<P>(thread_idx, dstr, begin, end);
}

__device__ __forceinline__ match_result reprog_device::extract(int32_t const thread_idx,
Expand All @@ -439,6 +454,7 @@ __device__ __forceinline__ match_result reprog_device::extract(int32_t const thr
return call_regexec(thread_idx, dstr, begin, end, group_id + 1);
}

template <positional P>
__device__ __forceinline__ match_result
reprog_device::call_regexec(int32_t const thread_idx,
string_view const dstr,
Expand All @@ -452,8 +468,8 @@ reprog_device::call_regexec(int32_t const thread_idx,
gp_ptr += relist::alloc_size(_max_insts, _thread_count);
relist list2(static_cast<int16_t>(_max_insts), _thread_count, gp_ptr, thread_idx);

reljunk jnk(&list1, &list2, get_inst(_startinst_id));
return regexec(dstr, jnk, begin, end, group_id);
reljunk<P> jnk(&list1, &list2, get_inst(_startinst_id));
return regexec<P>(dstr, jnk, begin, end, group_id);
}

} // namespace detail
Expand Down