Skip to content

Commit

Permalink
change train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ucaswangls committed Aug 24, 2022
1 parent bfc56b1 commit 69feae9
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 21 deletions.
9 changes: 3 additions & 6 deletions image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ def shift(inputs, step=2):
for i in range(nC):
output[i,:,step*i:step*i+col] = inputs[i,:,:]
return output

def batch_shift(inputs, step=2):
[b,nC, row, col] = inputs.shape
output = torch.zeros(b,nC, row, col+(nC-1)*step).to(inputs.device)
for i in range(nC):
output[:,i,:,step*i:step*i+col] = inputs[:,i,:,:]
return output
def batch_shift_back(inputs,step=2): # input [bs,256,310] output [bs, 28, 256, 256]
def batch_shift_back(inputs,step=2):
[b,c,row, col] = inputs.shape
output = torch.zeros(b,c, row, col-(c-1)*step).to(inputs.device)
for i in range(c):
Expand All @@ -23,11 +24,7 @@ def batch_shift_back(inputs,step=2): # input [bs,256,310] output [bs,
def gen_meas(data, mask3d):
nC = data.shape[0]
temp = shift(mask3d *data, 2)
# meas = torch.sum(temp, 0)/nC*2 # meas scale
meas = torch.sum(temp, 0) # meas scale

# y_temp = shift_back(meas,nC)
# meas = torch.mul(y_temp, mask3d)
meas = torch.sum(temp, 0)
return meas

def shuffle_crop(train_data,size=660):
Expand Down
3 changes: 0 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def test(args,network,mask,mask_s,src_mask,logger,writer=None,epoch=1):

mask, mask_s = generate_masks(args.mask_path)
src_mask = mask.to(args.device)
# mask_s = mask_s.to(args.device)
mask = shift(src_mask,2).to(args.device)
mask_s = torch.sum(mask,dim=0)
network = network.eval()
Expand All @@ -33,13 +32,11 @@ def test(args,network,mask,mask_s,src_mask,logger,writer=None,epoch=1):
batch_size,frames,height,width = gt.shape

gt = gt.float().numpy()
# Phi_s = mask_s.expand([batch_size,height,width])
Phi = mask.repeat([batch_size, 1, 1, 1])
Phi_s = mask_s.repeat([batch_size, 1, 1])

with torch.no_grad():
out_pic_list = network(meas, Phi, Phi_s)
# out_pic_list = network(meas)
out_pic = out_pic_list[-1].cpu().numpy()
psnr_t = 0
ssim_t = 0
Expand Down
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ def train(args,network,logger,mask,mask_s,writer=None):
logger.info("epoch: {}, avg_loss: {:.5f}, time: {:.2f}s.\n".format(epoch,epoch_loss/(iteration+1),end_time-start_time))

if rank==0 and (epoch % args.save_model_step) == 0:
if not osp.exists(args.checkpoint):
os.makedirs(args.checkpoint)
if not osp.exists(args.checkpoints):
os.makedirs(args.checkpoints)
if args.distributed:
torch.save(network.module.state_dict(),osp.join(args.checkpoint,"epoch_"+str(epoch)+".pth"))
torch.save(network.module.state_dict(),osp.join(args.checkpoints,"epoch_"+str(epoch)+".pth"))
else:
torch.save(network.state_dict(),osp.join(args.checkpoint,"epoch_"+str(epoch)+".pth"))
torch.save(network.state_dict(),osp.join(args.checkpoints,"epoch_"+str(epoch)+".pth"))
if rank==0 and args.test_flag:
logger.info("epoch: {}, psnr and ssim test results:".format(epoch))
if args.distributed:
Expand Down
8 changes: 0 additions & 8 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,6 @@ def compare_ssim(img1, img2):
return np.array(ssims).mean()
elif img1.shape[2] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2))
# def compare_ssim(img1,img2):
# img1 = torch.from_numpy(img1)
# img2 = torch.from_numpy(img2)
# img1 = torch.unsqueeze(img1,0)
# img2 = torch.unsqueeze(img2,0)
# return ssim(torch.unsqueeze(img1,0), torch.unsqueeze(img2,0))



def compare_psnr(img1, img2, shave_border=0):
height, width = img1.shape[:2]
Expand Down

0 comments on commit 69feae9

Please sign in to comment.