Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
version: 2

build:
os: ubuntu-22.04
tools:
python: "3.11"

python:
install:
- requirements: docs/requirements.txt # or just requirements.txt if it's in root

sphinx:
configuration: docs/source/conf.py # or conf.py if it's in root
116 changes: 116 additions & 0 deletions docs/api/BC.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# BoundaryCondition (Base Class)

The `BoundaryCondition` class is the base class for implementing all boundary conditions in Lattice Boltzmann Method (LBM) simulations.
It extends the generic `Operator` class and provides the foundational structure for applying boundary logic in different simulation stages.

## 📌 Purpose

In LBM simulations, boundary conditions (BCs) define how the simulation behaves at domain edges — walls, inlets, outlets, etc.
`BoundaryCondition` provides:

- A uniform interface for implementing BCs
- GPU/TPU-compatible kernels using **JAX** or **Warp**
- Support for auxiliary data (e.g., prescribed velocities)
- Integration with velocity sets, precision policy, and compute backends



## 🧩 Key Parameters

| Argument | Description |
|--------------------|---------------------------------------------------------------------|
| `implementation_step` | When the BC is applied: `COLLISION` or `STREAMING` |
| `velocity_set` | Type of LBM velocity set (optional, uses default if not provided) |
| `precision_policy` | Controls numerical precision (optional) |
| `compute_backend` | Either `JAX` or `WARP` (optional) |
| `indices` | Grid indices where the BC applies |
| `mesh_vertices` | Optional mesh information for mesh-aware BCs |


<!-- ## ⚙️ Features and Flags

| Flag | Description |
|-----------------------|--------------------------------------------------------------------|
| `needs_padding` | True if the BC requires boundary padding in all directions |
| `needs_mesh_distance` | True if the BC needs geometric distance to a mesh |
| `needs_aux_init` | Indicates if the BC uses auxiliary data (e.g., prescribed values) |
| `num_of_aux_data` | How many auxiliary values are needed (if any) |
| `needs_aux_recovery` | If auxiliary data must be recovered post-streaming | -->

<!-- ## ⚡ Backend Implementations

Subclasses are expected to register their backend-specific logic for:

- **JAX** (via `@jit`)
- **Warp** (via `@wp.kernel`)

These implementations are used to apply the boundary logic at simulation runtime.


## 🔄 Auxiliary Data Support

Some BCs (e.g., prescribed velocity or pressure) require initializing extra data at the boundary. The base class includes:

- `update_bc_auxilary_data(...)` – placeholder, can be overridden
- `aux_data_init(...)` – initializes BC-specific auxiliary values (e.g., pre-fill velocity)

These support seamless integration of BCs requiring pre-simulation setup.

## 🔧 Custom Warp Kernels

To define Warp-compatible BCs, use:

```python
def _construct_kernel(self, functional):
```

Where functional(...) implements the per-thread boundary logic, returning updated distribution functions.

## 🧪 Example: DoNothingBC

The `DoNothingBC` subclass demonstrates a minimal example:

```python
class DoNothingBC(BoundaryCondition):
def jax_implementation(...):
return jnp.where(boundary_mask, f_pre, f_post)
```
This BC effectively does nothing to the boundary values — useful for debugging or placeholders. -->


---

## 🚧 **Boundary Condition Subclasses**

1. **DoNothingBC**:
In this boundary condition, the fluid populations are allowed to pass through the boundary without any reflection or modification. Useful for test cases or special boundary handling.
2. **EquilibriumBC**:
In this boundary condition, the fluid populations are assumed to be in at equilibrium. Constructor has optional macroscopic density (`rho`) and velocity (`u`) values
3. **FullwayBounceBackBC**:
In this boundary condition, the velocity of the fluid populations is reflected back to the fluid side of the boundary, resulting in zero fluid velocity at the boundary. Enforces no-slip wall conditions by reversing particle distributions at the boundary during the collision step.
4. **HalfwayBounceBackBC**:
Similar to the `FullwayBounceBackBC`, in this boundary condition, the velocity of the fluid populations is partially reflected back to the fluid side of the boundary, resulting in a non-zero fluid velocity at the boundary. Enforces no-slip conditions by reflecting distribution functions halfway between fluid and boundary nodes, improving accuracy over fullway bounce-back.
5. **ZouHeBC**:
This boundary condition is used to impose a prescribed velocity or pressure profile at the boundary. Supports only normal velocity components (only one non-zero velocity element allowed)
6. **RegularizedBC**:
This boundary condition is used to impose a prescribed velocity or pressure profile at the boundary. This BC is more stable than `ZouHeBC`, but computationally more expensive.
7. **ExtrapolationOutflowBC**:
A type of outflow boundary condition that uses extrapolation to avoid strong wave reflections.
8. **GradsApproximationBC**:
Interpolated bounce-back boundary condition for representing curved boundaries. Requires 3D velocity sets (not implemented in 2D)




## Summary Table of Boundary Conditions

| BC Class | Purpose | Implementation Step | Supports Auxiliary Data | Backend Support |
|------------------------|------------------------------------------------------|---------------------|------------------------|-----------------------|
| `DoNothingBC` | Leaves boundary distributions unchanged (no-op) | STREAMING | No |JAX, Warp |
| `EquilibriumBC` | Prescribe equilibrium populations | STREAMING | No | JAX, Warp |
| `FullwayBounceBackBC` | Classic bounce-back (no-slip) | COLLISION | No | JAX, Warp |
| `HalfwayBounceBackBC` | Halfway bounce-back for no-slip walls | STREAMING | No | JAX, Warp |
| `ZouHeBC` | Classical Zou-He velocity/pressure BC with non-equilibrium bounce-back | STREAMING | Yes |JAX, Warp |
| `RegularizedBC` | Non-equilibrium bounce-back with second moment regularization | STREAMING | No | JAX, Warp |
| `ExtrapolationOutflowBC`| Smooth outflow via extrapolation | STREAMING | Yes | JAX, Warp |
| `GradsApproximationBC` | Approximate missing populations via Grad's method | STREAMING | No | Warp only |
137 changes: 137 additions & 0 deletions docs/api/constants.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Constants and Enums

This page provides a reference for the core enumerations (`Enum`) and configuration objects that govern the behavior of XLB simulations. These objects are used to specify settings like the computational backend, numerical precision, and the physics model to be solved.

---

## ComputeBackend

Defined in `compute_backend.py`

```python
class ComputeBackend(Enum):
JAX = auto()
WARP = auto()
```

**Description:**

An `Enum` specifying the primary computational engine for executing simulation kernels.

- **`JAX`**: Use the [JAX](https://github.com/google/jax) framework for computation, enabling execution on CPUs, GPUs, and TPUs.
- **`WARP`**: Use the [NVIDIA Warp](https://github.com/NVIDIA/warp) framework for high-performance GPU simulation kernels.

---

## GridBackend

Defined in `grid_backend.py`

```python
class GridBackend(Enum):
JAX = auto()
WARP = auto()
OOC = auto()
```

**Description:**

An `Enum` defining the backend for grid creation and data management.

- **`JAX`**, **`WARP`**: The grid data resides in memory on the respective compute device.
- **`OOC`**: Handles simulations where the grid data is too large to fit into memory and must be processed "out-of-core" from disk.

---

## PhysicsType

Defined in `physics_type.py`

```python
class PhysicsType(Enum):
NSE = auto() # Navier-Stokes Equations
ADE = auto() # Advection-Diffusion Equations
```

**Description:**

An `Enum` used to select the set of physical equations to be solved by the stepper.

- **`NSE`**: Simulates fluid dynamics governed by the incompressible Navier-Stokes equations.
- **`ADE`**: Simulates transport phenomena governed by the Advection-Diffusion equation.

---

## Precision

Defined in `precision_policy.py`

```python
class Precision(Enum):
FP64 = auto()
FP32 = auto()
FP16 = auto()
UINT8 = auto()
BOOL = auto()
```

**Description:**

An `Enum` representing fundamental data precision levels. Each member provides properties to get the corresponding data type in the target compute backend:

- **`.wp_dtype`**: The equivalent `warp` data type (e.g., `wp.float32`).
- **`.jax_dtype`**: The equivalent `jax.numpy` data type (e.g., `jnp.float32`).

---

## PrecisionPolicy

Defined in `precision_policy.py`

```python
class PrecisionPolicy(Enum):
FP64FP64 = auto()
FP64FP32 = auto()
FP64FP16 = auto()
FP32FP32 = auto()
FP32FP16 = auto()
```

**Description:**

An `Enum` that defines a policy for balancing numerical accuracy and memory usage. It specifies a precision for computation and a (potentially different) precision for storage.

For example, `FP64FP32` specifies that calculations should be performed in high-precision `float64`, but the results are stored in memory-efficient `float32`.

**Utility Properties & Methods:**
- **`.compute_precision`**: Returns the `Precision` enum for computation.
- **`.store_precision`**: Returns the `Precision` enum for storage.
- **`.cast_to_compute_jax(array)`**: Casts a JAX array to the policy's compute precision.
- **`.cast_to_store_jax(array)`**: Casts a JAX array to the policy's store precision.

---

## DefaultConfig

Defined in `default_config.py`

```python
@dataclass
class DefaultConfig:
velocity_set
default_backend
default_precision_policy
```

A `dataclass` that holds the global configuration for a simulation session.

An instance of this configuration is set globally using the `xlb.init()` function at the beginning of a script. This ensures that all subsequently created XLB components are aware of the chosen backend, velocity set, and precision policy.

```python
# The xlb.init() function sets the global DefaultConfig instance
xlb.init(
velocity_set=D2Q9(...),
default_backend=ComputeBackend.JAX,
default_precision_policy=PrecisionPolicy.FP32FP32
)
```
100 changes: 100 additions & 0 deletions docs/api/distribution.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Distribution

The `distribution` module provides tools for distributing **lattice Boltzmann operators** across multiple devices (e.g., GPUs or TPUs) using [JAX sharding](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).
This enables simulations to run in parallel while ensuring correct **halo communication** between device partitions.

---

## Overview

In lattice Boltzmann methods (LBM), each lattice site’s distribution function depends on its neighbors.
When running on multiple devices, the domain is split (sharded) across them, requiring **data exchange at the boundaries** after each step.

The `distribution` module handles:

- **Sharding operators** across devices.
- **Exchanging boundary (halo) data** between devices.
- Supporting stepper operators (like `IncompressibleNavierStokesStepper`) with or without boundary conditions.

---

## Functions

### `distribute_operator`

```python
distribute_operator(operator, grid, velocity_set, num_results=1, ops="permute")
```
Wraps an operator to run in distributed fashion.

## Parameters

- **operator** (`Operator`)
The LBM operator (e.g., collision, streaming).

- **grid**
Grid definition with device mesh info (`grid.global_mesh`, `grid.shape`, `grid.nDevices`).

- **velocity_set**
Velocity set defining the LBM stencil (e.g., D2Q9, D3Q19).

- **num_results** (`int`, default=`1`)
Number of results returned by the operator.

- **ops** (`str`, default=`"permute"`)
Communication scheme. Currently supports `"permute"` for halo exchange.

---

## Details

- Uses **`shard_map`** to parallelize across devices.
- Applies **halo communication** via `jax.lax.ppermute`:
- Sends right-edge values to the left neighbor.
- Sends left-edge values to the right neighbor.
- Returns a **JIT-compiled distributed operator**.

---

### `distribute`

```python
distribute(operator, grid, velocity_set, num_results=1, ops="permute")

```

## Description

Decides how to distribute an operator or stepper.

---

## Parameters

Same as **`distribute_operator`**.

---

## Special Case: `IncompressibleNavierStokesStepper`

- Checks if boundary conditions require **post-streaming updates**:
- If **yes** → only the `.stream` operator is distributed.
- If **no** → the entire stepper is distributed.

---

## Example

```python
from xlb.operator.stepper import IncompressibleNavierStokesStepper
from xlb.distribution import distribute

# Create stepper
stepper = IncompressibleNavierStokesStepper(...)

# Distribute across devices
distributed_stepper = distribute(stepper, grid, velocity_set)

# Run simulation
state = distributed_stepper(state)
```
Loading