Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 67 additions & 1 deletion napari/_vispy/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,12 @@ def __init__(
self.viewer.camera.events.zoom.connect(self._on_cursor)
self.viewer.layers.events.reordered.connect(self._reorder_layers)
self.viewer.layers.events.removed.connect(self._remove_layer)
self.viewer.multi_channel_gridcanvas.events.enabled.connect(
self._on_grid_change
)
self.viewer.multi_channel_gridcanvas.events.stride.connect(
self._on_grid_change
)
self.destroyed.connect(self._disconnect_theme)

@property
Expand Down Expand Up @@ -334,7 +340,14 @@ def _map_canvas2world(
of the viewer.
"""
nd = self.viewer.dims.ndisplay
transform = self.view.scene.transform
# TODO: figure out how to extent this to all grid boxes, main thing is to see which viewbox the mouse is hovering over
# and to get the transform solely of that viewbox. This would allow compatibility with current code returning one
# position world.
if self.viewer.multi_channel_gridcanvas.enabled:
transform = self.grid_views[0].scene.transform
else:
transform = self.view.scene.transform

# cartesian to homogeneous coordinates
mapped_position = transform.imap(list(position))
if nd == 3:
Expand Down Expand Up @@ -681,3 +694,56 @@ def screenshot(self) -> QImage:
def enable_dims_play(self, *args) -> None:
"""Enable playing of animation. False if awaiting a draw event"""
self.viewer.dims._play_ready = True

def _on_grid_change(self):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider refactoring the _on_grid_change method by extracting the enabling/disabling logic into separate helper functions.

Refactor the method by offloading the enabling/disabling logic to helper functions. This will keep the functionality intact while reducing the inline branching and list comprehensions. For example:

def _enable_grid_mode(self):
    grid_shape, n_gridboxes = self.viewer.multi_channel_gridcanvas.actual_shape(
        len(self.layer_to_visual)
    )
    self.grid = self.central_widget.add_grid()
    camera = self.camera._view.camera
    self.grid_views = [
        self.grid.add_view(
            row=y,
            col=x,
            camera=camera if y == 0 and x == 0 else None
        )
        for y in range(grid_shape[0])
        for x in range(grid_shape[1])
        if x * y < n_gridboxes
    ]
    self.camera._view = self.grid_views[0]
    self.central_widget.remove_widget(self.view)
    self.grid_cameras = [
        VispyCamera(self.grid_views[i], self.viewer.camera, self.viewer.dims)
        for i in range(len(self.grid_views[1:]))
    ]
    for ind, layer in enumerate(self.layer_to_visual.values()):
        if ind != 0:
            self.grid_views[ind].camera = self.grid_cameras[ind - 1]._view.camera
            self.grid_views[ind].camera.link(self.grid_views[0].camera)
        layer.node.parent = self.grid_views[ind].scene

def _disable_grid_mode(self):
    for layer in self.layer_to_visual.values():
        layer.node.parent = self.view.scene
    self.central_widget.remove_widget(self.grid)
    self.central_widget.add_widget(self.view)
    self.camera._view = self.view
    # TODO: properly disconnect grid events and delete all viewboxes
    del self.grid
    for camera in self.grid_cameras:
        camera.disconnect()
        del camera
    # TODO: respect 3D camera if enabled
    self.camera._view.camera = self.camera._2D_camera

def _on_grid_change(self):
    """Change grid view"""
    if self.viewer.multi_channel_gridcanvas.enabled:
        self._enable_grid_mode()
    else:
        self._disable_grid_mode()

Actionable Steps:

  1. Create two focused helper functions: _enable_grid_mode and _disable_grid_mode.
  2. Move the corresponding code from _on_grid_change into these helper functions.
  3. In _on_grid_change, simply check the condition and call the appropriate helper.

This refactoring decouples responsibilities, making the control flow easier to follow and maintain.

"""Change grid view"""
if self.viewer.multi_channel_gridcanvas.enabled:
grid_shape, n_gridboxes = (
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (code-quality): Extract code out into method (extract-method)

self.viewer.multi_channel_gridcanvas.actual_shape(
len(self.layer_to_visual)
)
)

self.grid = self.central_widget.add_grid()
camera = self.camera._view.camera
self.grid_views = [
self.grid.add_view(
row=y, col=x, camera=camera if y == 0 and x == 0 else None
)
for y in range(grid_shape[0])
for x in range(grid_shape[1])
if x * y < n_gridboxes
Comment on lines +713 to +715
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): Review grid view index calculation in the list comprehension.

Use (y * grid_shape[1] + x) instead of x * y to compute the grid index and preserve row-major ordering.

Suggested change
for y in range(grid_shape[0])
for x in range(grid_shape[1])
if x * y < n_gridboxes
for y in range(grid_shape[0])
for x in range(grid_shape[1])
if y * grid_shape[1] + x < n_gridboxes

]
self.camera._view = self.grid_views[0]
self.central_widget.remove_widget(self.view)
# del self.view
self.grid_cameras = [
VispyCamera(
self.grid_views[i], self.viewer.camera, self.viewer.dims
)
for i in range(len(self.grid_views[1:]))
]

for ind, layer in enumerate(self.layer_to_visual.values()):
if ind != 0:
self.grid_views[ind].camera = self.grid_cameras[
ind - 1
]._view.camera
self.grid_views[ind].camera.link(self.grid_views[0].camera)
layer.node.parent = self.grid_views[ind].scene
else:
for layer in self.layer_to_visual.values():
layer.node.parent = self.view.scene
self.central_widget.remove_widget(self.grid)
self.central_widget.add_widget(self.view)
self.camera._view = self.view

# TODO properly disconnect events of grid and delete all viewboxes
del self.grid

for camera in self.grid_cameras:
camera.disconnect()
del camera

# TODO respect 3d camera if enabled
self.camera._view.camera = self.camera._2D_camera
95 changes: 95 additions & 0 deletions napari/components/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,98 @@ def position(self, index: int, nlayers: int) -> tuple[int, int]:
i_row = adj_i // n_column
i_column = adj_i % n_column
return (i_row, i_column)


# TODO: mixin class to combine with GridCanvas
class MultiChannelGridCanvas(EventedModel):
"""Multichannel gridcanvas.
Grid mode with multiple cameras
Attributes
----------
enabled : bool
If grid is enabled or not.
stride : int
Number of layers to place in each grid square before moving on to
the next square. The default ordering is to place the most visible
layer in the top left corner of the grid. A negative stride will
cause the order in which the layers are placed in the grid to be
reversed.
shape : 2-tuple of int
Number of rows and columns in the grid. A value of -1 for either or
both of will be used the row and column numbers will trigger an
auto calculation of the necessary grid shape to appropriately fill
all the layers at the appropriate stride.
"""

# fields
# See https://github.com/pydantic/pydantic/issues/156 for why
# these need a type: ignore comment
stride: GridStride = 1 # type: ignore[valid-type]
shape: tuple[GridHeight, GridWidth] = (-1, -1) # type: ignore[valid-type]
enabled: bool = False

def actual_shape(self, nlayers: int = 1) -> tuple[tuple[int, int], int]:
"""Return the actual shape of the grid.
This will return the shape parameter, unless one of the row
or column numbers is -1 in which case it will compute the
optimal shape of the grid given the number of layers and
current stride.
If the grid is not enabled, this will return (1, 1).
Parameters
----------
nlayers : int
Number of layers that need to be placed in the grid.
Returns
-------
shape : 2-tuple of int
Number of rows and columns in the grid.
"""
if not self.enabled:
return (1, 1), 0

if nlayers == 0:
return (1, 1), 0

n_row, n_column = self.shape
n_grid_squares = np.ceil(nlayers / abs(self.stride)).astype(int)

if n_row == -1 and n_column == -1:
n_column = np.ceil(np.sqrt(n_grid_squares)).astype(int)
n_row = np.ceil(n_grid_squares / n_column).astype(int)
elif n_row == -1:
n_row = np.ceil(n_grid_squares / n_column).astype(int)
elif n_column == -1:
n_column = np.ceil(n_grid_squares / n_row).astype(int)

n_row = max(1, n_row)
n_column = max(1, n_column)

return (n_row, n_column), n_grid_squares

def position(self, index: int, nlayers: int) -> tuple[int, int]:
"""Return the position of a given linear index in grid.
If the grid is not enabled, this will return (0, 0).
Parameters
----------
index : int
Position of current layer in layer list.
nlayers : int
Number of layers that need to be placed in the grid.
Returns
-------
position : 2-tuple of int
Row and column position of current index in the grid.
"""
if not self.enabled:
return (0, 0)

shape, n_gridboxes = self.actual_shape(nlayers)

# Adjust for forward or reverse ordering
adj_i = nlayers - index - 1 if self.stride < 0 else index

adj_i = adj_i // abs(self.stride)
adj_i = adj_i % (shape[0] * shape[1])
i_row = adj_i // shape[1]
i_column = adj_i % shape[1]
return (i_row, i_column)
10 changes: 9 additions & 1 deletion napari/components/viewer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from napari.components.camera import Camera
from napari.components.cursor import Cursor, CursorStyle
from napari.components.dims import Dims
from napari.components.grid import GridCanvas
from napari.components.grid import GridCanvas, MultiChannelGridCanvas
from napari.components.layerlist import LayerList
from napari.components.overlays import (
AxesOverlay,
Expand Down Expand Up @@ -183,6 +183,9 @@ class ViewerModel(KeymapProvider, MousemapProvider, EventedModel):
cursor: Cursor = Field(default_factory=Cursor, allow_mutation=False)
dims: Dims = Field(default_factory=Dims, allow_mutation=False)
grid: GridCanvas = Field(default_factory=GridCanvas, allow_mutation=False)
multi_channel_gridcanvas: MultiChannelGridCanvas = Field(
default_factory=MultiChannelGridCanvas, allow_mutation=False
)
layers: LayerList = Field(
default_factory=LayerList, allow_mutation=False
) # Need to create custom JSON encoder for layer!
Expand Down Expand Up @@ -257,6 +260,11 @@ def __init__(
settings.application.events.grid_spacing.connect(
self._update_viewer_grid
)
self.multi_channel_gridcanvas.stride = settings.application.grid_stride
self.multi_channel_gridcanvas.shape = (
settings.application.grid_height,
settings.application.grid_width,
)
settings.experimental.events.async_.connect(self._update_async)

# Add extra events - ideally these will be removed too!
Expand Down
Loading