-
Notifications
You must be signed in to change notification settings - Fork 101
🎨 Refactor ModelABC to Help Use Default Torch Models
#867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev-define-engines-abc
Are you sure you want to change the base?
🎨 Refactor ModelABC to Help Use Default Torch Models
#867
Conversation
Signed-off-by: Shan E Ahmed Raza <[email protected]>
Signed-off-by: Shan E Ahmed Raza <[email protected]>
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## dev-define-engines-abc #867 +/- ##
==========================================================
- Coverage 91.19% 91.11% -0.09%
==========================================================
Files 73 73
Lines 9374 9379 +5
Branches 1230 1230
==========================================================
- Hits 8549 8546 -3
- Misses 792 800 +8
Partials 33 33 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| with torch.inference_mode(): | ||
| output = model(img_patches_device) | ||
| # Output should be a single tensor or scalar | ||
| return {"probabilities": output.cpu().numpy()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the current develop branch, neither CNNModel, nor CNNBackbone returned dictionaries as output of their infer_batch() methods. Also, CNNModel currently returns an array, while CNNBackbone returns a list with the array. It might be fine, just wanted to highlight this.
CNNModel
| return output.cpu().numpy() |
CNNBackbone
| return [output.cpu().numpy()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. We are aware of this. Our preference is to use torch nn models but to generalise for multi modal output we may need dictionaries. This PR is to check if we can move to generic torch models or we will need a sub class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense.
…rch-nn-model # Conflicts: # tests/models/test_arch_vanilla.py # tiatoolbox/models/architecture/vanilla.py
…rch-nn-model # Conflicts: # tiatoolbox/models/engine/engine_abc.py
ModelABCto Help Use Default Torch Modelsinfer_batchfromModelABC