Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
JcccKing committed May 25, 2023
0 parents commit 689ebdd
Show file tree
Hide file tree
Showing 47 changed files with 5,541 additions and 0 deletions.
42 changes: 42 additions & 0 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import importlib
import torch.utils.data
from data.base_dataset import BaseDataset


def find_dataset_using_name(dataset_name):
dataset_filename = "data." + dataset_name + "_dataset"
datasetlib = importlib.import_module(dataset_filename)
dataset = None
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() \
and issubclass(cls, BaseDataset):
dataset = cls
if dataset is None:
raise ValueError("In %s.py, there should be a subclass of BaseDataset "
"with class name that matches %s in lowercase." %
(dataset_filename, target_dataset_name))
return dataset


def get_option_setter(dataset_name):
dataset_class = find_dataset_using_name(dataset_name)
return dataset_class.modify_commandline_options


def create_dataloader(opt):
dataset = find_dataset_using_name(opt.dataset_mode)
instance = dataset()
instance.initialize(opt)
print("Dataset [%s] of size %d was created" % (type(instance).__name__, len(instance)))
dataloader = torch.utils.data.DataLoader(
instance,
batch_size=opt.batchSize,
shuffle=(opt.phase=='train'),
num_workers=int(opt.nThreads),
drop_last=(opt.phase=='train')
)
return dataloader
135 changes: 135 additions & 0 deletions data/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random


class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()

@staticmethod
def modify_commandline_options(parser, is_train):
return parser

def initialize(self, opt):
pass


def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.preprocess_mode == 'resize_and_crop':
new_h = new_w = opt.load_size
elif opt.preprocess_mode == 'scale_width_and_crop':
new_w = opt.load_size
new_h = opt.load_size * h // w
elif opt.preprocess_mode == 'scale_shortside_and_crop':
ss, ls = min(w, h), max(w, h) # shortside and longside
width_is_shorter = w == ss
ls = int(opt.load_size * ls / ss)
new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss)

x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))

flip = random.random() > 0.5
return {'crop_pos': (x, y), 'flip': flip}


def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True):
transform_list = []
if opt.dataset_mode == 'flickr' and method == Image.NEAREST:
transform_list.append(transforms.Lambda(lambda img: __add1(img)))
if 'resize' in opt.preprocess_mode:
osize = [opt.load_size, opt.load_size]
transform_list.append(transforms.Resize(osize, interpolation=method))
elif 'scale_width' in opt.preprocess_mode:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
elif 'scale_shortside' in opt.preprocess_mode:
transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method)))

if 'crop' in opt.preprocess_mode:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))

if opt.preprocess_mode == 'none':
base = 32
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))

if opt.preprocess_mode == 'fixed':
w = opt.crop_size
h = round(opt.crop_size / opt.aspect_ratio)
transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method)))

if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

if opt.isTrain and 'rotate' in params.keys():
transform_list.append(transforms.Lambda(lambda img: __rotate(img, params['rotate'], method)))

if toTensor:
transform_list += [transforms.ToTensor()]

if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)


def __resize(img, w, h, method=Image.BICUBIC):
return img.resize((w, h), method)


def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img
return img.resize((w, h), method)


def __scale_width(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), method)


def __scale_shortside(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
ss, ls = min(ow, oh), max(ow, oh) # shortside and longside
width_is_shorter = ow == ss
if (ss == target_width):
return img
ls = int(target_width * ls / ss)
nw, nh = (ss, ls) if width_is_shorter else (ls, ss)
return img.resize((nw, nh), method)


def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
return img.crop((x1, y1, x1 + tw, y1 + th))


def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img


def __rotate(img, deg, method=Image.BICUBIC):
return img.rotate(deg, resample=method)


def __add1(img):
return Image.fromarray(np.array(img) + 1)
145 changes: 145 additions & 0 deletions data/celebahqedge_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os
import cv2
import torch
import numpy as np
from PIL import Image
from skimage import feature
from data.pix2pix_dataset import Pix2pixDataset
from data.base_dataset import get_params, get_transform


class CelebAHQEdgeDataset(Pix2pixDataset):
# hair, skin, l_brow, r_blow, l_eye, r_eye, l_ear, r_ear, nose, u_lip, mouth, l_lip, neck,
@staticmethod
def modify_commandline_options(parser, is_train):
parser = Pix2pixDataset.modify_commandline_options(parser, is_train)
parser.set_defaults(preprocess_mode='resize_and_crop')
parser.set_defaults(no_pairing_check=True)
if is_train:
parser.set_defaults(load_size=286)
else:
parser.set_defaults(load_size=256)
parser.set_defaults(crop_size=256)
parser.set_defaults(display_winsize=256)
parser.set_defaults(label_nc=15)
parser.set_defaults(contain_dontcare_label=False)
parser.set_defaults(cache_filelist_read=False)
parser.set_defaults(cache_filelist_write=False)
return parser

def get_paths(self, opt):
if opt.phase == 'train':
fd = open(os.path.join(opt.dataroot, 'train.txt'))
lines = fd.readlines()
fd.close()
elif opt.phase == 'test':
fd = open(os.path.join(opt.dataroot, 'val.txt'))
lines = fd.readlines()
fd.close()

image_paths = []
label_paths = []
for i in range(len(lines)):
image_paths.append(os.path.join('CelebA-HQ-img', lines[i].strip() + '.jpg'))
subfolder = str(int(lines[i].strip()) // 2000)
label_paths.append(
os.path.join('CelebAMask-HQ-mask-anno', subfolder, lines[i].strip().zfill(5) + '_{}.png'))

return label_paths, image_paths

def get_ref(self, opt):
extra = ''
if opt.phase == 'test':
extra = '_test'
with open(os.path.join(opt.dataroot, 'celebahq_ref{}.txt'.format(extra))) as fd:
lines = fd.readlines()
ref_dict = {}
for i in range(len(lines)):
items = lines[i].strip().split(',')
key = items[0]
if opt.phase == 'test':
val = items[1:]
else:
val = [items[1], items[-1]]
ref_dict[key] = val
train_test_folder = ('', '')
return ref_dict, train_test_folder

def get_edges(self, edge, t):
edge[:, 1:] = edge[:, 1:] | (t[:, 1:] != t[:, :-1])
edge[:, :-1] = edge[:, :-1] | (t[:, 1:] != t[:, :-1])
edge[1:, :] = edge[1:, :] | (t[1:, :] != t[:-1, :])
edge[:-1, :] = edge[:-1, :] | (t[1:, :] != t[:-1, :])
return edge

def get_label_tensor(self, path):
inner_parts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'l_ear', 'r_ear', 'nose', 'u_lip', 'mouth',
'l_lip', 'eye_g', 'hair']
img_path = self.labelpath_to_imgpath(path)
img = Image.open(img_path).resize((self.opt.load_size, self.opt.load_size), resample=Image.BILINEAR)
params = get_params(self.opt, img.size)
transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
transform_img = get_transform(self.opt, params, method=Image.BILINEAR, normalize=False)

inner_label = np.ones(img.size, dtype=np.uint8)
edges = np.zeros(img.size, dtype=np.uint8)
tensors_dist = 0
e = 1
for part in inner_parts:
edge = np.zeros(img.size, dtype=np.uint8) # this for distance transform map on each facial part
if os.path.exists(path.format(part)):
part_label = Image.open(path.format(part)).convert('L').resize((self.opt.load_size, self.opt.load_size),
resample=Image.NEAREST)
part_label = np.array(part_label)
if part == 'hair':
inner_label[part_label == 255] = 1
else:
inner_label[part_label == 255] = 0
edges = self.get_edges(edges, part_label)
edge = self.get_edges(edge, part_label)
im_dist = cv2.distanceTransform(255 - edge * 255, cv2.DIST_L1, 3)
im_dist = np.clip((im_dist / 3), 0, 255).astype(np.uint8)
tensor_dist = transform_img(Image.fromarray(im_dist))
tensors_dist = tensor_dist if e == 1 else torch.cat([tensors_dist, tensor_dist])
e += 1

# canny edge for background
canny_edges = feature.canny(np.array(img.convert('L')))
canny_edges = canny_edges * inner_label

edges_all = edges + canny_edges
edges_all[edges_all > 1] = 1
tensor_edges_all = transform_label(Image.fromarray(edges_all * 255))
edges[edges > 1] = 1
tensor_edges = transform_label(Image.fromarray(edges * 255))

label_tensor = torch.cat((tensor_edges_all, tensors_dist, tensor_edges), dim=0)
return label_tensor, params

def imgpath_to_labelpath(self, path):
root, name = path.split('CelebA-HQ-img/')
subfolder = str(int(name.split('.')[0]) // 2000)
label_path = os.path.join(root, 'CelebAMask-HQ-mask-anno', subfolder, name.split('.')[0].zfill(5) + '_{}.png')
return label_path

def labelpath_to_imgpath(self, path):
root = path.replace('\\', '/').split('CelebAMask-HQ-mask-anno/')[0]
name = os.path.basename(path).split('_')[0]
img_path = os.path.join(root, 'CelebA-HQ-img', str(int(name)) + '.jpg')
return img_path

# In ADE20k, 'unknown' label is of value 0.
# Change the 'unknown' label to the last label to match other datasets.
# def postprocess(self, input_dict):
# label = input_dict['label']
# label = label - 1
# label[label == -1] = self.opt.label_nc
# input_dict['label'] = label
# if input_dict['label_ref'] is not None:
# label_ref = input_dict['label_ref']
# label_ref = label_ref - 1
# label_ref[label_ref == -1] = self.opt.label_nc
# input_dict['label_ref'] = label_ref
Loading

0 comments on commit 689ebdd

Please sign in to comment.