-
Notifications
You must be signed in to change notification settings - Fork 11
/
CIFAR.py
134 lines (114 loc) · 5.73 KB
/
CIFAR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
import torchvision.transforms as transforms
from datasets import SubDataset, AbstractDomainInterface, ExpandRGBChannels
from torchvision import datasets
def filter_indices(dataset, indices, filter_label):
accept = []
for ind in indices:
_, label = dataset[ind]
if label not in filter_label:
accept.append(ind)
return torch.IntTensor(accept)
class CIFAR10(AbstractDomainInterface):
"""
CIFAR10: 50,000 train + 10,000 test. (3x32x32)
D1: (40,000 train + 10,000 valid) + (10,000 test)
D2 (Dv, Dt): 50,000 valid + 10,000 test.
"""
def __init__(self):
super(CIFAR10, self).__init__()
im_transformer = transforms.Compose([transforms.ToTensor()])
root_path = './workspace/datasets/cifar10'
self.D1_train_ind = torch.arange(0, 40000).int()
self.D1_valid_ind = torch.arange(40000, 50000).int()
self.D1_test_ind = torch.arange(0, 10000).int()
self.D2_valid_ind = torch.arange(0, 50000).int()
self.D2_test_ind = torch.arange(0, 10000).int()
self.ds_train = datasets.CIFAR10(root_path,
train=True,
transform=im_transformer,
download=True)
self.ds_test = datasets.CIFAR10(root_path,
train=False,
transform=im_transformer,
download=True)
def get_D1_train(self):
return SubDataset(self.name, self.ds_train, self.D1_train_ind)
def get_D1_valid(self):
return SubDataset(self.name, self.ds_train, self.D1_valid_ind, label=0)
def get_D1_test(self):
return SubDataset(self.name, self.ds_test, self.D1_test_ind, label=0)
def get_D2_valid(self, D1):
assert self.is_compatible(D1)
return SubDataset(self.name, self.ds_train, self.D2_valid_ind, label=1, transform=D1.conformity_transform())
def get_D2_test(self, D1):
assert self.is_compatible(D1)
return SubDataset(self.name, self.ds_test, self.D2_test_ind, label=1, transform=D1.conformity_transform())
def conformity_transform(self):
return transforms.Compose([ExpandRGBChannels(),
transforms.ToPILImage(),
transforms.Resize((32, 32)),
transforms.ToTensor(),
])
class CIFAR100(AbstractDomainInterface):
"""
CIFAR100: 50,000 train + 10,000 test. (3x32x32)
D1: (40,000 train + 10,000 valid) + (10,000 test)
D2 (Dv , Dt): 50,000 valid + 10,000 test.
"""
def __init__(self):
super(CIFAR100, self).__init__()
im_transformer = transforms.Compose([transforms.ToTensor()])
root_path = './workspace/datasets/cifar100'
self.D1_train_ind = torch.arange(0, 40000).int()
self.D1_valid_ind = torch.arange(40000, 50000).int()
self.D1_test_ind = torch.arange(0, 10000).int()
self.D2_valid_ind = torch.arange(0, 50000).int()
self.D2_test_ind = torch.arange(0, 10000).int()
self.ds_train = datasets.CIFAR100(root_path,
train=True,
transform=im_transformer,
download=True)
self.ds_test = datasets.CIFAR100(root_path,
train=False,
transform=im_transformer,
download=True)
"""
TinyImagenet:
6:bee with 38:bee
21:chimpanzee with 55:chimpanzee, chimp, Pan troglodytes
24:cockroach with 41:cockroach, roach
43:lion with 34:lion, king of beasts, Panthera leo
51:mushroom with 185:mushroom
53:orange with 186:orange
61:plate with 177:plate
77:snail with 15:snail
89:tractor with 164:tractor
"""
self.filter_rules = {
'TinyImagenet': [6, 21, 24, 43, 51, 53, 61, 77, 89]
}
def get_D1_train(self):
return SubDataset(self.name, self.ds_train, self.D1_train_ind)
def get_D1_valid(self):
return SubDataset(self.name, self.ds_train, self.D1_valid_ind, label=0)
def get_D1_test(self):
return SubDataset(self.name, self.ds_test, self.D1_test_ind, label=0)
def get_D2_valid(self, D1):
assert self.is_compatible(D1)
target_indices = self.D2_valid_ind
if self.filter_rules.has_key(D1.name):
target_indices = filter_indices(self.ds_train, target_indices, self.filter_rules[D1.name])
return SubDataset(self.name, self.ds_train, target_indices, label=1, transform=D1.conformity_transform())
def get_D2_test(self, D1):
assert self.is_compatible(D1)
target_indices = self.D2_test_ind
if self.filter_rules.has_key(D1.name):
target_indices = filter_indices(self.ds_test, target_indices, self.filter_rules[D1.name])
return SubDataset(self.name, self.ds_test, target_indices, label=1, transform=D1.conformity_transform())
def conformity_transform(self):
return transforms.Compose([ExpandRGBChannels(),
transforms.ToPILImage(),
transforms.Resize((32, 32)),
transforms.ToTensor(),
])