-
Notifications
You must be signed in to change notification settings - Fork 9.8k
Add accelerate API support for Word Language Model example #1345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
f56fb9c
a6d10a5
39bf17a
c50a636
93113d1
ed9aa93
3b003ec
70cee5b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -37,10 +37,6 @@ | |||
help='tie the word embedding and softmax weights') | ||||
parser.add_argument('--seed', type=int, default=1111, | ||||
help='random seed') | ||||
parser.add_argument('--cuda', action='store_true', default=False, | ||||
help='use CUDA') | ||||
parser.add_argument('--mps', action='store_true', default=False, | ||||
help='enables macOS GPU training') | ||||
parser.add_argument('--log-interval', type=int, default=200, metavar='N', | ||||
help='report interval') | ||||
parser.add_argument('--save', type=str, default='model.pt', | ||||
|
@@ -51,25 +47,20 @@ | |||
help='the number of heads in the encoder/decoder of the transformer model') | ||||
parser.add_argument('--dry-run', action='store_true', | ||||
help='verify the code and the model') | ||||
parser.add_argument('--accel', action='store_true',help='Enables accelerated training') | ||||
args = parser.parse_args() | ||||
|
||||
# Set the random seed manually for reproducibility. | ||||
torch.manual_seed(args.seed) | ||||
if torch.cuda.is_available(): | ||||
if not args.cuda: | ||||
print("WARNING: You have a CUDA device, so you should probably run with --cuda.") | ||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | ||||
if not args.mps: | ||||
print("WARNING: You have mps device, to enable macOS GPU run with --mps.") | ||||
|
||||
use_mps = args.mps and torch.backends.mps.is_available() | ||||
if args.cuda: | ||||
device = torch.device("cuda") | ||||
elif use_mps: | ||||
device = torch.device("mps") | ||||
|
||||
if args.accel and torch.accelerator.is_available(): | ||||
device = torch.accelerator.current_accelerator() | ||||
|
||||
else: | ||||
device = torch.device("cpu") | ||||
|
||||
print("Using device:", device) | ||||
|
||||
############################################################################### | ||||
# Load data | ||||
############################################################################### | ||||
|
@@ -243,11 +234,11 @@ def export_onnx(path, batch_size, seq_len): | |||
|
||||
# Load the best saved model. | ||||
with open(args.save, 'rb') as f: | ||||
model = torch.load(f) | ||||
torch.load(f, weights_only=False) | ||||
|
torch<2.6 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I extract the change and update the requirements to 2.7 it won't work, this change allows the example to run with the simplest code change, since leaving it as it was fails to work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In PyTorch 2.6, the default value for weights_only
was set to True, and PyTorch 2.7 introduced support for the accelerator
API.
In this pull request, we can integrate the use of the accelerator
API in this PR. Meanwhile, we will address the update for saving and loading models using state_dict
in a separate pull request.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTorch 2.7 introduced support for the accelerator API. <...>In this pull request, we can integrate the use of the accelerator API in this PR.
From 2.6 actually. See https://docs.pytorch.org/docs/2.6/accelerator.html#module-torch.accelerator.
To integrate torch.accelerator
we must update the requirement for torch to be >=2.6. Otherwise tests will simply fail. I suspect that you did not actually run the modified run_python_examples.sh
.
If I extract the change and update the requirements to 2.7 it won't work
I believe you are doing changes in the wrong order. First, update requirement to be able to use latest pytorch and fix issues which appear. Next, as a second step, introduce new APIs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did run the modified run_python_examples.sh
but maybe I am doing this in the wrong order. So the suggestion here is to first update requirements and fix the issues in a separate PR, close this one and create a new one for the new API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First, we need to run the example with latest PyTorch and fix any issue in a separate PR.
Thanks for the feedback @dvrogozh.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the suggestion here is to first update requirements and fix the issues in a separate PR, close this one and create
Yes, but you don't need to close this PR. Just mark it as a draft while working on the update requirements PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is a PR to update torch version requirement as I would do it:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what was the error you're getting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to be an overlook from my part. This was needed when trying a safe approach of only loading the weights but apparently it is no longer needed. I will remove it to prevent any unwanted changes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest to drop this from example command line and maybe add a note that example supports running on acceleration devices and list which were tried (CUDA, MPS, XPU).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated README