diff --git a/napari/_vispy/canvas.py b/napari/_vispy/canvas.py index 14fa277e5a0..4493a3e5a07 100644 --- a/napari/_vispy/canvas.py +++ b/napari/_vispy/canvas.py @@ -166,6 +166,12 @@ def __init__( self.viewer.camera.events.zoom.connect(self._on_cursor) self.viewer.layers.events.reordered.connect(self._reorder_layers) self.viewer.layers.events.removed.connect(self._remove_layer) + self.viewer.multi_channel_gridcanvas.events.enabled.connect( + self._on_grid_change + ) + self.viewer.multi_channel_gridcanvas.events.stride.connect( + self._on_grid_change + ) self.destroyed.connect(self._disconnect_theme) @property @@ -334,7 +340,14 @@ def _map_canvas2world( of the viewer. """ nd = self.viewer.dims.ndisplay - transform = self.view.scene.transform + # TODO: figure out how to extent this to all grid boxes, main thing is to see which viewbox the mouse is hovering over + # and to get the transform solely of that viewbox. This would allow compatibility with current code returning one + # position world. + if self.viewer.multi_channel_gridcanvas.enabled: + transform = self.grid_views[0].scene.transform + else: + transform = self.view.scene.transform + # cartesian to homogeneous coordinates mapped_position = transform.imap(list(position)) if nd == 3: @@ -681,3 +694,56 @@ def screenshot(self) -> QImage: def enable_dims_play(self, *args) -> None: """Enable playing of animation. False if awaiting a draw event""" self.viewer.dims._play_ready = True + + def _on_grid_change(self): + """Change grid view""" + if self.viewer.multi_channel_gridcanvas.enabled: + grid_shape, n_gridboxes = ( + self.viewer.multi_channel_gridcanvas.actual_shape( + len(self.layer_to_visual) + ) + ) + + self.grid = self.central_widget.add_grid() + camera = self.camera._view.camera + self.grid_views = [ + self.grid.add_view( + row=y, col=x, camera=camera if y == 0 and x == 0 else None + ) + for y in range(grid_shape[0]) + for x in range(grid_shape[1]) + if x * y < n_gridboxes + ] + self.camera._view = self.grid_views[0] + self.central_widget.remove_widget(self.view) + # del self.view + self.grid_cameras = [ + VispyCamera( + self.grid_views[i], self.viewer.camera, self.viewer.dims + ) + for i in range(len(self.grid_views[1:])) + ] + + for ind, layer in enumerate(self.layer_to_visual.values()): + if ind != 0: + self.grid_views[ind].camera = self.grid_cameras[ + ind - 1 + ]._view.camera + self.grid_views[ind].camera.link(self.grid_views[0].camera) + layer.node.parent = self.grid_views[ind].scene + else: + for layer in self.layer_to_visual.values(): + layer.node.parent = self.view.scene + self.central_widget.remove_widget(self.grid) + self.central_widget.add_widget(self.view) + self.camera._view = self.view + + # TODO properly disconnect events of grid and delete all viewboxes + del self.grid + + for camera in self.grid_cameras: + camera.disconnect() + del camera + + # TODO respect 3d camera if enabled + self.camera._view.camera = self.camera._2D_camera diff --git a/napari/components/grid.py b/napari/components/grid.py index cd8d6f79c25..8ac74a99d88 100644 --- a/napari/components/grid.py +++ b/napari/components/grid.py @@ -121,3 +121,98 @@ def position(self, index: int, nlayers: int) -> tuple[int, int]: i_row = adj_i // n_column i_column = adj_i % n_column return (i_row, i_column) + + +# TODO: mixin class to combine with GridCanvas +class MultiChannelGridCanvas(EventedModel): + """Multichannel gridcanvas. + Grid mode with multiple cameras + Attributes + ---------- + enabled : bool + If grid is enabled or not. + stride : int + Number of layers to place in each grid square before moving on to + the next square. The default ordering is to place the most visible + layer in the top left corner of the grid. A negative stride will + cause the order in which the layers are placed in the grid to be + reversed. + shape : 2-tuple of int + Number of rows and columns in the grid. A value of -1 for either or + both of will be used the row and column numbers will trigger an + auto calculation of the necessary grid shape to appropriately fill + all the layers at the appropriate stride. + """ + + # fields + # See https://github.com/pydantic/pydantic/issues/156 for why + # these need a type: ignore comment + stride: GridStride = 1 # type: ignore[valid-type] + shape: tuple[GridHeight, GridWidth] = (-1, -1) # type: ignore[valid-type] + enabled: bool = False + + def actual_shape(self, nlayers: int = 1) -> tuple[tuple[int, int], int]: + """Return the actual shape of the grid. + This will return the shape parameter, unless one of the row + or column numbers is -1 in which case it will compute the + optimal shape of the grid given the number of layers and + current stride. + If the grid is not enabled, this will return (1, 1). + Parameters + ---------- + nlayers : int + Number of layers that need to be placed in the grid. + Returns + ------- + shape : 2-tuple of int + Number of rows and columns in the grid. + """ + if not self.enabled: + return (1, 1), 0 + + if nlayers == 0: + return (1, 1), 0 + + n_row, n_column = self.shape + n_grid_squares = np.ceil(nlayers / abs(self.stride)).astype(int) + + if n_row == -1 and n_column == -1: + n_column = np.ceil(np.sqrt(n_grid_squares)).astype(int) + n_row = np.ceil(n_grid_squares / n_column).astype(int) + elif n_row == -1: + n_row = np.ceil(n_grid_squares / n_column).astype(int) + elif n_column == -1: + n_column = np.ceil(n_grid_squares / n_row).astype(int) + + n_row = max(1, n_row) + n_column = max(1, n_column) + + return (n_row, n_column), n_grid_squares + + def position(self, index: int, nlayers: int) -> tuple[int, int]: + """Return the position of a given linear index in grid. + If the grid is not enabled, this will return (0, 0). + Parameters + ---------- + index : int + Position of current layer in layer list. + nlayers : int + Number of layers that need to be placed in the grid. + Returns + ------- + position : 2-tuple of int + Row and column position of current index in the grid. + """ + if not self.enabled: + return (0, 0) + + shape, n_gridboxes = self.actual_shape(nlayers) + + # Adjust for forward or reverse ordering + adj_i = nlayers - index - 1 if self.stride < 0 else index + + adj_i = adj_i // abs(self.stride) + adj_i = adj_i % (shape[0] * shape[1]) + i_row = adj_i // shape[1] + i_column = adj_i % shape[1] + return (i_row, i_column) diff --git a/napari/components/viewer_model.py b/napari/components/viewer_model.py index 12e6b1c47d3..a7bba8067ac 100644 --- a/napari/components/viewer_model.py +++ b/napari/components/viewer_model.py @@ -30,7 +30,7 @@ from napari.components.camera import Camera from napari.components.cursor import Cursor, CursorStyle from napari.components.dims import Dims -from napari.components.grid import GridCanvas +from napari.components.grid import GridCanvas, MultiChannelGridCanvas from napari.components.layerlist import LayerList from napari.components.overlays import ( AxesOverlay, @@ -183,6 +183,9 @@ class ViewerModel(KeymapProvider, MousemapProvider, EventedModel): cursor: Cursor = Field(default_factory=Cursor, allow_mutation=False) dims: Dims = Field(default_factory=Dims, allow_mutation=False) grid: GridCanvas = Field(default_factory=GridCanvas, allow_mutation=False) + multi_channel_gridcanvas: MultiChannelGridCanvas = Field( + default_factory=MultiChannelGridCanvas, allow_mutation=False + ) layers: LayerList = Field( default_factory=LayerList, allow_mutation=False ) # Need to create custom JSON encoder for layer! @@ -257,6 +260,11 @@ def __init__( settings.application.events.grid_spacing.connect( self._update_viewer_grid ) + self.multi_channel_gridcanvas.stride = settings.application.grid_stride + self.multi_channel_gridcanvas.shape = ( + settings.application.grid_height, + settings.application.grid_width, + ) settings.experimental.events.async_.connect(self._update_async) # Add extra events - ideally these will be removed too!