Skip to content

Commit 8dd43f1

Browse files
committed
feat(ops): SplitBySeparators — regex-only splitter (Rust+Py); extract shared split helpers; enum UPPERCASE; defaults via Python; simplify range logic
1 parent a3a0c3a commit 8dd43f1

File tree

6 files changed

+403
-2
lines changed

6 files changed

+403
-2
lines changed

python/cocoindex/functions.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
import dataclasses
44
import functools
5-
from typing import Annotated, Any, Literal
5+
from typing import Any, Literal
66

77
import numpy as np
88
from numpy.typing import NDArray
99

1010
from . import llm, op
11-
from .typing import TypeAttr, Vector
11+
from .typing import Vector
1212

1313

1414
class ParseJson(op.FunctionSpec):
@@ -40,6 +40,24 @@ class SplitRecursively(op.FunctionSpec):
4040
custom_languages: list[CustomLanguageSpec] = dataclasses.field(default_factory=list)
4141

4242

43+
class SplitBySeparators(op.FunctionSpec):
44+
"""
45+
Split text by specified regex separators only.
46+
Output schema matches SplitRecursively for drop-in compatibility:
47+
KTable rows with fields: location (Range), text (Str), start, end.
48+
Args:
49+
separators_regex: list[str] # e.g., [r"\\n\\n+"]
50+
keep_separator: Literal["NONE", "LEFT", "RIGHT"] = "NONE"
51+
include_empty: bool = False
52+
trim: bool = True
53+
"""
54+
55+
separators_regex: list[str] = dataclasses.field(default_factory=list)
56+
keep_separator: Literal["NONE", "LEFT", "RIGHT"] = "NONE"
57+
include_empty: bool = False
58+
trim: bool = True
59+
60+
4361
class EmbedText(op.FunctionSpec):
4462
"""Embed a text into a vector space."""
4563

src/ops/functions/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
pub mod embed_text;
22
pub mod extract_by_llm;
33
pub mod parse_json;
4+
pub mod split_by_separators;
45
pub mod split_recursively;
56

67
#[cfg(test)]
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
use anyhow::{Context, Result};
2+
use regex::Regex;
3+
use std::sync::Arc;
4+
5+
use crate::base::field_attrs;
6+
use crate::ops::registry::ExecutorFactoryRegistry;
7+
use crate::ops::shared::split::{Position, set_output_positions};
8+
use crate::{fields_value, ops::sdk::*};
9+
10+
#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Eq)]
11+
#[serde(rename_all = "UPPERCASE")]
12+
enum KeepSep {
13+
NONE,
14+
LEFT,
15+
RIGHT,
16+
}
17+
18+
#[derive(Serialize, Deserialize)]
19+
struct Spec {
20+
// Python SDK provides defaults/values.
21+
separators_regex: Vec<String>,
22+
keep_separator: KeepSep,
23+
include_empty: bool,
24+
trim: bool,
25+
}
26+
27+
struct Args {
28+
text: ResolvedOpArg,
29+
}
30+
31+
struct Executor {
32+
spec: Spec,
33+
regex: Option<Regex>,
34+
args: Args,
35+
}
36+
37+
impl Executor {
38+
fn new(args: Args, spec: Spec) -> Result<Self> {
39+
let regex = if spec.separators_regex.is_empty() {
40+
None
41+
} else {
42+
// OR-join all separators, multiline
43+
let pattern = format!(
44+
"(?m){}",
45+
spec.separators_regex
46+
.iter()
47+
.map(|s| format!("(?:{s})"))
48+
.collect::<Vec<_>>()
49+
.join("|")
50+
);
51+
Some(Regex::new(&pattern).context("failed to compile separators_regex")?)
52+
};
53+
Ok(Self { args, spec, regex })
54+
}
55+
}
56+
57+
struct ChunkOutput<'s> {
58+
start_pos: Position,
59+
end_pos: Position,
60+
text: &'s str,
61+
}
62+
63+
#[async_trait]
64+
impl SimpleFunctionExecutor for Executor {
65+
async fn evaluate(&self, input: Vec<Value>) -> Result<Value> {
66+
let full_text = self.args.text.value(&input)?.as_str()?;
67+
let bytes = full_text.as_bytes();
68+
69+
// add_range applies trim/include_empty and records the text slice
70+
let mut chunks: Vec<ChunkOutput<'_>> = Vec::new();
71+
let mut add_range = |mut s: usize, mut e: usize| {
72+
if self.spec.trim {
73+
while s < e && bytes[s].is_ascii_whitespace() {
74+
s += 1;
75+
}
76+
while e > s && bytes[e - 1].is_ascii_whitespace() {
77+
e -= 1;
78+
}
79+
}
80+
if self.spec.include_empty || e > s {
81+
chunks.push(ChunkOutput {
82+
start_pos: Position::new(s),
83+
end_pos: Position::new(e),
84+
text: &full_text[s..e],
85+
});
86+
}
87+
};
88+
89+
if let Some(re) = &self.regex {
90+
let mut start = 0usize;
91+
for m in re.find_iter(full_text) {
92+
let end = match self.spec.keep_separator {
93+
KeepSep::LEFT => m.end(),
94+
KeepSep::NONE | KeepSep::RIGHT => m.start(),
95+
};
96+
add_range(start, end);
97+
start = match self.spec.keep_separator {
98+
KeepSep::RIGHT => m.start(),
99+
KeepSep::NONE | KeepSep::LEFT => m.end(),
100+
};
101+
}
102+
add_range(start, full_text.len());
103+
} else {
104+
// No separators: emit whole text
105+
add_range(0, full_text.len());
106+
}
107+
108+
set_output_positions(
109+
full_text,
110+
chunks.iter_mut().flat_map(|c| {
111+
std::iter::once(&mut c.start_pos).chain(std::iter::once(&mut c.end_pos))
112+
}),
113+
);
114+
115+
let table = chunks
116+
.into_iter()
117+
.map(|c| {
118+
let s = c.start_pos.output.unwrap();
119+
let e = c.end_pos.output.unwrap();
120+
(
121+
KeyValue::from_single_part(RangeValue::new(s.char_offset, e.char_offset)),
122+
fields_value!(Arc::<str>::from(c.text), s.into_output(), e.into_output())
123+
.into(),
124+
)
125+
})
126+
.collect();
127+
128+
Ok(Value::KTable(table))
129+
}
130+
}
131+
132+
struct Factory;
133+
134+
#[async_trait]
135+
impl SimpleFunctionFactoryBase for Factory {
136+
type Spec = Spec;
137+
type ResolvedArgs = Args;
138+
139+
fn name(&self) -> &str {
140+
"SplitBySeparators"
141+
}
142+
143+
async fn resolve_schema<'a>(
144+
&'a self,
145+
_spec: &'a Spec,
146+
args_resolver: &mut OpArgsResolver<'a>,
147+
_context: &FlowInstanceContext,
148+
) -> Result<(Args, EnrichedValueType)> {
149+
// one required arg: text: Str
150+
let args = Args {
151+
text: args_resolver
152+
.next_arg("text")?
153+
.expect_type(&ValueType::Basic(BasicValueType::Str))?
154+
.required()?,
155+
};
156+
157+
// start/end structs exactly like SplitRecursively
158+
let pos_struct = schema::ValueType::Struct(schema::StructSchema {
159+
fields: Arc::new(vec![
160+
schema::FieldSchema::new("offset", make_output_type(BasicValueType::Int64)),
161+
schema::FieldSchema::new("line", make_output_type(BasicValueType::Int64)),
162+
schema::FieldSchema::new("column", make_output_type(BasicValueType::Int64)),
163+
]),
164+
description: None,
165+
});
166+
167+
let mut struct_schema = StructSchema::default();
168+
let mut sb = StructSchemaBuilder::new(&mut struct_schema);
169+
sb.add_field(FieldSchema::new(
170+
"location",
171+
make_output_type(BasicValueType::Range),
172+
));
173+
sb.add_field(FieldSchema::new(
174+
"text",
175+
make_output_type(BasicValueType::Str),
176+
));
177+
sb.add_field(FieldSchema::new(
178+
"start",
179+
schema::EnrichedValueType {
180+
typ: pos_struct.clone(),
181+
nullable: false,
182+
attrs: Default::default(),
183+
},
184+
));
185+
sb.add_field(FieldSchema::new(
186+
"end",
187+
schema::EnrichedValueType {
188+
typ: pos_struct,
189+
nullable: false,
190+
attrs: Default::default(),
191+
},
192+
));
193+
let output_schema = make_output_type(TableSchema::new(
194+
TableKind::KTable(KTableInfo { num_key_parts: 1 }),
195+
struct_schema,
196+
))
197+
.with_attr(
198+
field_attrs::CHUNK_BASE_TEXT,
199+
serde_json::to_value(args_resolver.get_analyze_value(&args.text))?,
200+
);
201+
Ok((args, output_schema))
202+
}
203+
204+
async fn build_executor(
205+
self: Arc<Self>,
206+
spec: Spec,
207+
args: Args,
208+
_context: Arc<FlowInstanceContext>,
209+
) -> Result<impl SimpleFunctionExecutor> {
210+
Executor::new(args, spec)
211+
}
212+
}
213+
214+
pub fn register(registry: &mut ExecutorFactoryRegistry) -> Result<()> {
215+
Factory.register(registry)
216+
}
217+
218+
#[cfg(test)]
219+
mod tests {
220+
use super::*;
221+
use crate::ops::functions::test_utils::test_flow_function;
222+
223+
#[tokio::test]
224+
async fn test_split_by_separators_paragraphs() {
225+
let spec = Spec {
226+
separators_regex: vec![r"\n\n+".to_string()],
227+
keep_separator: KeepSep::NONE,
228+
include_empty: false,
229+
trim: true,
230+
};
231+
let factory = Arc::new(Factory);
232+
let text = "Para1\n\nPara2\n\n\nPara3";
233+
234+
let input_arg_schemas = &[(
235+
Some("text"),
236+
make_output_type(BasicValueType::Str).with_nullable(true),
237+
)];
238+
239+
let result = test_flow_function(
240+
&factory,
241+
&spec,
242+
input_arg_schemas,
243+
vec![text.to_string().into()],
244+
)
245+
.await
246+
.unwrap();
247+
248+
match result {
249+
Value::KTable(table) => {
250+
// Expected ranges after trimming whitespace:
251+
let expected = vec![
252+
(RangeValue::new(0, 5), "Para1"),
253+
(RangeValue::new(7, 12), "Para2"),
254+
(RangeValue::new(15, 20), "Para3"),
255+
];
256+
for (range, expected_text) in expected {
257+
let key = KeyValue::from_single_part(range);
258+
let row = table.get(&key).unwrap();
259+
let chunk_text = row.0.fields[0].as_str().unwrap();
260+
assert_eq!(**chunk_text, *expected_text);
261+
}
262+
}
263+
other => panic!("Expected KTable, got {other:?}"),
264+
}
265+
}
266+
267+
#[tokio::test]
268+
async fn test_split_by_separators_keep_right() {
269+
let spec = Spec {
270+
separators_regex: vec![r"\.".to_string()],
271+
keep_separator: KeepSep::RIGHT,
272+
include_empty: false,
273+
trim: true,
274+
};
275+
let factory = Arc::new(Factory);
276+
let text = "A. B. C.";
277+
278+
let input_arg_schemas = &[(
279+
Some("text"),
280+
make_output_type(BasicValueType::Str).with_nullable(true),
281+
)];
282+
283+
let result = test_flow_function(
284+
&factory,
285+
&spec,
286+
input_arg_schemas,
287+
vec![text.to_string().into()],
288+
)
289+
.await
290+
.unwrap();
291+
292+
match result {
293+
Value::KTable(table) => {
294+
assert!(table.len() >= 3);
295+
}
296+
_ => panic!("KTable expected"),
297+
}
298+
}
299+
}

src/ops/registration.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ fn register_executor_factories(registry: &mut ExecutorFactoryRegistry) -> Result
1818
functions::split_recursively::register(registry)?;
1919
functions::extract_by_llm::Factory.register(registry)?;
2020
functions::embed_text::register(registry)?;
21+
functions::split_by_separators::register(registry)?;
2122

2223
targets::postgres::Factory::default().register(registry)?;
2324
targets::qdrant::register(registry)?;

src/ops/shared/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
pub mod postgres;
2+
pub mod split;

0 commit comments

Comments
 (0)