Skip to content

Commit

Permalink
add xception support
Browse files Browse the repository at this point in the history
  • Loading branch information
weiaicunzai committed Nov 8, 2018
1 parent 9083877 commit 58be5bd
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ $ cd pytorch-cifar100
```

### 2. change cifar100 dataset path in conf/global_settings.py
```CIFAR100_PATH``` is the path to cifar100 dataset, you can download cifar100 by clicking [here](https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz), or download from the offical website [here](https://www.cs.toronto.edu/~kriz/cifar.html). Noet that please download the python version cifar100 dataset.
```CIFAR100_PATH``` is the path to cifar100 dataset, you can download cifar100 by clicking [here](https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz), or download from the offical website [here](https://www.cs.toronto.edu/~kriz/cifar.html). Note that please download the python version cifar100 dataset.

### 3. run tensorbard
```bash
Expand All @@ -31,9 +31,17 @@ You need to specify the net you want to train using arg -net
```bash
$ python train.py -net vgg16
```
the supported net args are:
The supported net args are:
```
vgg16, densenet121, densenet161, densenet201, googlenet, inceptionv3, inceptionv4
vgg16
densenet121
densenet161
densenet201
googlenet
inceptionv3
inceptionv4
xception
resnet101
```
Normally, the weights file with the best accuracy would be written to the disk(default in checkpoint folder).

Expand Down
6 changes: 6 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@
elif args.net == 'inceptionv4':
from models.inceptionv4 import inceptionv4
net = inceptionv4().cuda()
elif args.net == 'resnet101':
from models.resnet import resnet101
net = resnet101().cuda()
elif args.net == 'xception':
from models.xception import xception
net = xception().cuda()
else:
print('the network name you have entered is not supported yet')

Expand Down
10 changes: 9 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import os
import sys
import argparse
from datetime import datetime

Expand Down Expand Up @@ -155,9 +156,16 @@ def main(net_name, checkpoint_path, epochs, milestones):
elif args.net == 'inceptionv4':
from models.inceptionv4 import inceptionv4
net = inceptionv4().cuda()
elif args.net == 'resnet101':
from models.resnet import resnet101
net = resnet101().cuda()
elif args.net == 'xception':
from models.xception import xception
net = xception().cuda()
else:
print('the network name you have entered is not supported yet')

sys.exit()

#data preprocessing:
transform_train = transforms.Compose([
transforms.ToPILImage(),
Expand Down

0 comments on commit 58be5bd

Please sign in to comment.