Skip to content

Commit 2d9f8a6

Browse files
authored
Merge pull request #134 from fastlabel/fix-mask-image
fix mask image logic
2 parents dcc5cb0 + ad48906 commit 2d9f8a6

File tree

2 files changed

+39
-5
lines changed

2 files changed

+39
-5
lines changed

fastlabel/__init__.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2308,10 +2308,16 @@ def __export_index_color_image(
23082308
if count == 0:
23092309
cv_draw_points = []
23102310
if utils.is_clockwise(points):
2311-
cv_draw_points = self.__get_cv_draw_points(points)
2311+
cv_draw_points = self.__get_cv_draw_points(
2312+
utils.sort_segmentation_points(points)
2313+
)
23122314
else:
2315+
reverse_points = utils.reverse_points(points)
2316+
sorted_points = utils.sort_segmentation_points(
2317+
reverse_points
2318+
)
23132319
cv_draw_points = self.__get_cv_draw_points(
2314-
utils.reverse_points(points)
2320+
sorted_points
23152321
)
23162322
cv2.fillPoly(
23172323
seg_mask_image,
@@ -2323,9 +2329,11 @@ def __export_index_color_image(
23232329
else:
23242330
# Reverse hollow points for opencv because these points are
23252331
# counterclockwise
2326-
cv_draw_points = self.__get_cv_draw_points(
2327-
utils.reverse_points(points)
2332+
reverse_points = utils.reverse_points(points)
2333+
sorted_points = utils.sort_segmentation_points(
2334+
reverse_points
23282335
)
2336+
cv_draw_points = self.__get_cv_draw_points(sorted_points)
23292337
cv2.fillPoly(
23302338
seg_mask_image,
23312339
[cv_draw_points],
@@ -2362,7 +2370,8 @@ def __export_index_color_image(
23622370

23632371
def __get_cv_draw_points(self, points: List[int]) -> List[int]:
23642372
"""
2365-
Convert points to pillow draw points. Diagonal points are not supported.
2373+
Convert points to pillow draw points. Diagonal points are not supported
2374+
Annotation clockwise draw.
23662375
"""
23672376
x_points = []
23682377
x_points.append(points[0])

fastlabel/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,31 @@ def reverse_points(points: List[int]) -> List[int]:
8585
return reversed_points
8686

8787

88+
def sort_segmentation_points(points: List[int]) -> List[int]:
89+
"""
90+
e.g.)
91+
[1, 2, 1, 1, 2, 1, 2, 2, 1, 2] => [1, 1, 2, 1, 2, 2, 1, 2, 1, 1]
92+
"""
93+
points_array = np.array(points).reshape((-1, 2))[1:]
94+
base_point_index = 0
95+
points_list = points_array.tolist()
96+
for index, val in enumerate(points_list):
97+
if index == 0:
98+
continue
99+
if (
100+
val[1] <= points_list[base_point_index][1] and val[0] <= points_list[base_point_index][0]
101+
):
102+
base_point_index = index
103+
new_points_array = np.vstack(
104+
[
105+
points_array[base_point_index:],
106+
points_array[:base_point_index],
107+
np.array([points_array[base_point_index]]),
108+
]
109+
)
110+
return new_points_array.ravel().tolist()
111+
112+
88113
def is_clockwise(points: list) -> bool:
89114
"""
90115
points: [x1, y1, x2, y2, x3, y3, ... xn, yn]

0 commit comments

Comments
 (0)