Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions model/seg/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from io import BytesIO
import json
import random

import numpy as np
import requests
from torch.utils.data import Dataset
from torchvision.transforms import RandomCrop
from PIL import Image

from model.seg.imutils import vstrips_process, hstrips_process


class BoundingBoxDataset(Dataset):
def __init__(self, json_file: str):
with open(json_file, "r") as f:
self.data_json = json.load(f)

def __len__(self):
return len(self.data_json) # Return the total number of images

def __getitem__(self, idx):
# TODO: This should actually be multiple bboxes, not just a single one
bbox = self.data_json["annotations"][idx]["bbox"]
response = requests.get(self.data_json["images"][idx]["coco_url"])
img = Image.open(BytesIO(response.content))

crop_dims = RandomCrop.get_params(img, (256, 256))
crop_dims = [
crop_dims[0],
crop_dims[1],
crop_dims[0] + crop_dims[2],
crop_dims[1] + crop_dims[3],
]
img = img.crop(crop_dims)

shifted_bbox = [
bbox[0] - crop_dims[0],
bbox[1] - crop_dims[1],
bbox[2],
bbox[3],
]
is_vert = random.choice([True, False])
strips, labels = self.create_strips_and_labels(img, shifted_bbox, is_vert)
return is_vert, strips, labels

def create_strips_and_labels(self, img, bbox, is_vert):
if is_vert:
return hstrips_process(img, bbox)
return vstrips_process(img, bbox)


if __name__ == "__main__":
dataset = BoundingBoxDataset("instances_minitrain2017.json")
is_vertical, strips, labels = dataset[0]

bound_strips = []
for strip, label in zip(strips, labels):
if label:
bound_strips.append(np.zeros_like(strip))
else:
bound_strips.append(strip)

if is_vertical:
img = Image.fromarray(np.vstack(bound_strips))
else:
img = Image.fromarray(np.hstack(bound_strips))
img.show()
17 changes: 17 additions & 0 deletions model/seg/dloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytorch_lightning as pl
from torch.utils.data import DataLoader

from model.seg.data import BoundingBoxDataset


class SegmentationDataModule(pl.LightningDataModule):
def __init__(self, json_file: str, batch_size: int = 32):
super().__init__()
self.json_file_path = json_file
self.batch_size = batch_size

def setup(self, stage: str):
self.train_data = BoundingBoxDataset(self.json_file_path)

def train_dataloader(self):
return DataLoader(self.train_data, batch_size=self.batch_size)
60 changes: 60 additions & 0 deletions model/seg/imutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Tuple
import numpy as np
from PIL import Image


def vstrips_process(
img: Image.Image | np.ndarray, bbox: list[float, float, float, float], pad: int = 5
) -> Tuple[list[np.ndarray], list[bool]]:
if isinstance(img, Image.Image):
img = np.array(img)

strips, labels = [], []
for idx in range(0, img.shape[1], pad * 2):
# if this chunk overlaps with the boundary of the bbox then the model should predict a segmentation here
if (
max(idx - pad, 0) < bbox[0] < idx + pad
or max(idx - pad, 0) < bbox[0] + bbox[2] < idx + pad
):
labels.append(1)
# if debug:
# strips.append(img[:, max(idx - pad, 0) : idx + pad])
else:
labels.append(0)
# if debug:
# size = idx + pad - max(idx - pad, 0)
# strips.append(np.ones((img.shape[0], size, 3)) * 0)
strips.append(img[:, max(idx - pad, 0) : idx + pad])
return strips, labels


def hstrips_process(
img: Image.Image | np.ndarray, bbox: list[float, float, float, float], pad: int = 5
) -> Tuple[list[np.ndarray], list[bool]]:
if isinstance(img, Image.Image):
img = np.array(img)

# If set to debug, strips returns a list of
strips, labels = [], []
for idx in range(0, img.shape[0], pad * 2):
# if this chunk overlaps with the boundary of the bbox then the model should predict a segmentation here
if (
max(idx - pad, 0) < bbox[1] < idx + pad
or max(idx - pad, 0) < bbox[1] + bbox[3] < idx + pad
):
labels.append(1)
# if debug:
# strips.append(img[max(idx - pad, 0) : idx + pad])
else:
labels.append(0)
# size = idx + pad - max(idx - pad, 0)
# strips.append(np.ones((size, img.shape[1], 3)) * 0)
strips.append(img[max(idx - pad, 0) : idx + pad])
return strips, labels


# def debug_strips(img: np.ndarray):
# vstrips = vstrips_intersects(np.array(img), bbox_shifted)
# hstrips = hstrips_intersects(np.array(img), bbox_shifted)
# plt.imshow(np.hstack(vstrips).astype(np.uint8))
# plt.imshow(np.vstack(hstrips).astype(np.uint8))
24 changes: 24 additions & 0 deletions model/seg/rseg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F
import torch


class SegmentationModel(pl.LightningModule):
def __init__(self):
super(SegmentationModel, self).__init__()
self.attention = nn.MultiheadAttention(embed_dim=1024, num_heads=1)
self.out_layer = nn.Linear(1024, 1)

def forward(self, x):
x = self.attention(x)
return self.out_layer(x)

def training_step(self, batch):
x, y = batch
y_hat = self(x)
loss = F.binary_cross_entropy(y_hat, y)
return loss

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
11 changes: 11 additions & 0 deletions model/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytorch_lightning as pl

from model.seg.dloader import SegmentationDataModule
from model.seg.rseg import SegmentationModel

# Training
model = SegmentationModel()
trainer = pl.Trainer()

imagenet = SegmentationDataModule()
trainer.fit(model, datamodule=imagenet)
283 changes: 260 additions & 23 deletions nb.ipynb

Large diffs are not rendered by default.