Skip to content

Context block for IndexWriter #507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/tutorials.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,19 @@ writer.wait_merging_threads()
Note that `wait_merging_threads()` must come at the end, because
the `writer` object will not be usable after this call.

Alternatively `writer` can be used as a context manager. The same block of code can then be written as

```python
with index.writer() as writer:
writer.add_document(tantivy.Document(
doc_id=1,
title=["The Old Man and the Sea"],
body=["""He was an old man who fished alone in a skiff in the Gulf Stream and he had gone eighty-four days now without taking a fish."""],
))
```

Both `commit()` and `wait_merging_threads()` is called when the with-block is exited.

## Building and Executing Queries with the Query Parser

With the Query Parser, you can easily build simple queries for your index.
Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@

@nox.session(python=["3.9", "3.10", "3.11", "3.12", "3.13"])
def test(session):
session.install("-rrequirements-dev.txt")
session.install("-r", "requirements-dev.txt")
session.install("-e", ".", "--no-build-isolation")
session.run("pytest", *session.posargs)
10 changes: 10 additions & 0 deletions src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,16 @@ impl IndexWriter {
pub fn wait_merging_threads(&mut self) -> PyResult<()> {
self.take_inner()?.wait_merging_threads().map_err(to_pyerr)
}

pub fn __enter__(slf: Py<Self>) -> Py<Self> {
slf
}

pub fn __exit__(&mut self, _exc_type: PyObject, _exc_value: PyObject, _traceback: PyObject) {
self.commit();
self.wait_merging_threads();
}

}

/// Create a new index object.
Expand Down
19 changes: 15 additions & 4 deletions tantivy/tantivy.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import datetime
from enum import Enum
from types import TracebackType
from typing import Any, Optional, Sequence, TypeVar, Union
from typing_extensions import Self


class Schema:
Expand Down Expand Up @@ -400,6 +402,17 @@ class IndexWriter:
def wait_merging_threads(self) -> None:
pass

def __enter__(self: Self) -> Self:
pass

def __exit__(
self: Self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
pass


class Index:
def __new__(
Expand Down Expand Up @@ -476,6 +489,7 @@ class Snippet:
def fragment(self) -> str:
pass


class SnippetGenerator:
@staticmethod
def create(
Expand Down Expand Up @@ -519,7 +533,6 @@ class Tokenizer:


class Filter:

@staticmethod
def alphanum_only() -> Filter:
pass
Expand Down Expand Up @@ -551,16 +564,14 @@ class Filter:
@staticmethod
def split_compound(constituent_words: list[str]) -> Filter:
pass


class TextAnalyzer:

class TextAnalyzer:
def analyze(self, text: str) -> list[str]:
pass


class TextAnalyzerBuilder:

def __init__(self, tokenizer: Tokenizer):
pass

Expand Down
40 changes: 27 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tantivy import SchemaBuilder, Index, Document


def schema():
def build_schema():
return (
SchemaBuilder()
.add_text_field("title", stored=True)
Expand All @@ -13,7 +13,7 @@ def schema():
)


def schema_numeric_fields():
def build_schema_numeric_fields():
return (
SchemaBuilder()
.add_integer_field("id", stored=True, indexed=True, fast=True)
Expand All @@ -23,7 +23,8 @@ def schema_numeric_fields():
.build()
)

def schema_with_date_field():

def build_schema_with_date_field():
return (
SchemaBuilder()
.add_integer_field("id", stored=True, indexed=True)
Expand All @@ -32,7 +33,8 @@ def schema_with_date_field():
.build()
)

def schema_with_ip_addr_field():

def build_schema_with_ip_addr_field():
return (
SchemaBuilder()
.add_integer_field("id", stored=True, indexed=True)
Expand All @@ -41,10 +43,11 @@ def schema_with_ip_addr_field():
.build()
)


def create_index(dir=None):
# assume all tests will use the same documents for now
# other methods may set up function-local indexes
index = Index(schema(), dir)
index = Index(build_schema(), dir)
writer = index.writer(15_000_000, 1)

# 2 ways of adding documents
Expand Down Expand Up @@ -97,7 +100,7 @@ def create_index(dir=None):


def create_index_with_numeric_fields(dir=None):
index = Index(schema_numeric_fields(), dir)
index = Index(build_schema_numeric_fields(), dir)
writer = index.writer(15_000_000, 1)

doc = Document()
Expand Down Expand Up @@ -140,15 +143,16 @@ def create_index_with_numeric_fields(dir=None):
index.reload()
return index


def create_index_with_date_field(dir=None):
index = Index(schema_with_date_field(), dir)
index = Index(build_schema_with_date_field(), dir)
writer = index.writer(15_000_000, 1)

doc = Document()
doc.add_integer("id", 1)
doc.add_float("rating", 3.5)
doc.add_date("date", datetime(2021, 1, 1))

writer.add_document(doc)
doc = Document.from_dict(
{
Expand All @@ -161,10 +165,11 @@ def create_index_with_date_field(dir=None):
writer.commit()
writer.wait_merging_threads()
index.reload()
return index
return index


def create_index_with_ip_addr_field(dir=None):
schema = schema_with_ip_addr_field()
schema = build_schema_with_ip_addr_field()
index = Index(schema, dir)
writer = index.writer(15_000_000, 1)

Expand All @@ -173,14 +178,14 @@ def create_index_with_ip_addr_field(dir=None):
doc.add_float("rating", 3.5)
doc.add_ip_addr("ip_addr", "10.0.0.1")
writer.add_document(doc)

doc = Document.from_dict(
{
"id": 2,
"rating": 4.5,
"ip_addr": "127.0.0.1",
},
schema
schema,
)
writer.add_document(doc)
doc = Document.from_dict(
Expand All @@ -189,14 +194,15 @@ def create_index_with_ip_addr_field(dir=None):
"rating": 4.5,
"ip_addr": "::1",
},
schema
schema,
)
writer.add_document(doc)
writer.commit()
writer.wait_merging_threads()
index.reload()
return index


def spanish_schema():
return (
SchemaBuilder()
Expand Down Expand Up @@ -262,14 +268,22 @@ def ram_index():
def ram_index_numeric_fields():
return create_index_with_numeric_fields()


@pytest.fixture(scope="class")
def ram_index_with_date_field():
return create_index_with_date_field()


@pytest.fixture(scope="class")
def ram_index_with_ip_addr_field():
return create_index_with_ip_addr_field()


@pytest.fixture(scope="class")
def spanish_index():
return create_spanish_index()


@pytest.fixture(scope="class")
def schema():
return build_schema()
29 changes: 21 additions & 8 deletions tests/tantivy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

import tantivy
from conftest import schema, schema_numeric_fields
from conftest import build_schema, build_schema_numeric_fields
from tantivy import (
Document,
Index,
Expand All @@ -29,7 +29,7 @@ def test_simple_search_in_dir(self, dir_index):

def test_simple_search_after_reuse(self, dir_index):
index_dir, _ = dir_index
index = Index(schema(), str(index_dir))
index = Index(build_schema(), str(index_dir))
query = index.parse_query("sea whale", ["title", "body"])

result = index.searcher().search(query, 10)
Expand Down Expand Up @@ -555,6 +555,19 @@ def test_delete_all_documents(self, ram_index):

assert len(result.hits) == 0

def test_index_writer_context_block(self, schema):
index = Index(schema)
with index.writer() as writer:
writer.add_document(Document(
doc_id=1,
title=["The Old Man and the Sea"],
body=["""He was an old man who fished alone in a skiff in the Gulf Stream and he had gone eighty-four days now without taking a fish."""],
))

index.reload()
result = index.searcher().search(Query.all_query())
assert len(result.hits) == 1


class TestUpdateClass(object):
def test_delete_update(self, ram_index):
Expand Down Expand Up @@ -588,12 +601,12 @@ def test_opens_from_dir_invalid_schema(self, dir_index):
def test_opens_from_dir(self, dir_index):
index_dir, _ = dir_index

index = Index(schema(), str(index_dir), reuse=True)
index = Index(build_schema(), str(index_dir), reuse=True)
assert index.searcher().num_docs == 3

def test_create_readers(self):
# not sure what is the point of this test.
idx = Index(schema())
idx = Index(build_schema())
idx.config_reader("Manual", 4)
assert idx.searcher().num_docs == 0
# by default this is manual mode
Expand Down Expand Up @@ -798,9 +811,9 @@ def test_bytes(bytes_kwarg, bytes_payload):


def test_schema_eq():
schema1 = schema()
schema2 = schema()
schema3 = schema_numeric_fields()
schema1 = build_schema()
schema2 = build_schema()
schema3 = build_schema_numeric_fields()

assert schema1 == schema2
assert schema1 != schema3
Expand Down Expand Up @@ -852,7 +865,7 @@ def test_doc_address_pickle():
class TestSnippets(object):
def test_document_snippet(self, dir_index):
index_dir, _ = dir_index
doc_schema = schema()
doc_schema = build_schema()
index = Index(doc_schema, str(index_dir))
query = index.parse_query("sea whale", ["title", "body"])
searcher = index.searcher()
Expand Down
Loading