基于paddle框架的Show, Attend and Tell: Neural Image Caption Generation with Visual Attention实现
本项目使用paddle框架复现Show, Attend and Tell模型。该论文首次将注意力机制引入到image captioning
任务中,使得模型在生成不同单词的过程中,能够关注图像的不同区域,取得了不错的效果。
注: AI Studio项目地址: https://aistudio.baidu.com/aistudio/projectdetail/2288384.
您可以使用AI Studio平台在线运行该项目!
论文:
- [1] K. Xu, J. Ba, R. Kiros, K. Cho, A. Courville, R. Salakhutdinov, R. Zemel, Y. Bengio, "Show, Attend and Tell: Neural Image Caption Generation with Visual Attention", ICML, 2015.
参考项目:
- a-PyTorch-Tutorial-to-Image-Captioning [Pytorch实现]
所有指标均为模型在Flickr8K的测试集评估而得
指标 | BlEU-1 | BlEU-2 | BlEU-3 | BlEU-4 |
---|---|---|---|---|
原论文 | 0.670 | 0.457 | 0.314 | 0.213 |
复现精度 | 0.677 | 0.494 | 0.350 | 0.243 |
本项目所使用的数据集为Flickr8K。该数据集共包含8000张图像,每张图像对应5个标题。训练集、验证集和测试集分别为6000、1000、1000张图像及其对应的标题(我们提供了脚本下载该数据集的标题以及图像特征,见download_dataset.sh)。
-
硬件:CPU、GPU
-
软件:
- Python 3.8
- Java 1.8.0
- PaddlePaddle == 2.1.0
# clone this repo
git clone https://github.com/fuqianya/show-attend-and-tell-paddle.git
cd show-attend-and-tell-paddle
pip install -r requirements.txt
# 下载数据集
bash ./download_dataset.sh
python prepro.py
python train.py
python eval.py --eval_model ./checkpoint/epoch_27.pth
模型下载: 谷歌云盘
将下载的模型权重以及训练信息放到checkpoint
目录下, 运行step6
的指令进行测试。
├── checkpoint # 存储训练的模型
├── config
│ └── config.py # 模型的参数设置
├── data # 预处理的数据
├── images # 数据集图像
├── model
│ └── encoder.py # 编码器
│ └── decoder.py # 解码器
│ └── dataloader.py # 加载训练数据
│ └── loss.py # 定义损失函数
├── pyutils
│ └── cap_eval # 计算评价指标工具
├── result # 存放生成的标题
├── utils
│ └── eval_utils.py # 测试工具
├── download_dataset.sh # 数据集下载脚本
├── prepro.py # 数据预处理
├── train.py # 训练主函数
├── eval.py # 测试主函数
└── requirement.txt # 依赖包
模型、训练的所有参数信息都在config.py
中进行了详细注释,详情见config/config.py
。
关于模型的其他信息,可以参考下表:
信息 | 说明 |
---|---|
发布者 | fuqianya |
时间 | 2021.08 |
框架版本 | Paddle 2.1.0 |
应用场景 | 多模态 |
支持硬件 | GPU、CPU |
下载链接 | 预训练模型 | 训练日志 |