Skip to content

Commit d2a6c9a

Browse files
committed
wip
1 parent e68c4d1 commit d2a6c9a

File tree

1 file changed

+49
-1
lines changed

1 file changed

+49
-1
lines changed

lonboard/_layer.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
)
2323

2424
import ipywidgets
25+
import numpy as np
2526
import traitlets
26-
from arro3.core import Table
27+
from arro3.core import Array, ChunkedArray, Table
2728
from arro3.core.types import ArrowStreamExportable
2829

2930
from lonboard._base import BaseExtension, BaseWidget
@@ -38,6 +39,7 @@
3839
from lonboard._serialization import infer_rows_per_chunk
3940
from lonboard._utils import auto_downcast as _auto_downcast
4041
from lonboard._utils import get_geometry_column_index, remove_extension_kwargs
42+
from lonboard.layer_extension import DataFilterExtension
4143
from lonboard.traits import (
4244
ArrowTableTrait,
4345
ColorAccessor,
@@ -409,9 +411,55 @@ def quak(self) -> quak.Widget:
409411
import quak
410412
import sqlglot
411413

414+
if not any(isinstance(ext, DataFilterExtension) for ext in self.extensions):
415+
self.add_extension(DataFilterExtension(category_size=1))
416+
417+
table: Table = self.table
418+
num_rows = table.num_rows
419+
if num_rows <= np.iinfo(np.uint8).max:
420+
row_index = Array.from_numpy(np.arange(num_rows, dtype=np.uint8))
421+
filter_arr = np.ones(num_rows, dtype=np.float32)
422+
elif num_rows <= np.iinfo(np.uint16).max:
423+
row_index = Array.from_numpy(np.arange(num_rows, dtype=np.uint16))
424+
filter_arr = np.ones(num_rows, dtype=np.float32)
425+
elif num_rows <= np.iinfo(np.uint32).max:
426+
row_index = Array.from_numpy(np.arange(num_rows, dtype=np.uint32))
427+
filter_arr = np.ones(num_rows, dtype=np.float32)
428+
else:
429+
row_index = Array.from_numpy(np.arange(num_rows, dtype=np.uint64))
430+
filter_arr = np.ones(num_rows, dtype=np.float32)
431+
432+
table_with_row_index = table.append_column(
433+
"_row_index", ChunkedArray(row_index)
434+
)
435+
quak_widget = quak.Widget(table_with_row_index)
436+
437+
def row_index_callback(change):
438+
global test
439+
test = change
440+
441+
sql = sqlglot.parse_one(quak_widget.sql, dialect="duckdb")
442+
sql.set("expressions", [sqlglot.column("_row_index")])
443+
row_index_table = quak_widget._conn.query(sql.sql(dialect="duckdb")).arrow()
444+
445+
# Reset all to 2. We don't use zero because there might be a bug with 0
446+
filter_arr[:] = 2
447+
448+
# Set the desired _row_index to 1
449+
filter_arr[row_index_table["_row_index"]] = 1
450+
451+
self.get_filter_category = filter_arr # type: ignore
452+
self.filter_categories = [1] # type: ignore
453+
454+
quak_widget.observe(row_index_callback, names="sql")
455+
456+
return quak_widget
412457
pass
413458

414459

460+
test = None
461+
462+
415463
class BitmapLayer(BaseLayer):
416464
"""
417465
The `BitmapLayer` renders a bitmap (e.g. PNG, JPEG, or WebP) at specified

0 commit comments

Comments
 (0)