This is the implementation of the paper:
"A simple neural network module for relational reasoning," A. Santoro et al., 2017 
- Python 3
 - PyTorch ,torchvision
 - termcolor ,tqdm
 
- train.py is the main training script
 
# select model_type (base, improved, improved2)
$ python train.py --model-type 'base'
$ python train.py --model-type 'improved'
$ python train.py --model-type 'improved2'- eval.py evaluates on Sort-of-CLEVR dataset
 
# select model_type (base, improved, improved2)
$ python eval.py --model-type 'base' --load-model 'trained_model/base_model.pth.tar'
$ python eval.py --model-type 'improved' --load-model 'trained_model/improved_model.pth.tar' 
$ python eval.py --model-type 'improved2' --load-model 'trained_model/improved_model2.pth.tar' | Models | Overall | Non-relational question | Relational question | 
|---|---|---|---|
| Reproduced RNs (base) | 96.4 % | 99.5 % | 93.4 % | 
| RNs + Weighed pairs (improved) | 97.5 % | 99.8 % | 95.1 % | 
| RNs + Enhanced features (improved2) | 97.7 % | 99.8 % | 95.6 % | 
.
├── datsets/
    ├── sort-of-clevr_test.pickle
    └── sort-of-clevr_train.pickle
├── util/
    ├── torch_util.py
    └── train_test_fn.py
├── models/
    ├── base_model.py
    └── improved_model.py
├── trained_models/
    ├── base_model.pth.tar
    └── improved_model.pth.tar
├── so_clevr_dataset.py
├── eval.py
├── train.py
└── README.md
If you need the trained model (chekpoint) or dataset (sort-of-clevr), feel free to send me an e-mail.
