Skip to content

Commit 5ff5573

Browse files
authored
Cleaner implementation of the classification ml_type
1 parent d08c480 commit 5ff5573

File tree

1 file changed

+10
-16
lines changed

1 file changed

+10
-16
lines changed

label_maker/label.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -128,27 +128,21 @@ def make_labels(dest_folder, zoom, country, classes, ml_type, bounding_box, spar
128128
# write out labels as GeoJSON or PNG
129129
if ml_type == 'classification':
130130
features = []
131-
label_area = []
132-
label_bool = []
133-
i = 0
134-
for tile, label in tile_results.items():
135-
label_bool.append([int(bool(l)) for l in label])
136-
label_area.append([float(l) for l in label])
137-
if 'label_bool2' in locals() and 'label_area2' in locals():
138-
label_bool[i] = [int(bool(label_bool[i] + label_bool2[i])) for label_bool[i], label_bool2[i] in zip(label_bool[i], label_bool2[i])]
139-
label_area[i] = [label_area[i] + label_area2[i] for label_area[i], label_area2[i] in zip(label_area[i], label_area2[i])]
131+
if ctr_idx == 0:
132+
label_area = np.zeros((len(kwargs['classes'])+1,len(tile_results),len(country)),dtype=float)
133+
label_bool = np.zeros((len(kwargs['classes'])+1,len(tile_results),len(country)),dtype=bool)
134+
for i, tile, label in enumerate(tile_results.items()):
135+
label_bool[:,i,ctr_idx] = np.asarray([bool(l) for l in label])
136+
label_area[:,i,ctr_idx] = np.asarray([float(l) for l in label])
140137
# if there are no classes, activate the background
141138
if ctr == country[-1]:
142-
if all(v == 0 for v in label_bool[i]):
143-
label_bool[i][0] = 1
139+
if all(v == 0 for v in label_bool[:,i,ctr_idx]):
140+
label_bool[0,i,ctr_idx] = 1
144141
feat = feature(Tile(*[int(t) for t in tile.split('-')]))
145142
features.append(Feature(geometry=feat['geometry'],
146143
properties=dict(feat_id=str(tile),
147-
label=label_bool[i],
148-
label_area=label_area[i])))
149-
i += 1
150-
label_bool2 = label_bool
151-
label_area2 = label_area
144+
label=np.any(label_bool[:,i,:],axis=1).astype(int).tolist(),
145+
label_area=np.sum(label_area[:,i,:],axis=1).tolist()))
152146
if ctr == country[-1]:
153147
json.dump(fc(features), open(op.join(dest_folder, f'classification_{zoom}.geojson'), 'w'))
154148
elif ml_type == 'object-detection':

0 commit comments

Comments
 (0)