-
Notifications
You must be signed in to change notification settings - Fork 398
[Bugfix] Wrong minari download first element #3106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Bugfix] Wrong minari download first element #3106
Conversation
… bugfix/wrong_minari_download_first_element # Conflicts: # torchrl/data/datasets/minari_data.py
… bugfix/wrong_minari_download_first_element # Conflicts: # torchrl/data/datasets/minari_data.py
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3106
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Added some quick comments here and there.
A test would be nice! (you can just add a line in an existing test to check that the bug is resolved?)
@@ -281,6 +283,7 @@ def _download_and_preproc(self): | |||
f"loading dataset from local Minari cache at {h5_path}" | |||
) | |||
h5_data = PersistentTensorDict.from_h5(h5_path) | |||
h5_data = h5_data.to_tensordict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this? It's a bit expensive so if we can avoid it it's better (under the hood it copies the entire dataset in memory)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is nothing I would love more than getting rid of that line. The method to change from NonTensorData to NonTensorStack is basically:
with set_list_to_stack(True):
tensordict[key] = data_list
Unfortunately, if we don't get the h5_data into memory, we face this error upon rewriting each NonTensorData key.
OSError: Can't synchronously write data (no write intent on file)
If anyone knows how to avoid this, I would love to get this thing fixed in a better way.
torchrl/data/datasets/minari_data.py
Outdated
return tensordict | ||
|
||
|
||
def extract_nontensor_fields(td: TensorDictBase, recursive: bool = False) -> TensorDict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These functions look similar but perform different functions. One preallocates keys in a tensordict, the other deletes NonTensorData keys and another one for transforming NonTensorData into NonTensorStack. But they all traverse the tensordict keys. I can try to refactor them a bit more later.
torchrl/data/datasets/minari_data.py
Outdated
return TensorDict(extracted, batch_size=td.batch_size) | ||
|
||
|
||
def preallocate_nontensor_fields(td_data: TensorDictBase, example: TensorDictBase, total_steps: int, name_map: dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
torchrl/data/datasets/minari_data.py
Outdated
for key in list(tensordict.keys()): | ||
val = tensordict.get(key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use items()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mmmm I get this error when using .items()
Traceback (most recent call last):
File "/Users/O000142/Projects/rl/torchrl/data/datasets/minari_data.py", line 585, in _extract_nontensor_fields
for key, val in tensordict.items():
RuntimeError: dictionary changed size during iteration
This is due to the fact that we are deleting keys in the dataset as we are iterating them. Looks better with items()
, but sadly doesn't work
I have applied most of your suggestions (except for the ditto suggestions, those take longer). I have also added a test. I am also having a problem that I might need help with. This is the code that I am using to test if all these changes work: from torchrl.data.datasets.minari_data import MinariExperienceReplay
BATCH_SIZE = 1
SAVE_ROOT = None
def download_minari_datasets(dataset_id):
a = MinariExperienceReplay(
dataset_id=dataset_id,
batch_size=BATCH_SIZE,
root=SAVE_ROOT,
)
print(f"✓ Successfully downloaded {dataset_id}")
print(a[210][('observation', 'mission')])
print(a[1210][('observation', 'mission')])
print(a[2210][('observation', 'mission')])
print(a[3210][('observation', 'mission')])
if __name__ == "__main__":
download_minari_datasets("minigrid/BabyAI-Pickup/optimal-v0") If you get single, different missions for each step, then it looks probably alright. The problem is that the second time you run this code, now that the dataset has already been downloaded, the experience replay keeps loading the same wrong missions from before. Long, repeated missions at each step. It happens as we use the _load function. I have absolutely no idea what might be the reason for this behaviour. |
Description
PersistentTensorDict.from_h5(h5_path) is the function that loads the h5 data and creates the first TensorDict version of the dataset. I have introduced the neccesary code for the _download_and_preproc function to handle all the substitution of NonTensorData into NonTensorStack. That is the way Minari uses the 'mission' categorical values.
I have introduced a fundamental change in the way MinariExperienceReplay works:
I think this change brings the tensordict to memory instead of remaining lazy, but it was neccesary for the patching to take place. Just let me know if we can handle this in a different way.
Motivation and Context
It solves the following issue #3105
With these changes, and using the example introduced in the issue, we can succesfully retrieve the correct mission for each step and see that it changes throught episodes and steps.
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!