Skip to content
22 changes: 22 additions & 0 deletions skglm/skglm_jax/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
## Installation


1. create then activate ``conda`` environnement
```shell
# create
conda create -n skglm-jax python=3.10

# activate env
conda activate skglm-jax
```

2. install ``skglm`` in editable mode
```shell
pip install skglm -e .
```

3. install dependencies
```shell
# jax
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
```
12 changes: 12 additions & 0 deletions skglm/skglm_jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# if not set, raises an error related to CUDA linking API.
# as recommended, setting the 'XLA_FLAGS' to bypass it.
# side-effect: (perhaps) slow compilation time.
# import os
# os.environ['XLA_FLAGS'] = '--xla_gpu_force_compilation_parallelism=1' # noqa

# set flag to resolve bug with `jax.linalg.norm`
# ref: https://github.com/google/jax/issues/8916#issuecomment-1101113497
# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "False" # noqa

import jax
jax.config.update("jax_enable_x64", True)
141 changes: 141 additions & 0 deletions skglm/skglm_jax/anderson_cd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from functools import partial

import jax
import numpy as np
import jax.numpy as jnp

from skglm.skglm_jax.datafits import QuadraticJax
from skglm.skglm_jax.penalties import L1Jax
from skglm.skglm_jax.utils import JaxAA


class AndersonCD:

EPS_TOL = 0.3

def __init__(self, max_iter=100, max_epochs=100, tol=1e-6, p0=10,
use_acc=False, verbose=0):
self.max_iter = max_iter
self.max_epochs = max_epochs
self.tol = tol
self.p0 = p0
self.use_acc = use_acc
self.verbose = verbose

def solve(self, X, y, datafit: QuadraticJax, penalty: L1Jax):
X, y = self._transfer_to_device(X, y)

n_samples, n_features = X.shape
lipschitz = datafit.get_features_lipschitz_cst(X, y)

w = jnp.zeros(n_features)
Xw = jnp.zeros(n_samples)
all_features = jnp.full(n_features, fill_value=True, dtype=bool)

for it in range(self.max_iter):

# check convergence
grad = datafit.gradient_ws(X, y, w, Xw, all_features)
scores = penalty.subdiff_dist_ws(w, grad, all_features)
stop_crit = jnp.max(scores)

if self.verbose:
p_obj = datafit.value(X, y, w) + penalty.value(w)

print(
f"Iteration {it}: p_obj_in={p_obj:.8f} "
f"stop_crit_in={stop_crit:.4e}"
)

if stop_crit <= self.tol:
break

# build ws
gsupp_size = penalty.generalized_support(w).sum()
ws_size = min(
max(2 * gsupp_size, self.p0),
n_features
)

ws = jnp.full(n_features, fill_value=False, dtype=bool)
ws_features = jnp.argsort(scores)[-ws_size:]
ws = ws.at[ws_features].set(True)

tol_in = AndersonCD.EPS_TOL * stop_crit

w, Xw = self._solve_sub_problem(X, y, w, Xw, ws, lipschitz, tol_in,
datafit, penalty)

w_cpu = np.asarray(w)
return w_cpu

def _solve_sub_problem(self, X, y, w, Xw, ws, lipschitz, tol_in,
datafit, penalty):

if self.use_acc:
accelerator = JaxAA(K=5)

for epoch in range(self.max_epochs):

w, Xw = self._cd_epoch(X, y, w, Xw, ws, lipschitz,
datafit, penalty)

if self.use_acc:
w, Xw = accelerator.extrapolate(w, Xw)

# check convergence
grad_ws = datafit.gradient_ws(X, y, w, Xw, ws)
scores_ws = penalty.subdiff_dist_ws(w, grad_ws, ws)
stop_crit_in = jnp.max(scores_ws)

if max(self.verbose - 1, 0):
p_obj_in = datafit.value(X, y, w) + penalty.value(w)

print(
f"Epoch {epoch}: p_obj_in={p_obj_in:.8f} "
f"stop_crit_in={stop_crit_in:.4e}"
)

if stop_crit_in <= tol_in:
break

return w, Xw

@partial(jax.jit, static_argnums=(0, -2, -1))
def _cd_epoch(self, X, y, w, Xw, ws, lipschitz, datafit, penalty):
for j, in_ws in enumerate(ws):

w, Xw = jax.lax.cond(
in_ws,
lambda X, y, w, Xw, j, lipschitz: self._cd_epoch_j(X, y, w, Xw, j, lipschitz, datafit, penalty), # noqa
lambda X, y, w, Xw, j, lipschitz: (w, Xw),
*(X, y, w, Xw, j, lipschitz)
)

return w, Xw

@partial(jax.jit, static_argnums=(0, -2, -1))
def _cd_epoch_j(self, X, y, w, Xw, j, lipschitz, datafit, penalty):

# Null columns of X would break this functions
# as their corresponding lipschitz is 0
# TODO: implement condition using lax
# if lipschitz[j] == 0.:
# continue

step = 1 / lipschitz[j]

grad_j = datafit.gradient_1d(X, y, w, Xw, j)
next_w_j = penalty.prox_1d(w[j] - step * grad_j, step)

delta_w_j = next_w_j - w[j]

w = w.at[j].set(next_w_j)
Xw = Xw + delta_w_j * X[:, j]

return w, Xw

def _transfer_to_device(self, X, y):
# TODO: other checks
# - skip if they are already jax array
return jnp.asarray(X), jnp.asarray(y)
48 changes: 48 additions & 0 deletions skglm/skglm_jax/datafits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import jax
import jax.numpy as jnp
from jax.numpy.linalg import norm as jnorm

from skglm.skglm_jax.utils import jax_jit_method


class QuadraticJax:
"""1 / (2 n_samples) ||y - Xw||^2"""

def value(self, X, y, w):
n_samples = X.shape[0]
return ((X @ w - y) ** 2).sum() / (2. * n_samples)

def gradient_1d(self, X, y, w, Xw, j):
n_samples = X.shape[0]
return X[:, j] @ (Xw - y) / n_samples

@jax_jit_method
def gradient_ws(self, X, y, w, Xw, ws):
n_features = X.shape[1]
Xw_minus_y = Xw - y

grad_ws = jnp.empty(n_features)
for j, in_ws in enumerate(ws):

grad_j = jax.lax.cond(
in_ws,
lambda X, Xw_minus_y, j: X[:, j] @ Xw_minus_y / len(Xw_minus_y),
lambda X, Xw_minus_y, j: 0.,
*(X, Xw_minus_y, j)
)

grad_ws = grad_ws.at[j].set(grad_j)

return grad_ws

def get_features_lipschitz_cst(self, X, y):
n_samples = X.shape[0]
return jnorm(X, ord=2, axis=0) ** 2 / n_samples

def get_global_lipschitz_cst(self, X, y):
n_samples = X.shape[0]
return jnorm(X, ord=2) ** 2 / n_samples

def gradient(self, X, y, w):
n_samples = X.shape[0]
return X.T @ (X @ w - y) / n_samples
79 changes: 79 additions & 0 deletions skglm/skglm_jax/fista.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np

import jax
import jax.numpy as jnp

from skglm.skglm_jax.datafits import QuadraticJax
from skglm.skglm_jax.penalties import L1Jax


class Fista:

def __init__(self, max_iter=200, use_auto_diff=True, verbose=0):
self.max_iter = max_iter
self.use_auto_diff = use_auto_diff
self.verbose = verbose

def solve(self, X, y, datafit: QuadraticJax, penalty: L1Jax):
n_samples, n_features = X.shape
X_gpu, y_gpu = jnp.asarray(X), jnp.asarray(y)

# compute step
lipschitz = datafit.get_global_lipschitz_cst(X_gpu, y_gpu)
if lipschitz == 0.:
return np.zeros(n_features)

step = 1 / lipschitz
all_features = jnp.full(n_features, fill_value=True, dtype=bool)

# get grad func of datafit
if self.use_auto_diff:
auto_grad = jax.jit(jax.grad(datafit.value, argnums=-1))

# init vars in device
w = jnp.zeros(n_features)
old_w = jnp.zeros(n_features)
mid_w = jnp.zeros(n_features)
grad = jnp.zeros(n_features)

t_old, t_new = 1, 1

for it in range(self.max_iter):

# compute grad
if self.use_auto_diff:
grad = auto_grad(X_gpu, y_gpu, mid_w)
else:
grad = datafit.gradient(X_gpu, y_gpu, mid_w)

# forward / backward
val = mid_w - step * grad
w = penalty.prox(val, step)

if self.verbose:
p_obj = datafit.value(X_gpu, y_gpu, w) + penalty.value(w)

if self.use_auto_diff:
grad = auto_grad(X_gpu, y_gpu, w)
else:
grad = datafit.gradient(X_gpu, y_gpu, w)

scores = penalty.subdiff_dist_ws(w, grad, all_features)
stop_crit = jnp.max(scores)

print(
f"Iteration {it:4}: p_obj={p_obj:.8f}, opt crit={stop_crit:.4e}"
)

# extrapolate
mid_w = w + ((t_old - 1) / t_new) * (w - old_w)

# update FISTA vars
t_old = t_new
t_new = 0.5 * (1 + jnp.sqrt(1. + 4. * t_old ** 2))
old_w = jnp.copy(w)

# transfer back to host
w_cpu = np.asarray(w, dtype=np.float64)

return w_cpu
54 changes: 54 additions & 0 deletions skglm/skglm_jax/penalties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import jax
import jax.numpy as jnp

from skglm.skglm_jax.utils import jax_jit_method


class L1Jax:
"""alpha ||w||_1"""

def __init__(self, alpha):
self.alpha = alpha

def value(self, w):
return (self.alpha * jnp.abs(w)).sum()

def prox_1d(self, value, stepsize):
shifted_value = jnp.abs(value) - stepsize * self.alpha
return jnp.sign(value) * jnp.maximum(shifted_value, 0.)

def prox(self, value, stepsize):
return self.prox_1d(value, stepsize)

@jax_jit_method
def subdiff_dist_ws(self, w, grad_ws, ws):
n_features = w.shape[0]
dist = jnp.empty(n_features)

for j, in_ws in enumerate(ws):
w_j = w[j]
grad_j = grad_ws[j]

dist_j = jax.lax.cond(
in_ws,
self._compute_subdiff_dist_j,
lambda w_j, grad_j: 0.,
*(w_j, grad_j)
)

dist = dist.at[j].set(dist_j)

return dist

def generalized_support(self, w):
return w != 0.

@jax_jit_method
def _compute_subdiff_dist_j(self, w_j, grad_j):
dist_j = jax.lax.cond(
w_j == 0.,
lambda w_j, grad_j, alpha: jnp.maximum(jnp.abs(grad_j) - alpha, 0.),
lambda w_j, grad_j, alpha: jnp.abs(grad_j + jnp.sign(w_j) * alpha),
*(w_j, grad_j, self.alpha)
)
return dist_j
Loading