MixMatch-pytorch-customized-dataset
This is a PyTorch implementation of MixMatch, which allows training with customized dataset.
The official Tensorflow implementation is here and the forked Pytorch implementation is here.
Two revised training functions are updated compared to the original forked repository.
In addition, I adjusted the code structure of the original Pytorch implementation and made necessary notes for better understanding.
train.py is the original Pytorch Implementation of that, which is trained on CIFAR-10 only.
- 
train_SSL.py Revised the dataset part to allow customized dataset for training. Revised the original MixMatch loss function by considering the potential class imbalance issue in the training data. 
- 
train_TL.py This is a simple baseline training process by supervised learning only using labeled data with the same number as that of SSL training. This allows performance evaluation with SSL training. 
Check code environment "requirements.txt".
- 
Customized dataset preparation. Put the data under "dataset/". Put the training/validatioin/test txt under the current location. Update the path information both in the train_SSL.py and train_TL.py. 
- 
Parameter settting by users. For example, update the number of labeled data for training. 
- 
Train the model in SSL mode: python train_SSL.py 
- 
Train the model in TL mode: python train_TL.py 
@article{berthelot2019mixmatch,
  title={MixMatch: A Holistic Approach to Semi-Supervised Learning},
  author={Berthelot, David and Carlini, Nicholas and Goodfellow, Ian and Papernot, Nicolas and Oliver, Avital and Raffel, Colin},
  journal={arXiv preprint arXiv:1905.02249},
  year={2019}
}