diff --git a/.gitignore b/.gitignore index 6522d78b5..4c40f1924 100644 --- a/.gitignore +++ b/.gitignore @@ -59,3 +59,4 @@ docs/_build/ # PyBuilder target/ +*.dat diff --git a/examples/ResNet/README.md b/examples/ResNet/README.md index 86629478b..3b5324af1 100644 --- a/examples/ResNet/README.md +++ b/examples/ResNet/README.md @@ -4,4 +4,7 @@ Implements the paper "Deep Residual Learning for Image Recognition", [http://arxiv.org/abs/1512.03385](http://arxiv.org/abs/1512.03385) with the variants proposed in "Identity Mappings in Deep Residual Networks", [https://arxiv.org/abs/1603.05027](https://arxiv.org/abs/1603.05027). +The train error shown here is a moving average of the error rate of each batch in training. +The validation error here is computed on test set. + ![cifar10](https://github.com/ppwwyyxx/tensorpack/raw/master/examples/ResNet/cifar10-resnet.png) diff --git a/requirements.txt b/requirements.txt index 8a054c562..7986dfc9f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ pillow scipy tqdm h5py +nltk diff --git a/tensorpack/dataflow/dataset/visualqa.py b/tensorpack/dataflow/dataset/visualqa.py new file mode 100644 index 000000000..98574606c --- /dev/null +++ b/tensorpack/dataflow/dataset/visualqa.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# File: visualqa.py +# Author: Yuxin Wu + +from ..base import DataFlow +from six.moves import zip, map +from collections import Counter +import json + +__all__ = ['VisualQA'] + +# TODO shuffle +class VisualQA(DataFlow): + """ + Visual QA dataset. See http://visualqa.org/ + Simply read q/a json file and produce q/a pairs in their original format. + """ + def __init__(self, question_file, annotation_file): + qobj = json.load(open(question_file)) + self.task_type = qobj['task_type'] + self.questions = qobj['questions'] + self._size = len(self.questions) + + aobj = json.load(open(annotation_file)) + self.anno = aobj['annotations'] + assert len(self.anno) == len(self.questions), \ + "{}!={}".format(len(self.anno), len(self.questions)) + self._clean() + + def _clean(self): + for a in self.anno: + for aa in a['answers']: + del aa['answer_id'] + + def size(self): + return self._size + + def get_data(self): + for q, a in zip(self.questions, self.anno): + assert q['question_id'] == a['question_id'] + yield [q, a] + + def get_common_answer(self, n): + """ Get the n most common answers (could be phrases) """ + cnt = Counter() + for anno in self.anno: + cnt[anno['multiple_choice_answer']] += 1 + return [k[0] for k in cnt.most_common(n)] + + def get_common_question_words(self, n): + """ + Get the n most common words in questions + """ + from nltk.tokenize import word_tokenize # will need to download 'punckt' + cnt = Counter() + for q in self.questions: + cnt.update(word_tokenize(q['question'].lower())) + del cnt['?'] # probably don't need this + ret = cnt.most_common(n) + return [k[0] for k in ret] + +if __name__ == '__main__': + vqa = VisualQA('/home/wyx/data/VQA/MultipleChoice_mscoco_train2014_questions.json', + '/home/wyx/data/VQA/mscoco_train2014_annotations.json') + for k in vqa.get_data(): + #print json.dumps(k) + break + vqa.get_common_question_words(100) + #from IPython import embed; embed()