A jax implementation of the TGM algorithm. Contains:
- Algorithm implementations of
TGM(encompassesGFN)SACPPO
- Code and data for the following synthetic and biological sequence design tasks:
BitSequence- Untranslated region (
UTR) - Antimicrobial peptides (
AMP) - Green fluorescent protein (
GFP)
- Proxy training
To get started (currently tested for Python 3.10):
git clone https://github.com/marcojira/tgm.git
# Or your favorite virtual env
python -m venv env && source env/bin/activate && pip install --upgrade pip
pip install -e tgm
python example.pyThe easiest way to train a sampler, is to use the run function from medium_rl.run and pass it a Config object from src/medium_rl/config.py.
The Config objects contains general training options (e.g. sampling policy, minibatch size, evaluation, etc.). It also consists of the three following sub configurations:
EnvConfig: Configuration for the environmentAlgConfig: Configuration for the training algorithmNetworkConfig: Configuration for the neural network
To see available options, check src/medium_rl/config.py. Also already contains base configurations for each. They can then be composed as follows:
from medium_rl.config import (
AMPConfig,
TGMConfig,
BaseTransformerConfig,
Config,
)
env_cfg = AMPConfig() # Taking default options
alg_cfg = TGMConfig(alpha=1, omega=1, q=0.75) # Changing some values
network_cfg = BaseTransformerConfig(dropout=0.05)
cfg = Config(
env=env_cfg,
alg=alg_cfg,
network=network_cfg,
reward_exp=64, # Change beta
lr=1e-4, # Change lr
)Once the Config object is created, running training simply requires:
from medium_rl.run import run
run(cfg)On an L40s GPU, training for 100k samples should be quick:
BitSequence: <3 minutes.UTR: <5 minutes.AMP: <5 minutes.GFP: <30 minutes.
All environments are subclasses of SequenceEnv that describes a generic sequence generation DCG. Similarly to PGX, the core object is a State that contains information about the current sequence. Then, SequenceEnv defines init, step and get_rewards functions to initialize the state, step the state and get proxy rewards for a sequence.
For each of the biological sequence design tasks, the checkpoint for the proxy reward function is provided in src/medium_rl/envs/proxies/<env_name>/proxy.pkl and the validation mean/std in src/medium_rl/envs/proxies/<env_name>/val_stats.pkl
Synthetic task described in Trajectory balance: Improved credit assignment in GFlowNets.
Sequence design task for the 5' UTR mRNA region that regulates transcription of the main coding sequence. Data to train the proxy comes from brandontrabucco/design-bench#11 (comment) and consists of 250 000 sequences and their associated ribosome loads.
Antimicrobial peptide design task. Proxy was trained as a binary classifier to predict whether a sequence is antimicrobial on a dataset of 9222 non-AMP sequences and 6438 AMP sequences from https://github.com/MJ10/clamp-gen-data/tree/master/data/dataset. The logit of the classifier is used as proxy reward.
Green fluorescent protein design task. Data to train the proxy was sourced from brandontrabucco/design-bench#11 (comment) and consists of 56086 variations of the original GFP protein and their associated fluorescence.
To create a custom environment, one can extend the SequenceEnv environment as follows:
from medium_rl.envs.sequence_env import SequenceEnv
class NEW_ENV_NAMESequence(SequenceEnv):
# Everything below needs to be specified
name = "NEW_ENV_NAME"
num_tokens = len(NEW_ENV_ALPHABET)
alphabet = NEW_ENV_ALPHABET
dict = {NEW_ENV_ALPHABET[i]: i for i in range(len(NEW_ENV_ALPHABET))}
CLS = 0 # CLS or BOS token index
PAD = 1 # PAD token index
EOS = 2 # EOS token index
def __init__(self, min_len: int, max_len: int, **kwargs):
super().__init__(min_len, max_len)
self.proxy = NEW_ENV_PROXY() # Initialize proxy if necessary
def get_rewards(
self,
token_seq: Array, # [B, T], batch of sequence of tokens
):
# Need to specify/write proxy reward function that takes in a [B, T] array of token indexes
# - B: Batch size
# - T: Sequence length
# and returns the proxy reward
rewards = self.proxy.evaluate(token_seq)
return rewardsThen, a EnvConfig can be specified as follows
class NEW_ENVConfig(EnvConfig):
name: str = "NEW_ENV"
min_len: int = 5
max_len: int = 10
...src/medium_rl/envs/proxies/train_proxy.py contains code for training proxy reward functions from data. train_model expects a x (a [N, T] array of token indexes) and y (a [N,] array of either floats to regress to or binary classes) as well as model_cfg specifying the hyperparameters of the network. See train_proxies.py for example uses.
The biological environments are jax implementations with moderate modifications of the environments of Biological Sequence Design with GFlowNets
as well as the benchmarks of Design-Bench: Benchmarks for Data-Driven Offline Model-Based Optimization. The training process for the proxy reward functions comes from the former and the data used from the latter. The BitSequence environment comes from Trajectory balance: Improved credit assignment in GFlowNets
. The design of the SequenceEnv environment is inspired by the PGX library.
