-
Notifications
You must be signed in to change notification settings - Fork 99
fix: random sampling in ForgetRetainDataset #145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
src/data/unlearn.py
Outdated
g = torch.Generator() | ||
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 | ||
seed = int(torch.empty((), dtype=torch.int64).random_().item() + rank) | ||
g.manual_seed(seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be better to use the seed from the experiment config here, rather than
int(torch.empty((), dtype=torch.int64).random_().item()
to avoid introducing randomness uncontrolled by the seed.
can you try to see if you can make the experiment's cfg.seed
available to this dataset class and then use seed = exp_seed + rank
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR! Please see comment
Thanks for the feedback! I've updated the PR accordingly. Please let me know if there are any further adjustments required. |
Please fix the lint errors! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not ideal to set the seed at the exact example level. This would mean we select the same retain example index sequences even if we are using a different dataset.
Since the point is that each rank must get a different seed, imo it is better to get the rank in the global seed function: https://github.com/locuslab/open-unlearning/blob/main/src/trainer/utils.py#L8
Let me know if you see any issues.
What does this PR do?
Fixes #139
Before submitting