-
Notifications
You must be signed in to change notification settings - Fork 0
/
PosteriorNetwork.py
255 lines (220 loc) · 12.5 KB
/
PosteriorNetwork.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
import numpy as np
import torch
from torch import nn
from torch import autograd
from torch.distributions.dirichlet import Dirichlet
from src.architectures.linear_sequential import linear_sequential
from src.architectures.convolution_linear_sequential import convolution_linear_sequential
from src.architectures.vgg_sequential import vgg16_bn
from src.architectures.resnet_sequential import resnet18
from torchvision.models import resnet18, resnet34, resnet50, densenet121, vgg16
from src.architectures.alexnet_sequential import alexnet
from src.posterior_networks.NormalizingFlowDensity import NormalizingFlowDensity
from src.posterior_networks.BatchedNormalizingFlowDensity import BatchedNormalizingFlowDensity
from src.posterior_networks.MixtureDensity import MixtureDensity
__budget_functions__ = {'one': lambda N: torch.ones_like(N),
'log': lambda N: torch.log(N + 1.),
'id': lambda N: N,
'id_normalized': lambda N: N / N.sum(),
'exp': lambda N: torch.exp(N),
'parametrized': lambda N: torch.nn.Parameter(torch.ones_like(N).to(N.device))}
class PosteriorNetwork(nn.Module):
def __init__(self, N, # Count of data from each class in training set. list of ints
input_dims, # Input dimension. list of ints
output_dim, # Output dimension. int
hidden_dims=[64,64,64], # Hidden dimensions. list of ints, changed
kernel_dim=None, # Kernel dimension if conv architecture. int
latent_dim=6, # Latent dimension. int
architecture='linear', # Encoder architecture name. int
k_lipschitz=None, # Lipschitz constant. float or None (if no lipschitz)
no_density=False, # Use density estimation or not. boolean
density_type='radial_flow', # Density type. string
n_density=8, # Number of density components. int
budget_function='id', # Budget function name applied on class count. name
batch_size=64, # Batch size. int
lr=1e-3, # Learning rate. float
loss='UCE', # Loss name. string
regr=1e-5, # Regularization factor in Bayesian loss. float
seed=0,
drop_prob=0.5): # Random seed for init. int
super().__init__()
torch.cuda.manual_seed(seed)
#torch.set_default_tensor_type(torch.DoubleTensor)
# Architecture parameters
self.input_dims, self.output_dim, self.hidden_dims, self.kernel_dim, self.latent_dim = input_dims, output_dim, hidden_dims, kernel_dim, latent_dim
self.k_lipschitz = k_lipschitz
self.no_density, self.density_type, self.n_density = no_density, density_type, n_density
if budget_function in __budget_functions__:
self.N, self.budget_function = __budget_functions__[budget_function](N), budget_function
else:
raise NotImplementedError
print(self.N)
# Training parameters
self.batch_size, self.lr = batch_size, lr
self.loss, self.regr = loss, regr
# Encoder -- Feature selection
if architecture == 'linear':
self.sequential = linear_sequential(input_dims=self.input_dims,
hidden_dims=self.hidden_dims,
output_dim=self.latent_dim,
k_lipschitz=self.k_lipschitz)
elif architecture == 'conv':
assert len(input_dims) == 3
self.sequential = convolution_linear_sequential(input_dims=self.input_dims,
linear_hidden_dims=self.hidden_dims,
conv_hidden_dims=[64, 64, 64],
output_dim=self.latent_dim,
kernel_dim=self.kernel_dim,
k_lipschitz=self.k_lipschitz)
self.kl_linear = linear_sequential(input_dims=latent_dim,
hidden_dims=[],
output_dim=self.latent_dim,
k_lipschitz=self.k_lipschitz)
elif architecture == 'vgg':
self.sequential = nn.Sequential(vgg16(pretrained=True),nn.Dropout(drop_prob), nn.Linear(1000, 128), nn.ReLU(True)) ## this the 128-d embeding
self.kl_linear = linear_sequential(input_dims=128,
hidden_dims=self.hidden_dims,
output_dim=self.latent_dim,
k_lipschitz=self.k_lipschitz)
elif architecture == 'resnet':
self.sequential = resnet50(pretrained=True) # 128
#print(self.sequential)
num_ftrs = self.sequential.fc.in_features #replacing fc layer with sequential.fc
self.sequential.fc = nn.Sequential(nn.Dropout(drop_prob), nn.Linear(num_ftrs, 128), nn.ReLU(True),
nn.Dropout(drop_prob), nn.Linear(128, 128), nn.ReLU(True)) ## this the 128-d embeding
self.kl_linear = linear_sequential(input_dims=128,
hidden_dims=self.hidden_dims,
output_dim=self.latent_dim,
k_lipschitz=self.k_lipschitz)
elif architecture == 'densenet':
self.sequential = densenet121(pretrained=True) # 128
#print(self.sequential)
num_ftrs = self.sequential.classifier.in_features #replacing fc layer with sequential.fc
self.sequential.classifier = nn.Sequential(nn.Dropout(drop_prob), nn.Linear(num_ftrs, 128), nn.ReLU(True),
nn.Dropout(drop_prob), nn.Linear(128, 128), nn.ReLU(True)) ## this the 128-d embeding
self.kl_linear = linear_sequential(input_dims=128,
hidden_dims=self.hidden_dims,
output_dim=self.latent_dim,
k_lipschitz=self.k_lipschitz)
else:
raise NotImplementedError
self.batch_norm = nn.BatchNorm1d(num_features=self.latent_dim)
self.cls_fc = nn.Linear(128, self.output_dim)
self.projection = nn.Sequential(nn.Linear(128, 64, bias=False), nn.BatchNorm1d(64),
nn.ReLU(inplace=True), nn.Linear(64, 32, bias=True))
# Normalizing Flow -- Normalized density on latent space
if self.density_type == 'planar_flow':
self.density_estimation = nn.ModuleList([NormalizingFlowDensity(dim=self.latent_dim, flow_length=n_density, flow_type=self.density_type) for c in range(self.output_dim)])
elif self.density_type == 'radial_flow':
self.density_estimation = nn.ModuleList([NormalizingFlowDensity(dim=self.latent_dim, flow_length=n_density, flow_type=self.density_type) for c in range(self.output_dim)])
elif self.density_type == 'batched_radial_flow':
self.density_estimation = BatchedNormalizingFlowDensity(c=self.output_dim, dim=self.latent_dim, flow_length=n_density, flow_type=self.density_type.replace('batched_', ''))
elif self.density_type == 'iaf_flow':
self.density_estimation = nn.ModuleList([NormalizingFlowDensity(dim=self.latent_dim, flow_length=n_density, flow_type=self.density_type) for c in range(self.output_dim)])
elif self.density_type == 'normal_mixture':
self.density_estimation = nn.ModuleList([MixtureDensity(dim=self.latent_dim, n_components=n_density, mixture_type=self.density_type) for c in range(self.output_dim)])
else:
raise NotImplementedError
self.softmax = nn.Softmax(dim=-1)
# Optimizer
ignored_params = list(map(id, self.density_estimation.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, self.parameters())
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
def forward(self, input, label, return_output='hard', compute_loss=True):
batch_size = input.size(0)
if self.N.device != input.device:
self.N = self.N.to(input.device)
if self.budget_function == 'parametrized':
N = self.N / self.N.sum()
else:
N = self.N
# Forward
zk = self.sequential(input) #128 --> to linear classifer/protector/normflow
if self.no_density: # Ablated model without density estimation
#print('No density!!!!!!!!!')
logits = self.cls_fc(zk)
alpha = torch.exp(logits)
prob_pred = self.softmax(logits)
#print('prob_pred',prob_pred)
else: # Full model with density estimation
#print('With density!!!!!!!!!')
zk2 = self.kl_linear(zk)
zk2 = self.batch_norm(zk2)
log_q_zk = torch.zeros((batch_size, self.output_dim)).to(zk2.device.type)
alpha = torch.zeros((batch_size, self.output_dim)).to(zk2.device.type)
if isinstance(self.density_estimation, nn.ModuleList):
for c in range(self.output_dim):
log_p = self.density_estimation[c].log_prob(zk2) #This should not be large-negative
log_q_zk[:, c] = log_p
#print('log p:', log_p)
alpha[:, c] = 1. + (N[c] * torch.exp(log_q_zk[:, c]))
else:
log_q_zk = self.density_estimation.log_prob(zk2)
alpha = 1. + (N[:, None] * torch.exp(log_q_zk)).permute(1, 0)
pass
prob_pred = torch.nn.functional.normalize(alpha, p=1)
output_pred = self.predict(prob_pred)
# Loss
if compute_loss:
if self.loss == 'CE':
self.grad_loss = self.CE_loss(prob_pred, label)
elif self.loss == 'UCE':
self.grad_loss = self.UCE_loss(alpha, label)
elif self.loss == 'CL':
pass ##defined outsided
else:
raise NotImplementedError
if return_output == 'hard':
return output_pred
elif return_output == 'soft':
return prob_pred
elif return_output == 'alpha':
return alpha
elif return_output == 'latent':
return zk
elif self.loss == 'CL' and return_output == 'projection':
out_proj = self.projection(zk)
return out_proj
else:
raise AssertionError
def CE_loss(self, prob_pred, label):
with autograd.detect_anomaly():
CE_loss = - torch.sum(label.squeeze() * torch.log(prob_pred))
return CE_loss
def UCE_loss(self, alpha, label):
with autograd.detect_anomaly():
alpha_0 = alpha.sum(1).unsqueeze(-1).repeat(1, self.output_dim)
entropy_reg = Dirichlet(alpha).entropy()
UCE_loss = torch.sum(label * (torch.digamma(alpha_0) - torch.digamma(alpha))) - self.regr * torch.sum(entropy_reg)
#print('UCE_loss:',label, alpha, UCE_loss)
return UCE_loss
def step(self):
self.optimizer.zero_grad()
self.grad_loss.backward()
self.optimizer.step()
def predict(self, p):
output_pred = torch.max(p, dim=-1)[1]
return output_pred
def feature_list(self,input):
#only used for deterministic model
if self.N.device != input.device:
self.N = self.N.to(input.device)
out_list = []
out = self.sequential(input)
out_list.append(out)
# out = self.sequential(out)
# out_list.append(out)
out = self.cls_fc(out)
out_list.append(out)
return out_list
# function to extact a specific feature
def intermediate_forward(self, input, layer_index):
#only used for deterministic model
if self.N.device != input.device:
self.N = self.N.to(input.device)
if layer_index == 0:
out = self.sequential(input)
elif layer_index == 1:
out = self.sequential(input)
out = self.cls_fc(out)
return out