Skip to content

[BUG] RuntimeError when passing dialogue data to LLMEnv #2875

@albertbou92

Description

@albertbou92

Describe the bug

I see very cool advancements in the direction of LLM RL training in the repo, awesome work! :)

After playing a bit with the LLMEnv I got the following error when passing dialogue data to the env.

RuntimeError: modifying the batch size of a lazy representation of a tensordict is not permitted. Consider instantiating the tensordict first by calling td = td.to_tensordict() before resetting the batch size.

Dialogue data is a pretty common format when using LLMs to allow the model to see inputs from the system and the user, and you can easily format the data into a single string by using different chat templates. Not sure if the intention is to support dialogue data in the format I am passing it to the env, but I feel like it would be convenient.

I detected that the bug goes away if I comment out this line, but I do not think this is the solution.
https://github.com/pytorch/rl/blob/main/torchrl/envs/transforms/llm.py#L534

To Reproduce

from datasets import Dataset
from torch.utils.data import DataLoader
from torchrl.envs.custom.llm import LLMEnv


def collate_fn(batch: list[dict]) -> dict[str, list]:
    return {k: [el[k] for el in batch] for k in batch[0]}

# Dummy data
sample = {
    "messages": [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "What's the capital of France?"}
    ]
}


# Repeat the sample 1000 times
data = [sample] * 1000

# Create the dataset
dataset = Dataset.from_list(data)

# Create a PyTorch DataLoader
batch_size = 16
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn
)

env = LLMEnv.from_dataloader(
    dataloader=dataloader,
    str2str=True,
    batch_size=4,
    str_key='messages'
)

obs = env.reset()

Additional context

Add any other context about the problem here.

Reason and Possible fixes

If you know or suspect the reason for this bug, paste the code lines and suggest modifications.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions