Skip to content

Commit

Permalink
add resent18, resnet34, resnet50, resnet152 support
Browse files Browse the repository at this point in the history
  • Loading branch information
weiaicunzai committed Nov 8, 2018
1 parent 58be5bd commit d90059a
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 6 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ googlenet
inceptionv3
inceptionv4
xception
resnet18
resnet34
resnet50
resnet101
resnet150
```
Normally, the weights file with the best accuracy would be written to the disk(default in checkpoint folder).

Expand Down
18 changes: 15 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,24 @@
elif args.net == 'inceptionv4':
from models.inceptionv4 import inceptionv4
net = inceptionv4().cuda()
elif args.net == 'resnet101':
from models.resnet import resnet101
net = resnet101().cuda()
elif args.net == 'xception':
from models.xception import xception
net = xception().cuda()
elif args.net == 'resnet18':
from models.resnet import resnet18
net = resnet18().cuda()
elif args.net == 'resnet34':
from models.resnet import resnet34
net = resnet34().cuda()
elif args.net == 'resnet50':
from models.resnet import resnet50
net = resnet50().cuda()
elif args.net == 'resnet101':
from models.resnet import resnet101
net = resnet101().cuda()
elif args.net == 'resnet152':
from models.resnet import resnet152
net = resnet152().cuda()
else:
print('the network name you have entered is not supported yet')

Expand Down
18 changes: 15 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,24 @@ def main(net_name, checkpoint_path, epochs, milestones):
elif args.net == 'inceptionv4':
from models.inceptionv4 import inceptionv4
net = inceptionv4().cuda()
elif args.net == 'resnet101':
from models.resnet import resnet101
net = resnet101().cuda()
elif args.net == 'xception':
from models.xception import xception
net = xception().cuda()
elif args.net == 'resnet18':
from models.resnet import resnet18
net = resnet18().cuda()
elif args.net == 'resnet34':
from models.resnet import resnet34
net = resnet34().cuda()
elif args.net == 'resnet50':
from models.resnet import resnet50
net = resnet50().cuda()
elif args.net == 'resnet101':
from models.resnet import resnet101
net = resnet101().cuda()
elif args.net == 'resnet152':
from models.resnet import resnet152
net = resnet152().cuda()
else:
print('the network name you have entered is not supported yet')
sys.exit()
Expand Down

0 comments on commit d90059a

Please sign in to comment.