Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 42 additions & 32 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,48 +1,56 @@
# MixMatch
This is an unofficial PyTorch implementation of [MixMatch: A Holistic Approach to Semi-Supervised Learning](https://arxiv.org/abs/1905.02249).
The official Tensorflow implementation is [here](https://github.com/google-research/mixmatch).
MixMatch-pytorch-customized-dataset

Now only experiments on CIFAR-10 are available.
This is a PyTorch implementation of MixMatch, which allows training with customized dataset.

This repository carefully implemented important details of the official implementation to reproduce the results.
The official Tensorflow implementation is [here](https://github.com/google-research/mixmatch) and the forked Pytorch implementation is [here](https://github.com/YU1ut/MixMatch-pytorch).

## Revision Note

## Requirements
- Python 3.6+
- PyTorch 1.0
- **torchvision 0.2.2 (older versions are not compatible with this code)**
- tensorboardX
- progress
- matplotlib
- numpy
Two revised training functions are updated compared to the original forked [repository](https://github.com/YU1ut/MixMatch-pytorch).


In addition, I adjusted the code structure of the original Pytorch implementation and made necessary notes for better understanding.

train.py is the original Pytorch Implementation of [that](https://github.com/YU1ut/MixMatch-pytorch), which is trained on CIFAR-10 only.

1. train_SSL.py

Revised the dataset part to allow customized dataset for training.

Revised the original MixMatch loss function by considering the potential class imbalance issue in the training data.

2. train_TL.py

This is a simple baseline training process by supervised learning only using labeled data with the same number as that of SSL training.

This allows performance evaluation with SSL training.

## Usage

### Environment

Check code environment "requirements.txt".

### Train
Train the model by 250 labeled data of CIFAR-10 dataset:

```
python train.py --gpu <gpu_id> --n-labeled 250 --out cifar10@250
```
1. Customized dataset preparation.

Train the model by 4000 labeled data of CIFAR-10 dataset:
Put the data under "dataset/".

```
python train.py --gpu <gpu_id> --n-labeled 4000 --out cifar10@4000
```
Put the training/validatioin/test txt under the current location.

### Monitoring training progress
```
tensorboard.sh --port 6006 --logdir cifar10@250
```
Update the path information both in the train_SSL.py and train_TL.py.

2. Parameter settting by users. For example, update the number of labeled data for training.

## Results (Accuracy)
| #Labels | 250 | 500 | 1000 | 2000| 4000 |
|:---|:---:|:---:|:---:|:---:|:---:|
|Paper | 88.92 ± 0.87 | 90.35 ± 0.94 | 92.25 ± 0.32| 92.97 ± 0.15 |93.76 ± 0.06|
|This code | 88.71 | 88.96 | 90.52 | 92.23 | 93.52 |
3. Train the model in SSL mode:

python train_SSL.py

4. Train the model in TL mode:

python train_TL.py

(Results of this code were evaluated on 1 run. Results of 5 runs with different seeds will be updated later. )

## References
```
Expand All @@ -52,4 +60,6 @@ tensorboard.sh --port 6006 --logdir cifar10@250
journal={arXiv preprint arXiv:1905.02249},
year={2019}
}
```
```


38 changes: 38 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
brotlipy==0.7.0
certifi==2021.10.8
cffi @ file:///opt/conda/conda-bld/cffi_1642701102775/work
charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
cryptography @ file:///tmp/build/80754af9/cryptography_1639414570729/work
cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work
fonttools==4.25.0
idna @ file:///tmp/build/80754af9/idna_1637925883363/work
joblib @ file:///tmp/build/80754af9/joblib_1635411271373/work
kiwisolver @ file:///opt/conda/conda-bld/kiwisolver_1638569886207/work
matplotlib @ file:///tmp/build/80754af9/matplotlib-suite_1647441664166/work
mkl-fft==1.3.1
mkl-random @ file:///tmp/build/80754af9/mkl_random_1626179032232/work
mkl-service==2.4.0
munkres==1.1.4
numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1634106693478/work
opencv-python==4.5.5.64
packaging @ file:///tmp/build/80754af9/packaging_1637314298585/work
Pillow==9.0.1
progress @ file:///tmp/build/80754af9/progress_1614270514501/work
protobuf==3.19.1
pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
pyOpenSSL @ file:///opt/conda/conda-bld/pyopenssl_1643788558760/work
pyparsing @ file:///tmp/build/80754af9/pyparsing_1635766073266/work
PySocks @ file:///tmp/build/80754af9/pysocks_1594394576006/work
python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work
requests @ file:///opt/conda/conda-bld/requests_1641824580448/work
scikit-learn @ file:///tmp/build/80754af9/scikit-learn_1642601761909/work
scipy @ file:///tmp/build/80754af9/scipy_1641536880743/work
six @ file:///tmp/build/80754af9/six_1644875935023/work
tensorboardX @ file:///tmp/build/80754af9/tensorboardx_1621440489103/work
threadpoolctl @ file:///Users/ktietz/demo/mc3/conda-bld/threadpoolctl_1629802263681/work
torch==1.11.0
torchaudio==0.11.0
torchvision==0.12.0
tornado @ file:///tmp/build/80754af9/tornado_1606942283357/work
typing_extensions @ file:///opt/conda/conda-bld/typing_extensions_1647553014482/work
urllib3 @ file:///opt/conda/conda-bld/urllib3_1643638302206/work
Loading