From a51073b8ba56c7635546d8a032981f9858ec1c74 Mon Sep 17 00:00:00 2001 From: otenim Date: Thu, 14 Mar 2019 01:21:15 +0900 Subject: [PATCH 01/15] modified to calculate channel-wise mean pixel value --- predict.py | 4 ++-- train.py | 21 +++++++++++++-------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/predict.py b/predict.py index 7daf5c5..35a46c5 100644 --- a/predict.py +++ b/predict.py @@ -2,7 +2,7 @@ import argparse import torch import json -import numpy +import numpy as np import torchvision.transforms as transforms from torchvision.utils import save_image from models import CompletionNetwork @@ -36,7 +36,7 @@ def main(args): # ============================================= with open(args.config, 'r') as f: config = json.load(f) - mpv = config['mpv'] + mpv = torch.tensor(config['mpv']).view(3,1,1) model = CompletionNetwork() model.load_state_dict(torch.load(args.model, map_location='cpu')) diff --git a/train.py b/train.py index 50dbc03..afc6649 100644 --- a/train.py +++ b/train.py @@ -47,7 +47,7 @@ parser.add_argument('--bdivs', type=int, default=1) parser.add_argument('--data_parallel', action='store_true') parser.add_argument('--num_test_completions', type=int, default=16) -parser.add_argument('--mpv', type=float, default=None) +parser.add_argument('--mpv', nargs=3, type=float, default=None) parser.add_argument('--alpha', type=float, default=4e-4) parser.add_argument('--arc', type=str, choices=['celeba', 'places2'], default='celeba') @@ -87,26 +87,31 @@ def main(args): train_loader = DataLoader(train_dset, batch_size=(args.bsize // args.bdivs), shuffle=True) # compute mean pixel value of training dataset - mpv = 0. + mpv = np.zeros(shape=(3,)) if args.mpv == None: pbar = tqdm(total=len(train_dset.imgpaths), desc='computing mean pixel value for training dataset...') for imgpath in train_dset.imgpaths: img = Image.open(imgpath) x = np.array(img, dtype=np.float32) / 255. - mpv += x.mean() + mpv += x.mean(axis=(0,1)) pbar.update() mpv /= len(train_dset.imgpaths) pbar.close() else: - mpv = args.mpv - mpv = torch.tensor(mpv).to(gpu) - alpha = torch.tensor(args.alpha).to(gpu) - + mpv = np.array(args.mpv) + # save training config + mpv_json = [] + for i in range(3): + mpv_json.append(float(mpv[i])) # convert to json serializable type args_dict = vars(args) - args_dict['mpv'] = float(mpv) + args_dict['mpv'] = mpv_json with open(os.path.join(args.result_dir, 'config.json'), mode='w') as f: json.dump(args_dict, f) + + # make mpv & alpha tensor + mpv = torch.tensor(mpv.astype(np.float32).reshape(3, 1, 1)).to(gpu) + alpha = torch.tensor(args.alpha).to(gpu) # ================================================ From 2588b59484def0d0dd89cd221c93eee5bd548808 Mon Sep 17 00:00:00 2001 From: otenim Date: Thu, 14 Mar 2019 01:23:07 +0900 Subject: [PATCH 02/15] minor change --- predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/predict.py b/predict.py index 35a46c5..1546020 100644 --- a/predict.py +++ b/predict.py @@ -67,7 +67,7 @@ def main(args): output = model(input) inpainted = poisson_blend(input, output, msk) imgs = torch.cat((x, input, inpainted), dim=0) - imgs = save_image(imgs, args.output_img, nrow=3) + save_image(imgs, args.output_img, nrow=3) print('output img was saved as %s.' % args.output_img) From 6147e1778e36d9ff6bcb46b73e8577b2be5a6f02 Mon Sep 17 00:00:00 2001 From: otenim Date: Thu, 14 Mar 2019 07:35:18 +0900 Subject: [PATCH 03/15] add opencv-based poisson blending method --- train.py | 10 +++++----- utils.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index afc6649..6507c5a 100644 --- a/train.py +++ b/train.py @@ -99,7 +99,7 @@ def main(args): pbar.close() else: mpv = np.array(args.mpv) - + # save training config mpv_json = [] for i in range(3): @@ -108,7 +108,7 @@ def main(args): args_dict['mpv'] = mpv_json with open(os.path.join(args.result_dir, 'config.json'), mode='w') as f: json.dump(args_dict, f) - + # make mpv & alpha tensor mpv = torch.tensor(mpv.astype(np.float32).reshape(3, 1, 1)).to(gpu) alpha = torch.tensor(args.alpha).to(gpu) @@ -170,7 +170,7 @@ def main(args): ).to(gpu) input = x - x * msk + mpv * msk output = model_cn(input) - completed = poisson_blend(input, output, msk) + completed = poisson_blend(x, output, msk) imgs = torch.cat((x.cpu(), input.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_1', 'step%d.png' % pbar.n) model_cn_path = os.path.join(args.result_dir, 'phase_1', 'model_cn_step%d' % pbar.n) @@ -259,7 +259,7 @@ def main(args): ).to(gpu) input = x - x * msk + mpv * msk output = model_cn(input) - completed = poisson_blend(input, output, msk) + completed = poisson_blend(x, output, msk) imgs = torch.cat((x.cpu(), input.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_2', 'step%d.png' % pbar.n) model_cd_path = os.path.join(args.result_dir, 'phase_2', 'model_cd_step%d' % pbar.n) @@ -349,7 +349,7 @@ def main(args): ).to(gpu) input = x - x * msk + mpv * msk output = model_cn(input) - completed = poisson_blend(input, output, msk) + completed = poisson_blend(x, output, msk) imgs = torch.cat((x.cpu(), input.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_3', 'step%d.png' % pbar.n) model_cn_path = os.path.join(args.result_dir, 'phase_3', 'model_cn_step%d' % pbar.n) diff --git a/utils.py b/utils.py index fd8ae24..744b250 100644 --- a/utils.py +++ b/utils.py @@ -2,6 +2,7 @@ import random import torchvision.transforms as transforms import numpy as np +import cv2 from poissonblending import blend @@ -121,6 +122,47 @@ def sample_random_batch(dataset, batch_size=32): def poisson_blend(input, output, mask): + """ + * inputs: + - input (torch.Tensor, required) + Input image tensor of Completion Network. + - output (torch.Tensor, required) + Output tensor of Completion Network. + - mask (torch.Tensor, required) + Input mask tensor of Completion Network. + * returns: + Image tensor inpainted with poisson image editing method. + """ + input, output, mask = input.clone(), output.clone(), mask.clone() + num_samples = input.shape[0] + ret = [] + for i in range(num_samples): + dstimg = transforms.functional.to_pil_image(input[i].cpu()) + dstimg = np.array(dstimg)[:, :, [2, 1, 0]] + srcimg = transforms.functional.to_pil_image(output[i].cpu()) + srcimg = np.array(srcimg)[:, :, [2, 1, 0]] + msk = transforms.functional.to_pil_image(mask[i].cpu()) + msk = np.array(msk)[:, :, [2, 1, 0]] + # compute mask's center + xs, ys = [], [] + for i in range(msk.shape[0]): + for j in range(msk.shape[1]): + if msk[i,j,0] == 255: + ys.append(i) + xs.append(j) + xmin, xmax = min(xs), max(xs) + ymin, ymax = min(ys), max(ys) + center = ((xmax + xmin) // 2, (ymax + ymin) // 2) + out = cv2.seamlessClone(srcimg, dstimg, msk, center, cv2.NORMAL_CLONE) + out = out[:, :, [2, 1, 0]] + out = transforms.functional.to_tensor(out) + out = torch.unsqueeze(out, dim=0) + ret.append(out) + ret = torch.cat(ret, dim=0) + return ret + + +def poisson_blend_old(input, output, mask): """ * inputs: - input (torch.Tensor, required) From 145348e0a66e034746ba1487ff2816d900fcabb7 Mon Sep 17 00:00:00 2001 From: otenim Date: Thu, 14 Mar 2019 07:49:16 +0900 Subject: [PATCH 04/15] minor change --- predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/predict.py b/predict.py index 9168826..1d150b8 100644 --- a/predict.py +++ b/predict.py @@ -67,7 +67,7 @@ def main(args): with torch.no_grad(): input = x - x * msk + mpv * msk output = model(input) - inpainted = poisson_blend(input, output, msk) + inpainted = poisson_blend(x, output, msk) imgs = torch.cat((x, input, inpainted), dim=0) save_image(imgs, args.output_img, nrow=3) print('output img was saved as %s.' % args.output_img) From 98e5f3da8f4c0590fee8f76fafc40d5ad08f3f98 Mon Sep 17 00:00:00 2001 From: otenim Date: Sat, 23 Mar 2019 15:50:50 +0900 Subject: [PATCH 05/15] change input channels from 3 to 4 --- models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models.py b/models.py index 791aba3..5589383 100644 --- a/models.py +++ b/models.py @@ -7,8 +7,8 @@ class CompletionNetwork(nn.Module): def __init__(self): super(CompletionNetwork, self).__init__() - # input_shape: (None, 3, img_h, img_w) - self.conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2) + # input_shape: (None, 4, img_h, img_w) + self.conv1 = nn.Conv2d(4, 64, kernel_size=5, stride=1, padding=2) self.bn1 = nn.BatchNorm2d(64) self.act1 = nn.ReLU() # input_shape: (None, 64, img_h, img_w) From d0b089dcea1a3cd2558d791c0a1e8cc0ddcea784 Mon Sep 17 00:00:00 2001 From: otenim Date: Sat, 23 Mar 2019 16:03:16 +0900 Subject: [PATCH 06/15] add an image of mountain --- images/test_5.jpeg | Bin 0 -> 11560 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 images/test_5.jpeg diff --git a/images/test_5.jpeg b/images/test_5.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..5dc72684a35dc5dfc37446898afa8516fd61eaa7 GIT binary patch literal 11560 zcmZv?bx_>T4>0@zjzfwadiY_5LUAwduEnLriXUFw4qA%K@kNUj_m&o?6e!+8aXqX+ z@#CPlzx~ZS@60>TJlUE3BiYR+vzzQD*@yXuHGoW0RYMiP!U6!Se+_s*0dD{TFdjJt zITZy31vM2F4HQa8OG5)?f-+Gv(Lt%HX&9k&P-ZR|GaUy#4U|QYh4Tq7A0Hnxv#^9P zuQ(Sk9}gBGAt4Dd2|XDZJue50gZKXmJpU(_?|%r<|8@Ki-2en1Xu+bu!D0olAy_yN ztcN~;;h&7SSpO6C{{vhA8wUi&!o&YZl^_GKv2bv3uyL_LV4VLA!@`CDI4mG?L0km` zR$DNIkT)Xj1Dj&Q5v5_z#P4G~Dm!5xk$*Gkuz>%K|Nrs*Hx&nj3;svrg#g(901F2j zjD`FEF8hxy4hy-U0tKr9$W{o^FmZJ3O{u7q_Tga>Ai}}=Coc{JkONNHv_)CU=}t*} zgmpwtv_<+vrgS8xbO4b#p7J)zKA#TF43_x+=AIa0(f#qy064%l#>iJfm%L^4@sg&m3Wn!ok|@HVb{ zn#)r^Ih1xqDa6{G(C9Gxy?H)8UADJntCn4VnPX?bIZ9%YN0OUYNS~((+F`Y<`Obd0 zOvgYk9{9tJ_DAHfK8Yz~%~lCu*mQV~Uxi$lNbFNIM@HIkV@t{c={D%fRg;MVf2!n zOgcV;4^ViX9`8qO|^4>!Z zdaH^S)%Qj$B2(6KlnzP-L1{`j-QiM>fIxIuOGO5XaB^j92(fImy$3O-5rdYt^hGzb zM$g|m96xq0A64y+$c9{0J^(TJ1*E~h;;K|8!`SniWK*wVvMlYqh%L4D+u4{R$_N*n ziK+n7%8e6=e_L#HS)peeU-CXFj|F6ZJ38nwWzaACZn(r%;3&b7K~KZT^#~4A+@pj6P$l+z3qJqtd16hA82OLl3f|nrvZ)568(Q zJQ88bA8PYM<2Lk4#{K75OD5fTm#^u(1-$HUj!hIz9UOBTM{=DL;4+i}>t$^=vC`2? z`gf+r4W6X06^?}>6!cOJ!tcuZumvx2>kZ$1>tzndamGOamitCHJ>d<*~ms< z?vm7MpyxlwI5Pcg~2+(WH_hOV!W@mSIAF>RB+jk@uZd*F!R522JI>-MMTJK&f-;4==f`4-s*G zQKs9LFB@*pt38irvvI!oHU_KCFZsqvdkZs!XORD7sf7J{FJG0;ew>{WJQAvNoQdp( zf2Kyy2-tk*Z;%=LbQfYmQ*&0Q`(fd5WwGuy@*4QwM=^kPGe${stAgkf8u#vo^e{;~ ze%bz$@q)~^_N>#fPNTKYVwpP%I~ujrEuB+4W-`rPKeh#9*4U66z19wy4pXXE4DIk9 z!H|=v**FW>g((kTh`Q>e4?3W}J^)|m@R5!#2}oQG|GzbP+Cet~@r0~hJjqXHr>3N5 zLdG74PG=r!rd>ZL(^8n@w~Q>6|1}Q;3?~-2*-l*NC&Uv~FZb1GJCT=Gj?zmYI0J26rf!$ohdVYoU~~OhZ88vJA)7F_Knq`$$|&Obs~{71 zC}p1#eB4Q6OhHt=XW1TQzoK_NL4SSJVYlj5W%P z>H!#{!L4|&Zm#wK6bTXhg@3HXT6bBLyUmPQYAQBUe<~390Gth5O_@^=ivP&qt_%9} z@%EjUAW5cnq|p=MzD?~M2BPr5;%EHH&_lOO)hT-Hkrb(fAas0aD!qQ=TGM|`$NhDm zXIK65up&9h^_sMbFDIuw#3_pk^-k0;l1ZEBAE_2UF$oJ7n7nXg>eum)w{Hv`z9+|m z9m+5DWTWHv1J62ap$_y$rh6~X`M`y>SDa?bKNraGszc$3`7 z@YILeR=r-3MpHBWIZopglTZz$ZkCKya>w zz_{GBk)&XqN&;WIMtKbQP^#_g-@=#|a3cFKcM1zL~XuM>2Z{9 zZsZJAUKf;Qt{yQwYhXFY7m`ov9tIE+qKzC5B1V=Y-BQ9L>yp4_N&^)~8HZ8_ z^8!xvaT-dkme$U{_w{76jRlYGU4P=mmk~b4>v;g$E2?LN$pVPi-u^_ydtapp*{i~bS>(O_;@$7Oc%Tr zWJM|wpIM6?CS2L8f=V>Y-cYFPK^v5EZaJLT=c+~HPko6VGS zDmQ~FCwUTh6{H1{VwC<*gA}UwDQyCzex25h&AJO$LAVjQUP_!_qqivsE!E%FWf4sg zZZt3zDLaImi;{gya=;S@428t1Y#azY@M-VUsCZGi@Hd$&ArU;j6!Q%7xf+#9EErdq-wnF$zDUMGbD4{ zS2cOiX091ZA(gku19vB+4DzS zzx!iW_{+t_@>c7Lxv0!^@KBG9DN)tSE2#dRosghaGh6n3c}G}epUuUnovUURkpfbf zc0H6Xb2PHV57~K}(rM7VQ6RgN;%#Cl2$A>9a0I+7K@w9gx}=!iz15dF?E3&gkAW zcQi+CrxGwG9@*jdLt#vb2YlyCgO*g?~A(og~s+Q!y8e^)>@s&Zd2 zmWm=Q)!lTzU6aH1--%WxkNUf%tJ48H^$&or4(G9TuDpNfWUl8*9xgR@u7}I-*kj58 z*Qa=4)gg~C`o(r`)hiLMPaj>s&P0u*%o9k55SkKLMiLD&SFxSPrmhn=7{Z7FpHO=1_$c2g+h0k~a^)==0s+gn+C+ z@Uu^+kb9T3c1a{2+x2XV%Jfq&4&d$UKVs|n zUARmrTWaS;&lRm3nd(z3c^3_AO*c z|J`sTb2NWOl)%HRoE}px$y4adN)bAqCoa7WG2-H2>jm-}Ogp}i-0VA~#2i0!<&p@{ zRNeIrj=Bmd83*I9X?QGP;wSy;B$lF|+x_~D?C!PkFzjZmC)!SCt=zdz&e9%((<$jy zHqWY~JECNYL(fZ|gM(wGyO;Y9`!ITDqhJ?3AO!LnoEX6DP%T7A(6v-6Zigs%q zIEMay0SmI~HF2|fmbAi+h?h9N@39sAuuQ>j_@4to*9Sx?2+ApT7Q_P>& z#7Ab^$!-)l70zas4__-l?;n7<(J)RnLc9-58`7jJvF7RsA|?NnU5dMY(y4hz0kE6W z(Rp^vnJVYx2Hsa?+U#RZk5mr#rm0iq(NF5!J_OOwp1*R;xlAH^S&AM2jf)8cTb*1Z zna8QN;oJVp<_f1Dqy&+D!BH{8=oKPns>~Vp_5exN4LD4J@v}Bk?4Bo6lHPsRpr}P> zvI={V%kb6P(eZ9?h*Z(1?LjJ2<1)B@0HJXo_a@MBc5) zNxU@0o8d%piOjZ*A?5m`lGD<2esA+3b23Wt>0~xG*KU~FzLMxmq6J$o3lM_C+|rwj zh{Z%Q`@k2~2~XKW-fG5r3yyiEf{q6pr(EXJ{zJc7R4bg#u`XpKd}7M_L`6VKWI?71 zqJFH7M`36gTZzsd9+lQhj+l*Pf~s|`%jvo$j!VVw)k*sONyL^NC?8M^tIDuNXJd29 zb(+9t9<%DVRxC8J6VRTZPISb~6x^Sz&D@B=g{b+uW1|9Z$k-8>MHfl^pSiqByV+tU zB(r1O%MI7jThb=^XVjhQV(^hVxVtr z#7^q35V7^noxDr8I#QgsBe^u#C}v(idngFA*e0wY=YytUByQ;Rc-vk)H_9U=k8AD_ zNu}n)_7#j#`}Sm)US#7{=RX_XrI-e5%)?Zp88m49c8@APuSiYHS?VnIKk+{1onwIIo zExSfY)J7~Izm_~N%O1!Up?vN30o{S#^aLp&+HKUfqzpbL@F$2{yQ4T3K#iH?&>Yb5 zrb9}j3B7q$TaeLG>jZD>$bj$z@OsWCgrzh^eK*W4IQ^v83Z#96BeXu>6IrddZ+2gH zOkKP@XyC8mNFP9LAOz!|ic9HNIx6d1EIPbG^s;8Mm<>$YsM^QLI^($EOT`nUbG+I9 z8iTp44SZr{(I}~<@F_DxuBm3H@My34;%A4PA2WII){#!XQMK_S`j{8?j znXAW3-V-dIMuHzL6*C}>O)NZy-2eP;W{0m5t4pI|#A)wsv9vc&u5PqH51dZO*)TR= z`S&i`RJYBmbLg@YQ4C|yDhFql0gwND2BD;;5#&fLqLA%fS?BzZG5pW^drsJse@L9! z37&-yo(5j>`~g4ls^Uxr*i&s&+MwrEL4qTc4tCecd9IY5P5(|xWGFelV3b(OH)*=j zms9P`7JWxSg+sSmaI;^!ZsPoT$!?O_6yEKCx54t3LVTL!)F6} znC*?9mF3ry(!aS^CB48{FzQ5#z0Zcf5r+k9f0SM;lr{>}(oU@KIvoK=^;j6US)pMC zp<&hbS;i z{>l9G&RiIMZNL5BkePwzCaYN!Lr*ucXwSXF?^bm0OK<9{>Xo1O|I-#-OXLqH}B=snY5A9zOl3 zX?(O_5P=MXC)5lJA6v~T?DW^cc)L#Vmq2fjKq99PM^(^g0sF=~dN$v9$7?ouWLmMM5B z*v13!K~;s;^@V~waS;Dl?Gty3EQZKH1Io&)OQ7&$rK^^mW-8tqxlM=;_n!wq@8ad@ zkM(avGWFRWLSKYx$9y=UZ+qBIjO&~|QEh+u_(2RulXl9p2_-vp8zK<6Q6v!E;Wu)K zp@`?L+p(RuGUWZXxo8x{r?3_c**;!*S5POnqs3SwNO!X}Gx=6`O$vHCJIl@cp?-DzyUUnSx%f^_Spt(#n{u)YFBG%aHP>IWrbY2{IV<0?wp5Ox^ut_^+GRbg(Gq_~o;4E-NJ- zFT(HNo7sg52VCx$OO{nm-4~=jQ)%D{*$YE1>bPcboBJBTUO`bd@&&7%s@e@MQ>qP7 z1|^+RiZ|ier``AZUXtwFm^fo5f#48Bxo$s-Vn61zi9lLkH+V^gux-%v_*8NB$-D&!Um)Fje7ANa#ZRap5$*w zqCmlTPs7(-THPvJk^^?SzDy3silWwvH0hAdXjm+nM9$rQ;mzkFDIpMN++!&M%wuWe z>*~rn%BJf-A*SKF{csf3`xqO4MwIs2#f^#OzW}tTX4%HZiwF@(6#f8w-2XVYVlr*{ z{EIC+{u#B+hyS*t@FgfYE1NGul##{+h{rH(n?SZRjA@Ec+<(ZhY+^Tg*}w&I0ZHb{ z5A;kL(ar48!DhS0>u>)Vg$bsCNv9gcufxw?nwAqGc#%S?ST2#}GVf{*!q4_?eu9%7 za|*2DeTo?JuIiVkGT@~ShP3rDtZ;JC0sLx`pg0&w#7B(EGf(Aj7tlzk;&ua(dRmJ_ zVyxgw$&R&|85op<$L{C=8Du`-wp><8#WeZ4_OLEeB;$j;DK4U|ayX=1Sdj`kYP}Xn z-RYKg)%jE#e<9&rNHnFKmIH1pgL`WQF|9d_7DbLfhOvlv7mSiZH9JISEwkCcup09Q zEbvfEGpJP}ns0N;4h_P$Q}hbymv{RqbN_{X;Z?mBruK8U_)})@#e_q#OO$Awnz%8E z|5xSFUacl2{g#=x* z99xR|e5m_`$vdU%H6y*A<8n=$GZ&kC)l(dw&Xa;$eO}|AhetP&97Y7P#qyawQ2N9z z$*Lcn*0M=PqD-p3|#ZSXe&FTTrzo(#ccbF2}7ir$B&wDG60hG@$l% z)cNUY$53I9rI!K-Zi|*P%K?$g)j@X7zgZl1+07$ZX5t@!JM8zdg%K-XjH#-qdJimj z>2NjKy<6I`hT-L_XR6yTj|-wEVBao4k<9}qC?JvH zX8dOtbe~`C5)NywwXvH1r0nowMTBe5_H}bReGq#4obOg~Z*X(8I7n_JXh6zg<}3NfK8{TBi9@l)U#rDS zIDkB4DOQvtp(ejMR+yZUbZAMHhjct#@KO0yT}yDRO%3kQT6S{id2YwVz4RA<+!?XL zmdXrH$lf5v>OODwro_Wrw8ZLPO-1h~$^5bAo8`-1#nPG=<_y;i_iHNLT>;&lU61t? z>YvWJzxm>#6vpP|4Z5?FDr^zdo2@%iD*ardFh2>!A3B4Agf1{>3)xO@n!qogtblZH z5OWD$Q&!JiOY?(|-9(rq9ke&|bAL`$cnH9-R*oDJ;_`2J& z^-H2T3z_sV239jxjyf{^ZE(OAk3Kt$JtwVa3>*PIJ1lcWap!lxTY};=!sBQ(4)Ee) zH;qHYPFt)Cy<(geyFCExUnBi7AkDt>!r<>hE^xOC`h)ap*`NORN}udl^@u2=d?tEg zW&2c&!ot`0h}z+&>APSBCB~r~48(k*zN7WoJ7>^wH@E!T>_@Pc*i9RRxW-hv*OE8Z zNn@%vL!xi1E~z^SvFFZlQZt`w)O}q`53CZTZO_}7o|vlU>6&iou&lW6oi$CX^H~7E zk!s_d5zXG~>Ft({RJ9?k#u{cUaaUdoDp`EAQ-rMqXQk92QL_bU&{E9bY<>PuqKsYk zC3~lm#;rI<(q}{FKwj8^b+4y46F2{zYokr3b#q@ScJ&tlT`5}TGL;#fxvQ9&7_UiY z*M>(K|J)|FUN~08!sOm-3%5e9tgyY=9w|k@rzjn27-2BLmB^_T*6^Y?bDrK4_DP(T zH+9?#w9%*Spo2Sg)(F?xcbpDwv)VGTgCSzc!>1pBcYeY7&)x@+ule&tmc`&(uL@Fr zY(sJ;=5D+ewZzzM4bEd-Rr?p}Re}$KqzfgeN_x1eEZYByp>f)?^u;V~9=QUYNQWh*!b;}73X}EedAd)RPh{;F1t8+A3wdHk2B<&)ITwMMTtFB9D zmy7J0Do$UCHiyte1Lz|F1_MuwWBA#0_(wEC3#jJm;P>*WzT4LX@$0Cm&f?#+GMbM( zkwi@Wc7l&vtRsGdo*p3TE7^f)`=?~){@K4)^#yS?&TBDBJI(dN^8|hxhB5}j8ICx3 zAe*7hMD5@a zqC*g~7=n&xEa`ZqJ2}UIlm1qaoonH7(DO0j_^iwg7d$%0^bfZQVouJyseL_hndBN{ z(gbUl74w%Rfi^^qP^{lEs;hRNbZrJ(MD=u>biA0K12!=}TIxdcgMuxe`0sIs z_6$z1SCJJioRVX>$)Q-uXTAJ-q4lFwpE2fBWhmD}BjOJ^W`}10B?FXmjv*@S7^g88 ztU^-T_W-2Pks-03#!HcD|EtID`Yb6^PH58IH|j%4r~{l9FMWyG6?_-4B?uBT8jvh_?7JenE;TFs<}hh} z-L>mnop+cJI~!oI^~ULo(Cm9n?A7W0&C=SsmM}`do)z|MC4{iE%E05NY>~PN?JrO_ z3@5!^F5;K3&aujWmQkytnP6sTvqMnm0N!cckzHpR@^vvh{JqgyO-UQMt+jv(4HXnaDF?>;UG%G8Q876l0@kJZb1*XDNZ6P+)03=lftmNV;;BSQ zwUZ{KI_&edIe|vhY+E)g5vNFKIFDae7A9 z%az%vqkJjMp6$GYynk{_XoG)12OtvZysf-;OKQiAD zqO-?@lyAfB$?J}JLAf@Bmc^AlEZ*NUqWw%L%^syS7{i7fOh=&=gJ+`MlcZ>`v6;sG zGXo_-V;=Iirm74^|Ixo>7vF5C-zTJPMrpU&1u>73D{DFATu|N@PJ0M_t|Er;S$g!~ zJxXcnb6~#|PF529Rz93SsNe*Tcpwx~3W7O_SZl}H z6IN<=b22tB=Qk@pfGto81qm?phVzX56v?7aBL7Ra_mWhs zyqh3=$2ks!X3sk`gT`T!S?fckj#e+-Kv~PH`ESMg>N(VN&eWNl-Ck>*V^9v^H?W|QII(^0qGix^dwXlOTk0q;5;u#BRkUA#tBDC(lSKXEd{3uE&=;H}ecL-lnBGo<)w>r-OO*YEX!| zM~Cn8u-^K=(CUI4c^}HS4Y{S~1uyiXcCRQz??|gH->wxJ5R-vY{xAfHmQKF#GiU=Z zAR3SGhooA^=mWuDtm%t|EcIo}VB1oq(B;7vN{^ZM$^B-O%?*v|XGB9ck%IS4gK@uz zkAp=M`9IIRuN0?=EYg-2UUAq8N@=?>aaw08ae)$Z!`ihy7cvGZ7<+CEWWw{s-Y&@l z-jWD#giJMp3pI^hl=7U2u^WvWu@W+;_>=#F`nNR`Ji_FH|Mq7B%+mKW77Aya!1F!@ zk5<;%$z*)#^ACnV9x^2`wvW`<#kRHQ?|px}{jO2MKqBSZ&1ZMdqf^e~zjiXsmUrV! z4?(NQ2{cy9%6#N)NDm!bf=a--7CYO4tEqau*KfXftT~uEB)k!yek9-R9!EwLSSfn_ z`N!Bq6hokQ%Y0-k%>PEa`LrC={g@W6D z+Ffhx*>cSB3n~y$vNw!(d=tl@;#BXKup_H8dqu|B>Ewf~VNn461w|k7XiR*gpA)Tt zdyL}`dW9rXD+A@AyzT7m_*IVM^9Hrf_d3*>F^ zZA+^E;;&AHNm;rs_~Wy+XkYOPwp61WNEGJ@Rm{J7{cJtz_2aF5c!zUdtR#D^v{ zS1-}h?w5+(_;wNN5Njv>6vt(h`-ti^NHFQgmA*q-o^IgzhZdvI%*XR~T_f%ms7Pfh zE_?&1^l{$oql%{9*Wu6MtwdF+<*GImtEeujuRAa4WTu+WK6qbnV{qJSf-Qp)c#<1q zvGp@4Qx*#e!=neZQ;jqqQOcIc4lQ{pWw6a8fano~q-`bojlE^`z-@CU>x;umX3l}XT*ry);ql<6oP4>t)`G`lYZ0>ze*5+q{p|G{b65Kyo!#77sw8@p} zQliqLZ=2^w{#oiUMSfih#epT-xYm!x@+C)eeqC$qb>A@Yl8?IdHpC(%+Szm&te!{` z8ZIw=RrG9Bxof1c&pNW!7y)Ll^d?gnI`9ge_06Jk=N?(@XXLxv_xIb0OeOrb1J?g^ zO2fEG%i!Jr>h&`AeO}6Ymq=DD8>oGxItDLpy_2Hy?B)lezszw7uPojNr3JxD^uEu~ zGOlf_f=Pa!{WP###__W{zauA7g25Yu@fs;_)Iw~ZKi7plTk1K3x0Xng%P<6P#arw? zQ=>sK+4K({iWE1|8(mFcZujJ!uA})LYCg9=sSwh0|lf@Gvo+2b#Sl$)RZwX7?mt#Q< zd0YpQHSo!UmT^S$r<@#febt5748i*e&5Y0uIFIVi_t$8l(6P*FJ}Z6t?J#;Hq!AU~ zwEVwBq{v24jaER8>fuX=xGb7kLeCD3$rJYup|GP7b)|$TyeFGZY zCVBy$!{88;gxAd`Zk0m@EW<8LO^e;!Nc_bp5;@Ty#+k3C?gLauG2>fqs2Ds9*KLKiCC4`EXkUOpj zBXN~79U*QHLH;9Vas{_68NB&dt_`cLz zGefeo(E?<(M}2Q&VhGXvN&1JK94|d?jEeU#r+l&Y#RfPPyp`Be>wB@4!9B|~UwtX$ zF;db006fxYCs!loK+qa$V1od5wL9Z31iCXdGw+?L$-G0JmQ7R%j$DhaW`>c&AA-kR zbql!lmAYSDGRW6CT4=;$4c}11_}jRBMJV*f`6S(%^WCXUuEa)C5t^8HXH+T23Y~)qr^Wg4Q}_WA+{pH`|e||YpoXs2&d<5)4BvqAMJ5QOtuwfYM&t-f8%4L z7xB?}&bHe_4?yrA?pbfE=afu4qEn&=dnCI!-@+HVYX|R*zPL1&|85Z{Jwmqtv`f9t zf*D8xhJxuT6N7`vFeVL8^-;LB57bJQ?N5(O;k~)u_F~d0-m4a01L4MM-0Z6g8kPpc zqmh&gqnZ+)^Sq;j>7B-0DU3@k&tjXuHQyIuszqCLS{I_5RVrlWXPEM$IPXfUQZ;!W z*m~S1xGsI3p>H$rN13_WvFXVdwV90CP$iECpzOIXxs6B;fCO(W^&KY$P@0ze{r&M` zfiMfajvcAMkU~yGyVirItTGj!aouTscSgtIV1EY2E!rr&dbtwQ9J|!`(}fc53kw@! zwM^_u6V%KI`uF(&X|*@%kU`8yqYbCWxHRK{*K9jAq#PcAoB(*Ay Date: Sat, 23 Mar 2019 16:50:09 +0900 Subject: [PATCH 07/15] modify --- predict.py | 13 +++++++------ utils.py | 4 +++- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/predict.py b/predict.py index 1d150b8..7d2d0ff 100644 --- a/predict.py +++ b/predict.py @@ -36,7 +36,7 @@ def main(args): # ============================================= with open(args.config, 'r') as f: config = json.load(f) - mpv = torch.tensor(config['mpv']).view(3,1,1) + mpv = torch.tensor(config['mpv']).view(1,3,1,1) model = CompletionNetwork() if config['data_parallel']: model = torch.nn.DataParallel(model) @@ -54,8 +54,8 @@ def main(args): x = torch.unsqueeze(x, dim=0) # create mask - msk = gen_input_mask( - shape=x.shape, + mask = gen_input_mask( + shape=(1, 1, x.shape[2], x.shape[3]), hole_size=( (args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h), @@ -65,10 +65,11 @@ def main(args): # inpaint with torch.no_grad(): - input = x - x * msk + mpv * msk + x_mask = x - x * mask + mpv * mask + input = torch.cat((x_mask, mask), dim=1) output = model(input) - inpainted = poisson_blend(x, output, msk) - imgs = torch.cat((x, input, inpainted), dim=0) + inpainted = poisson_blend(x, output, mask) + imgs = torch.cat((x, x_mask, inpainted), dim=0) save_image(imgs, args.output_img, nrow=3) print('output img was saved as %s.' % args.output_img) diff --git a/utils.py b/utils.py index 744b250..2fafc3c 100644 --- a/utils.py +++ b/utils.py @@ -141,7 +141,9 @@ def poisson_blend(input, output, mask): dstimg = np.array(dstimg)[:, :, [2, 1, 0]] srcimg = transforms.functional.to_pil_image(output[i].cpu()) srcimg = np.array(srcimg)[:, :, [2, 1, 0]] - msk = transforms.functional.to_pil_image(mask[i].cpu()) + msk = mask[i].cpu() + msk = torch.((msk,msk,msk), dim=0) # convert to 3-channel format + msk = transforms.functional.to_pil_image(msk) msk = np.array(msk)[:, :, [2, 1, 0]] # compute mask's center xs, ys = [], [] From 1dace6b70904df45074c56bcfd0da48615c1fa40 Mon Sep 17 00:00:00 2001 From: otenim Date: Sat, 23 Mar 2019 16:50:16 +0900 Subject: [PATCH 08/15] wrote docs --- utils.py | 67 +++++++++++++++++++++++++++++--------------------------- 1 file changed, 35 insertions(+), 32 deletions(-) diff --git a/utils.py b/utils.py index 2fafc3c..b792f46 100644 --- a/utils.py +++ b/utils.py @@ -12,20 +12,20 @@ def gen_input_mask( """ * inputs: - shape (sequence, required): - Shape of output mask. - A 4D tuple (samples, c, h, w) is assumed. + Shape of a mask tensor to be generated. + A sequence of length 4 (N, C, H, W) is assumed. - hole_size (sequence or int, required): Size of holes created in a mask. - If a sequence of length 4 provided, - holes of size (w, h) = ( + If a sequence of length 4 is provided, + holes of size (W, H) = ( hole_size[0][0] <= hole_size[0][1], hole_size[1][0] <= hole_size[1][1], ) are generated. - All the pixel values within holes are filled with 1. + All the pixel values within holes are filled with 1.0. - hole_area (sequence, optional): This argument constraints the area where holes are generated. - hole_area[0] is the left corner (x, y) of the area, - while hole_area[1] is its width and height (w, h). + hole_area[0] is the left corner (X, Y) of the area, + while hole_area[1] is its width and height (W, H). This area is used as the input region of Local discriminator. The default value is None. - max_holes (int, optional): @@ -33,13 +33,12 @@ def gen_input_mask( The number of holes is randomly chosen from [1, max_holes]. The default value is 1. * returns: - Input mask tensor with holes. - All the pixel values within holes are filled with 1, - while the other pixel values are 0. + A mask tensor of shape [N, C, H, W] with holes. + All the pixel values within holes are filled with 1.0, + while the other pixel values are zeros. """ mask = torch.zeros(shape) bsize, _, mask_h, mask_w = mask.shape - masks = [] for i in range(bsize): n_holes = random.choice(list(range(1, max_holes+1))) for j in range(n_holes): @@ -72,11 +71,13 @@ def gen_hole_area(size, mask_size): """ * inputs: - size (sequence, required) - Size (w, h) of hole area. + A sequence of length 2 (W, H) is assumed. + (W, H) is the size of hole area. - mask_size (sequence, required) - Size (w, h) of input mask. + A sequence of length 2 (W, H) is assumed. + (W, H) is the size of input mask. * returns: - A sequence which is used for the input argument 'hole_area' of function 'gen_input_mask'. + A sequence used for the input argument 'hole_area' for function 'gen_input_mask'. """ mask_w, mask_h = mask_size harea_w, harea_h = size @@ -89,13 +90,13 @@ def crop(x, area): """ * inputs: - x (torch.Tensor, required) - A pytorch 4D tensor (samples, c, h, w). + A torch tensor of shape (N, C, H, W) is assumed. - area (sequence, required) - A sequence of length 2 ((x_min, y_min), (w, h)). - sequence[0] is the left corner of the area to be cropped. - sequence[1] is its width and height. + A sequence of length 2 ((X, Y), (W, H)) is assumed. + sequence[0] (X, Y) is the left corner of an area to be cropped. + sequence[1] (W, H) is its width and height. * returns: - A pytorch tensor cropped in the specified area. + A torch tensor of shape (N, C, H, W) cropped in the specified area. """ xmin, ymin = area[0] w, h = area[1] @@ -121,29 +122,31 @@ def sample_random_batch(dataset, batch_size=32): return torch.cat(batch, dim=0) -def poisson_blend(input, output, mask): +def poisson_blend(x, output, mask): """ * inputs: - - input (torch.Tensor, required) - Input image tensor of Completion Network. + - x (torch.Tensor, required) + Input image tensor of shape (N, 3, H, W). - output (torch.Tensor, required) - Output tensor of Completion Network. + Output tensor from Completion Network of shape (N, 3, H, W). - mask (torch.Tensor, required) - Input mask tensor of Completion Network. + Input mask tensor of shape (N, 1, H, W). * returns: - Image tensor inpainted with poisson image editing method. + An image tensor of shape (N, 3, H, W) inpainted + using poisson image editing method. """ - input, output, mask = input.clone(), output.clone(), mask.clone() - num_samples = input.shape[0] + x = x.clone().cpu() + output = output.clone().cpu() + mask = mask.clone().cpu() + mask = torch.cat((mask,mask,mask), dim=1) # convert to 3-channel format + num_samples = x.shape[0] ret = [] for i in range(num_samples): - dstimg = transforms.functional.to_pil_image(input[i].cpu()) + dstimg = transforms.functional.to_pil_image(x[i]) dstimg = np.array(dstimg)[:, :, [2, 1, 0]] - srcimg = transforms.functional.to_pil_image(output[i].cpu()) + srcimg = transforms.functional.to_pil_image(output[i]) srcimg = np.array(srcimg)[:, :, [2, 1, 0]] - msk = mask[i].cpu() - msk = torch.((msk,msk,msk), dim=0) # convert to 3-channel format - msk = transforms.functional.to_pil_image(msk) + msk = transforms.functional.to_pil_image(mask[i]) msk = np.array(msk)[:, :, [2, 1, 0]] # compute mask's center xs, ys = [], [] From b1980bf76ebb9b6c3eef28dde2a367568aa1ea00 Mon Sep 17 00:00:00 2001 From: otenim Date: Sat, 23 Mar 2019 16:52:16 +0900 Subject: [PATCH 09/15] removed unused variables --- utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils.py b/utils.py index b792f46..9d7c0c2 100644 --- a/utils.py +++ b/utils.py @@ -41,7 +41,7 @@ def gen_input_mask( bsize, _, mask_h, mask_w = mask.shape for i in range(bsize): n_holes = random.choice(list(range(1, max_holes+1))) - for j in range(n_holes): + for _ in range(n_holes): # choose patch width if isinstance(hole_size[0], tuple) and len(hole_size[0]) == 2: hole_w = random.randint(hole_size[0][0], hole_size[0][1]) @@ -115,7 +115,7 @@ def sample_random_batch(dataset, batch_size=32): """ num_samples = len(dataset) batch = [] - for i in range(min(batch_size, num_samples)): + for _ in range(min(batch_size, num_samples)): index = random.choice(range(0, num_samples)) x = torch.unsqueeze(dataset[index], dim=0) batch.append(x) From e98dab33c484888ada71a6cd090b8805be4777af Mon Sep 17 00:00:00 2001 From: otenim Date: Sat, 23 Mar 2019 17:15:00 +0900 Subject: [PATCH 10/15] apply changes of utils.py and models.py --- train.py | 63 ++++++++++++++++++++++++++++++++------------------------ 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/train.py b/train.py index 6507c5a..45ef4f1 100644 --- a/train.py +++ b/train.py @@ -110,7 +110,7 @@ def main(args): json.dump(args_dict, f) # make mpv & alpha tensor - mpv = torch.tensor(mpv.astype(np.float32).reshape(3, 1, 1)).to(gpu) + mpv = torch.tensor(mpv.astype(np.float32).reshape(1, 3, 1, 1)).to(gpu) alpha = torch.tensor(args.alpha).to(gpu) @@ -136,14 +136,16 @@ def main(args): # forward x = x.to(gpu) - msk = gen_input_mask( - shape=x.shape, + mask = gen_input_mask( + shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area((args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, ).to(gpu) - output = model_cn(x - x * msk + mpv * msk) - loss = completion_network_loss(x, output, msk) + x_mask = x - x * mask + mpv * mask + input = torch.cat((x_mask, mask)) + output = model_cn(input) + loss = completion_network_loss(x, output, mask) # backward loss.backward() @@ -162,16 +164,17 @@ def main(args): if pbar.n % args.snaperiod_1 == 0: with torch.no_grad(): x = sample_random_batch(test_dset, batch_size=args.num_test_completions).to(gpu) - msk = gen_input_mask( - shape=x.shape, + mask = gen_input_mask( + shape=shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area((args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, ).to(gpu) - input = x - x * msk + mpv * msk + x_mask = x - x * mask + mpv * mask + input = torch.cat((x_mask, mask)) output = model_cn(input) - completed = poisson_blend(x, output, msk) - imgs = torch.cat((x.cpu(), input.cpu(), completed.cpu()), dim=0) + completed = poisson_blend(x, output, mask) + imgs = torch.cat((x.cpu(), x_mask.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_1', 'step%d.png' % pbar.n) model_cn_path = os.path.join(args.result_dir, 'phase_1', 'model_cn_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) @@ -210,14 +213,16 @@ def main(args): # fake forward x = x.to(gpu) hole_area_fake = gen_hole_area((args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) - msk = gen_input_mask( - shape=x.shape, + mask = gen_input_mask( + shape=shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=hole_area_fake, max_holes=args.max_holes, ).to(gpu) fake = torch.zeros((len(x), 1)).to(gpu) - output_cn = model_cn(x - x * msk + mpv * msk) + x_mask = x - x * mask + mpv * mask + input_cn = torch.cat((x_mask, mask)) + output_cn = model_cn(input_cn) input_gd_fake = output_cn.detach() input_ld_fake = crop(input_gd_fake, hole_area_fake) output_fake = model_cd((input_ld_fake.to(gpu), input_gd_fake.to(gpu))) @@ -251,16 +256,17 @@ def main(args): if pbar.n % args.snaperiod_2 == 0: with torch.no_grad(): x = sample_random_batch(test_dset, batch_size=args.num_test_completions).to(gpu) - msk = gen_input_mask( - shape=x.shape, + mask = gen_input_mask( + shape=shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area((args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, ).to(gpu) - input = x - x * msk + mpv * msk + x_mask = x - x * mask + mpv * mask + input = torch.cat((x_mask, mask)) output = model_cn(input) - completed = poisson_blend(x, output, msk) - imgs = torch.cat((x.cpu(), input.cpu(), completed.cpu()), dim=0) + completed = poisson_blend(x, output, mask) + imgs = torch.cat((x.cpu(), x_mask.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_2', 'step%d.png' % pbar.n) model_cd_path = os.path.join(args.result_dir, 'phase_2', 'model_cd_step%d' % pbar.n) save_image(imgs, imgpath, nrow=len(x)) @@ -283,8 +289,8 @@ def main(args): # forward model_cd x = x.to(gpu) hole_area_fake = gen_hole_area((args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) - msk = gen_input_mask( - shape=x.shape, + mask = gen_input_mask( + shape=shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=hole_area_fake, max_holes=args.max_holes, @@ -292,7 +298,9 @@ def main(args): # fake forward fake = torch.zeros((len(x), 1)).to(gpu) - output_cn = model_cn(x - x * msk + mpv * msk) + x_mask = x - x * mask + mpv * mask + input_cn = torch.cat((x_mask, mask)) + output_cn = model_cn(input_cn) input_gd_fake = output_cn.detach() input_ld_fake = crop(input_gd_fake, hole_area_fake) output_fake = model_cd((input_ld_fake, input_gd_fake)) @@ -313,7 +321,7 @@ def main(args): loss_cd.backward() # forward model_cn - loss_cn_1 = completion_network_loss(x, output_cn, msk) + loss_cn_1 = completion_network_loss(x, output_cn, mask) input_gd_fake = output_cn input_ld_fake = crop(input_gd_fake, hole_area_fake) output_fake = model_cd((input_ld_fake, (input_gd_fake))) @@ -341,16 +349,17 @@ def main(args): if pbar.n % args.snaperiod_3 == 0: with torch.no_grad(): x = sample_random_batch(test_dset, batch_size=args.num_test_completions).to(gpu) - msk = gen_input_mask( - shape=x.shape, + mask = gen_input_mask( + shape=shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area((args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, ).to(gpu) - input = x - x * msk + mpv * msk + x_mask = x - x * mask + mpv * mask + input = torch.cat((x_mask, mask)) output = model_cn(input) - completed = poisson_blend(x, output, msk) - imgs = torch.cat((x.cpu(), input.cpu(), completed.cpu()), dim=0) + completed = poisson_blend(x, output, mask) + imgs = torch.cat((x.cpu(), x_mask.cpu(), completed.cpu()), dim=0) imgpath = os.path.join(args.result_dir, 'phase_3', 'step%d.png' % pbar.n) model_cn_path = os.path.join(args.result_dir, 'phase_3', 'model_cn_step%d' % pbar.n) model_cd_path = os.path.join(args.result_dir, 'phase_3', 'model_cd_step%d' % pbar.n) From 4c2052ee42b3ebff52b3198723d0323d57b5296a Mon Sep 17 00:00:00 2001 From: otenim Date: Sat, 23 Mar 2019 17:16:40 +0900 Subject: [PATCH 11/15] fixed bugs --- train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 45ef4f1..aad1f4f 100644 --- a/train.py +++ b/train.py @@ -165,7 +165,7 @@ def main(args): with torch.no_grad(): x = sample_random_batch(test_dset, batch_size=args.num_test_completions).to(gpu) mask = gen_input_mask( - shape=shape=(x.shape[0], 1, x.shape[2], x.shape[3]), + shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area((args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, @@ -214,7 +214,7 @@ def main(args): x = x.to(gpu) hole_area_fake = gen_hole_area((args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) mask = gen_input_mask( - shape=shape=(x.shape[0], 1, x.shape[2], x.shape[3]), + shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=hole_area_fake, max_holes=args.max_holes, @@ -257,7 +257,7 @@ def main(args): with torch.no_grad(): x = sample_random_batch(test_dset, batch_size=args.num_test_completions).to(gpu) mask = gen_input_mask( - shape=shape=(x.shape[0], 1, x.shape[2], x.shape[3]), + shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area((args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, From d168e5f67d0367e19a0e88784bb030a978613ac6 Mon Sep 17 00:00:00 2001 From: otenim Date: Sat, 23 Mar 2019 17:17:28 +0900 Subject: [PATCH 12/15] fixed bugs --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index aad1f4f..5b62578 100644 --- a/train.py +++ b/train.py @@ -290,7 +290,7 @@ def main(args): x = x.to(gpu) hole_area_fake = gen_hole_area((args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])) mask = gen_input_mask( - shape=shape=(x.shape[0], 1, x.shape[2], x.shape[3]), + shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=hole_area_fake, max_holes=args.max_holes, @@ -350,7 +350,7 @@ def main(args): with torch.no_grad(): x = sample_random_batch(test_dset, batch_size=args.num_test_completions).to(gpu) mask = gen_input_mask( - shape=shape=(x.shape[0], 1, x.shape[2], x.shape[3]), + shape=(x.shape[0], 1, x.shape[2], x.shape[3]), hole_size=((args.hole_min_w, args.hole_max_w), (args.hole_min_h, args.hole_max_h)), hole_area=gen_hole_area((args.ld_input_size, args.ld_input_size), (x.shape[3], x.shape[2])), max_holes=args.max_holes, From d731d6220d332ef62fcecf878ab6510a911ac4dd Mon Sep 17 00:00:00 2001 From: otenim Date: Sat, 23 Mar 2019 17:26:13 +0900 Subject: [PATCH 13/15] fixed bugs --- train.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index 5b62578..4a89f30 100644 --- a/train.py +++ b/train.py @@ -143,7 +143,7 @@ def main(args): max_holes=args.max_holes, ).to(gpu) x_mask = x - x * mask + mpv * mask - input = torch.cat((x_mask, mask)) + input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) loss = completion_network_loss(x, output, mask) @@ -171,7 +171,7 @@ def main(args): max_holes=args.max_holes, ).to(gpu) x_mask = x - x * mask + mpv * mask - input = torch.cat((x_mask, mask)) + input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) completed = poisson_blend(x, output, mask) imgs = torch.cat((x.cpu(), x_mask.cpu(), completed.cpu()), dim=0) @@ -221,7 +221,7 @@ def main(args): ).to(gpu) fake = torch.zeros((len(x), 1)).to(gpu) x_mask = x - x * mask + mpv * mask - input_cn = torch.cat((x_mask, mask)) + input_cn = torch.cat((x_mask, mask), dim=1) output_cn = model_cn(input_cn) input_gd_fake = output_cn.detach() input_ld_fake = crop(input_gd_fake, hole_area_fake) @@ -263,7 +263,7 @@ def main(args): max_holes=args.max_holes, ).to(gpu) x_mask = x - x * mask + mpv * mask - input = torch.cat((x_mask, mask)) + input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) completed = poisson_blend(x, output, mask) imgs = torch.cat((x.cpu(), x_mask.cpu(), completed.cpu()), dim=0) @@ -299,7 +299,7 @@ def main(args): # fake forward fake = torch.zeros((len(x), 1)).to(gpu) x_mask = x - x * mask + mpv * mask - input_cn = torch.cat((x_mask, mask)) + input_cn = torch.cat((x_mask, mask), dim=1) output_cn = model_cn(input_cn) input_gd_fake = output_cn.detach() input_ld_fake = crop(input_gd_fake, hole_area_fake) @@ -356,7 +356,7 @@ def main(args): max_holes=args.max_holes, ).to(gpu) x_mask = x - x * mask + mpv * mask - input = torch.cat((x_mask, mask)) + input = torch.cat((x_mask, mask), dim=1) output = model_cn(input) completed = poisson_blend(x, output, mask) imgs = torch.cat((x.cpu(), x_mask.cpu(), completed.cpu()), dim=0) From 60f4fb2e134a6cc9a186d395feacaace7ade1e0a Mon Sep 17 00:00:00 2001 From: otenim Date: Mon, 25 Mar 2019 04:32:01 +0900 Subject: [PATCH 14/15] bug fixed --- train.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 4a89f30..97c9b5e 100644 --- a/train.py +++ b/train.py @@ -320,6 +320,14 @@ def main(args): # backward model_cd loss_cd.backward() + cnt_bdivs += 1 + print(cnt_bdivs) + if cnt_bdivs >= args.bdivs: + # optimize + opt_cd.step() + # clear grads + opt_cd.zero_grad() + # forward model_cn loss_cn_1 = completion_network_loss(x, output_cn, mask) input_gd_fake = output_cn @@ -332,15 +340,12 @@ def main(args): # backward model_cn loss_cn.backward() - cnt_bdivs += 1 if cnt_bdivs >= args.bdivs: cnt_bdivs = 0 # optimize opt_cn.step() - opt_cn.step() # clear grads - opt_cd.zero_grad() opt_cn.zero_grad() # update progbar pbar.set_description('phase 3 | train loss (cd): %.5f (cn): %.5f' % (loss_cd.cpu(), loss_cn.cpu())) From 108050e43f547e2cb6d4f25094e84adc962815a9 Mon Sep 17 00:00:00 2001 From: otenim Date: Mon, 25 Mar 2019 05:00:36 +0900 Subject: [PATCH 15/15] remove print lines --- train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/train.py b/train.py index 97c9b5e..505c5fa 100644 --- a/train.py +++ b/train.py @@ -321,7 +321,6 @@ def main(args): loss_cd.backward() cnt_bdivs += 1 - print(cnt_bdivs) if cnt_bdivs >= args.bdivs: # optimize opt_cd.step()