使用规定的预训练模型完成集群3000W功耗下的阅读理解任务
实现分布式训练,模型剪枝,模型量化,并行I/O和引入强化学习
n_run.sh:多节点启动分布式训练脚本(因为集群搭建了NFS,所以需要区别参数)
cut_datasets.py:切分数据集
cut_pt.py:切分PT文件,用于并行I/O(为了加速读取数据,先将数据集处理后打包为PT文件)
run.py:(dev.py test.py类似)
优化策略标注为“#n”(可用CTRL F定位)
“# 0” :并行I/O,读入数据
“# 1” :部署分布式以及混合精度训练(纯FP16模式)
“# 2” :模型剪枝,冻结前八层模型参数(使用预训练参数),降低显存,防止过拟合
“# 3” :去除停用词,自注意力集中在关键词上
./pytorch_pretrained_bert/modeling.py:模型代码
优化策略标注为“#n”(可用CTRL F定位)
“# 4” :引入强化学习,在计算交叉熵之前,奖励正确选项权重(不用惩罚错的,效果不好)