-
Notifications
You must be signed in to change notification settings - Fork 396
Description
Describe the bug
I see very cool advancements in the direction of LLM RL training in the repo, awesome work! :)
After playing a bit with the LLMEnv I got the following error when passing dialogue data to the env.
RuntimeError: modifying the batch size of a lazy representation of a tensordict is not permitted. Consider instantiating the tensordict first by calling td = td.to_tensordict() before resetting the batch size.
Dialogue data is a pretty common format when using LLMs to allow the model to see inputs from the system and the user, and you can easily format the data into a single string by using different chat templates. Not sure if the intention is to support dialogue data in the format I am passing it to the env, but I feel like it would be convenient.
I detected that the bug goes away if I comment out this line, but I do not think this is the solution.
https://github.com/pytorch/rl/blob/main/torchrl/envs/transforms/llm.py#L534
To Reproduce
from datasets import Dataset
from torch.utils.data import DataLoader
from torchrl.envs.custom.llm import LLMEnv
def collate_fn(batch: list[dict]) -> dict[str, list]:
return {k: [el[k] for el in batch] for k in batch[0]}
# Dummy data
sample = {
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the capital of France?"}
]
}
# Repeat the sample 1000 times
data = [sample] * 1000
# Create the dataset
dataset = Dataset.from_list(data)
# Create a PyTorch DataLoader
batch_size = 16
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn
)
env = LLMEnv.from_dataloader(
dataloader=dataloader,
str2str=True,
batch_size=4,
str_key='messages'
)
obs = env.reset()
Additional context
Add any other context about the problem here.
Reason and Possible fixes
If you know or suspect the reason for this bug, paste the code lines and suggest modifications.
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)