Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
642 changes: 642 additions & 0 deletions Query_playground.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.utils.data
import torchvision

from .coco import build as build_coco
from .coco import build_coco, build_lwll


def get_coco_api_from_dataset(dataset):
Expand All @@ -22,4 +22,6 @@ def build_dataset(image_set, args):
# to avoid making panopticapi required for coco
from .coco_panoptic import build as build_coco_panoptic
return build_coco_panoptic(image_set, args)
if args.dataset_file == 'lwll':
return build_lwll(image_set, args)
raise ValueError(f'dataset {args.dataset_file} not supported')
67 changes: 65 additions & 2 deletions datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
"""
import os
from pathlib import Path
from PIL import Image

import numpy as np

import torch
import torch.utils.data
Expand All @@ -20,16 +24,64 @@ def __init__(self, img_folder, ann_file, transforms, return_masks):
self._transforms = transforms
self.prepare = ConvertCocoPolysToMask(return_masks)


def __getitem__(self, idx):
img, target = super(CocoDetection, self).__getitem__(idx)
image_id = self.ids[idx]
target = {'image_id': image_id, 'annotations': target}
img, target = self.prepare(img, target)
if self._transforms is not None:
img, target = self._transforms(img, target)

return img, target


class CocoDetection_query(torchvision.datasets.CocoDetection):
"""
this is the torchvision dataset which provides a query image to match the single category
of bounding boxes which is returned on each iteration.
"""
def __init__(self, img_folder, ann_file, transforms, return_masks):
super(CocoDetection_query, self).__init__(img_folder, ann_file)
self.image_transforms = transforms
self.query_transforms = make_coco_transforms('query')
self.prepare = ConvertCocoPolysToMask(return_masks)

# Query SEt
self.query_set = {k: [] for k in self.coco.cats.keys()}
for k, v in self.coco.anns.items():
self.query_set[v['category_id']].append({'image_id': v['image_id'], 'bbox': v['bbox'], 'id': v['id']})

def __getitem__(self, idx):

img, target = super(CocoDetection_query, self).__getitem__(idx)

# pick one class per image and limit targets to that class
cat_id = np.random.choice([t['category_id'] for t in target])
target = [t for t in target if t['category_id'] == cat_id]
# get query image (not from same image)
image_id = self.ids[idx]
annot_id = image_id
while annot_id == image_id:
query = np.random.choice(self.query_set[cat_id])
annot_id = query['image_id']
path = self.coco.loadImgs(annot_id)[0]['file_name']
query_img = np.array(Image.open(os.path.join(self.root, path)).convert('RGB'))

# crop to only bbox
x, y, w, h = query['bbox']
query_img = query_img[int(y):int(y + h), int(x):int(x + w), :]

# process org image and target
target = {'image_id': image_id, 'annotations': target}
img, target = self.prepare(img, target)
if self.image_transforms is not None:
img, target = self.image_transforms(img, target)
if self.query_transforms is not None:
query_img, _ = self.query_transforms(query_img, target)
return img, target, query_img


def convert_coco_poly_to_mask(segmentations, height, width):
masks = []
for polygons in segmentations:
Expand Down Expand Up @@ -135,16 +187,20 @@ def make_coco_transforms(image_set):
normalize,
])

if image_set == 'val':
if (image_set == 'val' or image_set == 'test'):
return T.Compose([
T.RandomResize([800], max_size=1333),
normalize,
])

if image_set == 'query':
return T.Compose([
normalize,
])
raise ValueError(f'unknown {image_set}')


def build(image_set, args):
def build_coco(image_set, args):
root = Path(args.coco_path)
assert root.exists(), f'provided COCO path {root} does not exist'
mode = 'instances'
Expand All @@ -156,3 +212,10 @@ def build(image_set, args):
img_folder, ann_file = PATHS[image_set]
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks)
return dataset

def build_lwll(image_set, args):
prob = os.path.basename(os.path.dirname(args.coco_path))
img_folder = os.path.join(args.coco_path, f'{prob}_full', str(image_set.split('_')[0]))
ann_file = os.path.join(args.coco_path, 'labels_full', 'coco', f'coco_{image_set}.json')
dataset = CocoDetection_query(img_folder, ann_file, transforms=make_coco_transforms(image_set.split('_')[0]), return_masks=args.masks)
return dataset
33 changes: 15 additions & 18 deletions engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,22 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
header = 'Epoch: [{}]'.format(epoch)
print_freq = 10

for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
for samples, targets, query in metric_logger.log_every(data_loader, print_freq, header):
samples = samples.to(device)
query = query.to(device)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

outputs = model(samples)
loss_dict = criterion(outputs, targets)
outputs = model(samples, query)
loss_dict = criterion(outputs, targets)
weight_dict = criterion.weight_dict
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

# reduce losses over all GPUs for logging purposes
loss_dict_reduced = utils.reduce_dict(loss_dict)
loss_dict_reduced_unscaled = {f'{k}_unscaled': v
for k, v in loss_dict_reduced.items()}
loss_dict_reduced_scaled = {k: v * weight_dict[k]
for k, v in loss_dict_reduced.items() if k in weight_dict}
losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())

loss_value = losses_reduced_scaled.item()
loss_dict_reduced = utils.reduce_dict(loss_dict)
loss_dict_reduced_unscaled = {f'{k}_unscaled': v for k, v in loss_dict_reduced.items()}
loss_dict_reduced_scaled = {k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict}
losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
loss_value = losses_reduced_scaled.item()

if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
Expand Down Expand Up @@ -85,20 +83,19 @@ def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, out
output_dir=os.path.join(output_dir, "panoptic_eval"),
)

for samples, targets in metric_logger.log_every(data_loader, 10, header):
for samples, targets, query in metric_logger.log_every(data_loader, 10, header):
samples = samples.to(device)
query = query.to(device)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

outputs = model(samples)
outputs = model(samples, query)
loss_dict = criterion(outputs, targets)
weight_dict = criterion.weight_dict

# reduce losses over all GPUs for logging purposes
loss_dict_reduced = utils.reduce_dict(loss_dict)
loss_dict_reduced_scaled = {k: v * weight_dict[k]
for k, v in loss_dict_reduced.items() if k in weight_dict}
loss_dict_reduced_unscaled = {f'{k}_unscaled': v
for k, v in loss_dict_reduced.items()}
loss_dict_reduced = utils.reduce_dict(loss_dict)
loss_dict_reduced_scaled = {k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict}
loss_dict_reduced_unscaled = {f'{k}_unscaled': v for k, v in loss_dict_reduced.items()}
metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
**loss_dict_reduced_scaled,
**loss_dict_reduced_unscaled)
Expand Down
Loading