Skip to content

Crispig/ASC-RACE

Repository files navigation

Task description

使用规定的预训练模型完成集群3000W功耗下的阅读理解任务

实现分布式训练,模型剪枝,模型量化,并行I/O和引入强化学习

代码结构

n_run.sh:多节点启动分布式训练脚本(因为集群搭建了NFS,所以需要区别参数)

cut_datasets.py:切分数据集

cut_pt.py:切分PT文件,用于并行I/O(为了加速读取数据,先将数据集处理后打包为PT文件)

run.py:(dev.py test.py类似)

​ 优化策略标注为“#n”(可用CTRL F定位)

​ “# 0” :并行I/O,读入数据

​ “# 1” :部署分布式以及混合精度训练(纯FP16模式)

​ “# 2” :模型剪枝,冻结前八层模型参数(使用预训练参数),降低显存,防止过拟合

​ “# 3” :去除停用词,自注意力集中在关键词上

./pytorch_pretrained_bert/modeling.py:模型代码

​ 优化策略标注为“#n”(可用CTRL F定位)

​ “# 4” :引入强化学习,在计算交叉熵之前,奖励正确选项权重(不用惩罚错的,效果不好)

About

Reading Comprehension

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published