Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions config.json
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
{
"name": "Mnist_LeNet",
"n_gpu": 1,

"n_gpu": 2,
"arch": {
"type": "MnistModel",
"args": {}
},
"data_loader": {
"type": "MnistDataLoader",
"args":{
"args": {
"data_dir": "data/",
"batch_size": 128,
"shuffle": true,
Expand All @@ -18,15 +17,16 @@
},
"optimizer": {
"type": "Adam",
"args":{
"args": {
"lr": 0.001,
"weight_decay": 0,
"amsgrad": true
}
},
"loss": "nll_loss",
"metrics": [
"accuracy", "top_k_acc"
"accuracy",
"top_k_acc"
],
"lr_scheduler": {
"type": "StepLR",
Expand All @@ -37,14 +37,11 @@
},
"trainer": {
"epochs": 100,

"save_dir": "saved/",
"save_period": 1,
"verbosity": 2,

"monitor": "min val_loss",
"early_stop": 10,

"tensorboard": true
}
}
}
160 changes: 156 additions & 4 deletions data_loader/data_loaders.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,168 @@
from torchvision import datasets, transforms
import os
import sys
from skimage import io, transform
import numpy as np
import torch
import pandas as pd

sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
from base import BaseDataLoader


# 1.####################################MnistDataLoader########################################

class MnistDataLoader(BaseDataLoader):
"""
MNIST data loading demo using BaseDataLoader
"""
def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True):

def __init__(self,
data_dir,
batch_size,
shuffle=True,
validation_split=0.0,
num_workers=1,
training=True):
trsfm = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
transforms.Normalize((0.1307, ), (0.3081, ))
])
self.data_dir = data_dir
self.dataset = datasets.MNIST(self.data_dir, train=training, download=True, transform=trsfm)
super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)
self.dataset = datasets.MNIST(self.data_dir,
train=training,
download=True,
transform=trsfm)
super().__init__(self.dataset, batch_size, shuffle, validation_split,
num_workers)


# 2.####################################FaceLandmarksDataset########################################

class Rescale(object):
"""Rescale the image in a sample to a given size.

Args:
output_size (tuple or int): Desired output size. If tuple, output is
matched to output_size. If int, smaller of image edges is matched
to output_size keeping aspect ratio the same.
"""

def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size

def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']

h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size

new_h, new_w = int(new_h), int(new_w)

img = transform.resize(image, (new_h, new_w))

# h and w are swapped for landmarks because for images,
# x and y axes are axis 1 and 0 respectively
landmarks = landmarks * [new_w / w, new_h / h]
return {'image': img, 'landmarks': landmarks}


class RandomCrop(object):
"""Crop randomly the image in a sample.

Args:
output_size (tuple or int): Desired output size. If int, square crop
is made.
"""

def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size

def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']

h, w = image.shape[:2]
new_h, new_w = self.output_size

top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)

image = image[top:top + new_h, left:left + new_w]
landmarks = landmarks - [left, top]
return {'image': image, 'landmarks': landmarks}


class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""

def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']

# swap color axis because
# numpy image: H x W x C
# torch image: C x H x W
image = image.transpose((2, 0, 1))
return {
'image': torch.from_numpy(image),
'landmarks': torch.from_numpy(landmarks)
}


class FaceLandmarksDataset(BaseDataLoader):
"""Face Landmarks dataset."""

def __init__(self, csv_file, root_dir, transfm=None):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transfm

def __len__(self):
return len(self.landmarks_frame)

def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()

img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx,
0])
image = io.imread(img_name)
landmarks = self.landmarks_frame.iloc[idx, 1:]
landmarks = np.array([landmarks])
landmarks = landmarks.astype('float').reshape(-1, 2)
sample = {'image': image, 'landmarks': landmarks}

if self.transform:
sample = self.transform(sample)

return sample


if __name__ == "__main__":
composed = transforms.Compose([Rescale(256),
RandomCrop(224)])
face_data = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
root_dir='data/faces/',
transfm=composed)
for i in range(len(face_data)):
sample = face_data[i]
print("sample['image'].shape: {}".format(sample['image'].shape))
print("sample['landmarks'].shape: {}".format(sample['landmarks'].shape))
26 changes: 20 additions & 6 deletions parse_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@


class ConfigParser:

def __init__(self, config, resume=None, modification=None, run_id=None):
"""
class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving
Expand All @@ -26,7 +27,7 @@ class to parse configuration json file. Handles hyperparameters for training, in
save_dir = Path(self.config['trainer']['save_dir'])

exper_name = self.config['name']
if run_id is None: # use timestamp as default run-id
if run_id is None: # use timestamp as default run-id
run_id = datetime.now().strftime(r'%m%d_%H%M%S')
self._save_dir = save_dir / 'models' / exper_name / run_id
self._log_dir = save_dir / 'log' / exper_name / run_id
Expand Down Expand Up @@ -67,14 +68,17 @@ def from_args(cls, args, options=''):
assert args.config is not None, msg_no_cfg
resume = None
cfg_fname = Path(args.config)

config = read_json(cfg_fname)
if args.config and resume:
# update new config for fine-tuning
config.update(read_json(args.config))

# parse custom cli options into dictionary
modification = {opt.target : getattr(args, _get_opt_name(opt.flags)) for opt in options}
modification = {
opt.target: getattr(args, _get_opt_name(opt.flags))
for opt in options
}
return cls(config, resume, modification)

def init_obj(self, name, module, *args, **kwargs):
Expand All @@ -88,7 +92,9 @@ def init_obj(self, name, module, *args, **kwargs):
"""
module_name = self[name]['type']
module_args = dict(self[name]['args'])
assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
assert all([
k not in module_args for k in kwargs
]), 'Overwriting kwargs given in config file is not allowed'
module_args.update(kwargs)
return getattr(module, module_name)(*args, **module_args)

Expand All @@ -103,16 +109,20 @@ def init_ftn(self, name, module, *args, **kwargs):
"""
module_name = self[name]['type']
module_args = dict(self[name]['args'])
assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
assert all([
k not in module_args for k in kwargs
]), 'Overwriting kwargs given in config file is not allowed'
module_args.update(kwargs)
return partial(getattr(module, module_name), *args, **module_args)

def __getitem__(self, name):
"""Access items like ordinary dict."""
return self.config[name]

# verbosity 指的是详细等级
def get_logger(self, name, verbosity=2):
msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys())
msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(
verbosity, self.log_levels.keys())
assert verbosity in self.log_levels, msg_verbosity
logger = logging.getLogger(name)
logger.setLevel(self.log_levels[verbosity])
Expand All @@ -131,6 +141,7 @@ def save_dir(self):
def log_dir(self):
return self._log_dir


# helper functions to update config dict with custom cli options
def _update_config(config, modification):
if modification is None:
Expand All @@ -141,17 +152,20 @@ def _update_config(config, modification):
_set_by_path(config, k, v)
return config


def _get_opt_name(flags):
for flg in flags:
if flg.startswith('--'):
return flg.replace('--', '')
return flags[0].replace('--', '')


def _set_by_path(tree, keys, value):
"""Set a value in a nested object in tree by sequence of keys."""
keys = keys.split(';')
_get_by_path(tree, keys[:-1])[keys[-1]] = value


def _get_by_path(tree, keys):
"""Access a nested object in tree by sequence of keys."""
return reduce(getitem, keys, tree)
Loading