Improve the e2tomoseg_convnet.py and make training faster #570
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
I found out that the
e2tomoseg_convnet.pycould be improved by usingmodel.fit. Training will be six times faster.For example, when I ran the following command:
time e2tomoseg_convnet_test.py --trainset=particles/GCB_001_bin6_SIRT_preproc__good_2_trainset.hdf --nettag=convnet_iter100 --learnrate=0.0001 --niter=100 --ncopy=1 --batch=16 --nkernel=40,40,1 --ksize=15,15,15 --poolsz=2,1,1 --trainout --training --device=gpuI could get the output as follows:
When using the
model.fit, the training will be much faster, and I can get almost the same results (like the loss, the network...)I can even set the
callbackparameter inmodel.fit()to set up some learning rate scheduler to improve the training process. So I think it is better to usemodel.fitinstead of the cycle ofmodel.train_on_batch.