forked from albertlai/deep-style-transfer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbatches.py
More file actions
85 lines (79 loc) · 2.99 KB
/
batches.py
File metadata and controls
85 lines (79 loc) · 2.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import utils
import skimage
import numpy as np
DUMMY = 'NULL'
class BatchGenerator:
""" Loads images in a directory into batches
Args:
batch_size: Number of items per batch
image_h: Height of image
image_w: Width of image
image_dir: Image directory to load batches from
max_batches: Max number of batches to load into memory at any given moment
valid: Whether or not this is a validation set
logging: Logging object
batch_index: Batch index to start loading from
"""
def __init__(self, batch_size, image_h, image_w, image_dir=DUMMY, max_batches=100, valid=False, logging=None, batch_index=0):
self.image_dir = image_dir
self.batch_size = batch_size
self.max_batches = max_batches
if logging:
logging.info("setting up batches %d at a time starting from %d" % (self.max_batches, batch_index))
else:
print("setting up batches %d at a time starting from %d" % (self.max_batches, batch_index))
self.image_h = image_h
self.image_w = image_w
self.logging = logging
file_list = os.listdir(image_dir) if image_dir != DUMMY else []
self.last_load = batch_index if not valid else len(file_list)-1
self.index = 0
self.batches = None
self.valid = valid
self.load_batches()
def load_batches(self):
""" Load max_batches batches into memory """
is_new = False
if self.batches:
batches = self.batches
else:
is_new = True
batches = []
self.batches = batches
image_dir = self.image_dir
batch_size = self.batch_size
image_h = self.image_h
image_w = self.image_w
file_list = os.listdir(image_dir) if image_dir != DUMMY else []
n = self.last_load
for b in range(self.max_batches):
if is_new:
arr = np.zeros((batch_size, image_h, image_w, 3))
batches.append(arr)
else:
arr = batches[b]
if image_dir != DUMMY:
i = 0
while i < batch_size:
file_name = file_list[n]
try:
image = utils.load_image(os.path.join(image_dir,file_name), image_h, image_w)
arr[i] = image
i += 1
except:
pass
n += 1 if not self.valid else -1
self.last_load = n
def get_batch(self):
""" Returns the next batch. Starts loading the next set of batches into memory
if we reach the end
"""
batch = self.batches[self.index]
self.index += 1
if self.index >= len(self.batches):
self.index = 0
self.load_batches()
return np.array(batch)
def get_last_load(self):
return self.last_load