Skip to content

word2vec

hankcs edited this page May 21, 2020 · 25 revisions

word2vec

快速上手

请参考调用示例:

训练

语料格式

语料格式为单个txt文本文件,文件中每行一个句子,句子是分词后的单词列表,单词与单词之间用空格分割。

就英文测试而言,可以直接用著名的text8语料:http://mattmahoney.net/dc/text8.zip ,下载后解压,得到一个纯文本格式的text8文件。

中文语料需要提前分词,用空格分割。最好直接将分词语料的标签去掉得到“黄金”语料,这样可以减小分词带来的误差(比如著名的Sighan05分词语料)。也可以预先收集大量文档,用HanLP分词后输出为纯文本。一个例子是搜狗文本分类语料库mini版已分词.txt.zip,下载前请先阅读搜狗实验室数据使用许可协议。这些语料量级较小,训练出来的词向量覆盖词汇量较小,也不够准确,仅用于此处的快速演示。

在服务器允许的条件下,语料越大越好。一般取维基百科全站dump转为简体后训练,这样量级的语料才可能用于生产环境。有一些开源项目提供了预训练的词向量下载,比如:

词向量 语言 词汇量 维度 体积
HanLP 简体 69681 300 199MB
fastText 简繁混合 332647 300 861MB
polyglot 简体 100004 64 82MB

命令行训练

为了方便测试,可以用命令行调用训练程序:

$ java -cp hanlp.jar com.hankcs.hanlp.mining.word2vec.Train
word2vec Java toolkit v 0.1c

Options:
Parameters for training:
	-output <file>
		Use <file> to save the resulting word vectors / word clusters
	-size <int>
		Set size of word vectors; default is 100
	-window <int>
		Set max skip length between words; default is 5
	-sample <float>
		Set threshold for occurrence of words. Those that appear with higher frequency in the training data will be randomly down-sampled; default is 0.001, useful range is (0, 0.00001)
	-hs <int>
		Use Hierarchical Softmax; default is 0 (not used)
	-negative <int>
		Number of negative examples; default is 5, common values are 3 - 10 (0 = not used)
	-threads <int>
		Use <int> threads (default is the cores of local machine)
	-iter <int>
		Run more training iterations (default 5)
	-min-count <int>
		This will discard words that appear less than <int> times; default is 5
	-alpha <float>
		Set the starting learning rate; default is 0.025 for skip-gram and 0.05 for CBOW
	-cbow <int>
		Use the continuous bag of words model; default is 1 (use 0 for skip-gram model)
	-input <file>
		Use text data from <file> to train the model

Examples:
java com.hankcs.hanlp.mining.word2vec.Train -input corpus.txt -output vectors.txt -size 200 -window 5 -sample 0.0001 -negative 5 -hs 0 -binary -cbow 1 -iter 3

参数与原版C程序兼容,至少需要指定训练语料的路径和模型保存路径,例子:

$ java -cp hanlp.jar com.hankcs.hanlp.mining.word2vec.Train -input msr_training.utf8 -output msr.txt

输出:

加载训练语料:100.00%
词表大小:24845
训练词数:2269300
语料词数:2368390
学习率:0.000163  进度:99.67%  剩余时间:01 s
训练结束,一共耗时:1 m 54 s 
正在保存模型到磁盘中……
模型已保存到:msr.txt

以上数据在一台普通的IBM兼容机上得到:

head /proc/cpuinfo | grep -i name
model name      : Intel(R) Xeon(R) CPU E3-1220 v5 @ 3.00GHz
head /proc/meminfo -n 1
MemTotal:        7968764 kB

视机器性能不同,大致能在数分钟内结束。若语料很大,则会耗时较长。

输出词向量的格式为文本形式,与原版C程序、以及大多数开源词向量程序的文本格式兼容。即满足:

n d
word_1 [v_1 ... v_d]
...
word_n [v_1 ... v_d]

API训练

用户也可以写代码完成训练,其接口如下:

    /**
     * 执行训练
     *
     * @param trainFileName     输入语料文件
     * @param modelFileName     输出模型路径
     * @return 词向量模型
     */
    public WordVectorModel train(String trainFileName, String modelFileName)

调用示例:

Word2VecTrainer trainerBuilder = new Word2VecTrainer();
WordVectorModel wordVectorModel = trainerBuilder.train("data/msr_training.utf8.txt", "data/msr_vectors.txt");
wordVectorModel.nearest("中国");

其中,Word2VecTrainer有许多参数,但默认的就能满足基本需求。 常用的参数有:

/**
 * 词向量的维度(等同于神经网络模型隐藏层的大小)
 * <p>
 * 默认 100
 */
public Word2VecTrainer setLayerSize(int layerSize)
/**
 * 设置迭代次数
 */
public Word2VecTrainer setNumIterations(int iterations)
/**
 * 并行化训练线程数
 * <p>
 * 默认 {@link Runtime#availableProcessors()}
 */
public Word2VecTrainer useNumThreads(int numThreads)

如果要修改更高级的参数,需要先对word2vec算法原理具备足够的理解。请参考《word2vec原理推导与代码分析》。 高级参数的修改接口已经全部封装好,此处不再一一列举,详见源码hanlp-source.jar

训练回调

由于训练有时候耗时较长,用户可以注册回调函数来接受训练进度的消息:

/**
 * 设置训练回调
 *
 * @param callback 回调接口
 */
public void setCallback(TrainingCallback callback)

其中,回调接口如下:

public interface TrainingCallback
{
    /**
     * 语料加载中
     * @param percent 已加载的百分比(0-100)
     */
    void corpusLoading(float percent);

    /**
     * 语料加载完毕
     * @param vocWords 词表大小(不是词频,而是语料中有多少种词)
     * @param trainWords 实际训练用到的词的总词频(有些词被停用词过滤掉)
     * @param totalWords 全部词语的总词频
     */
    void corpusLoaded(int vocWords, int trainWords, int totalWords);

    /**
     * 训练过程的回调
     * @param alpha 学习率
     * @param progress 训练完成百分比(0-100)
     */
    void training(float alpha, float progress);
}

词向量应用

介绍词向量模型训练完毕后,能够利用模型干什么。

加载模型

WordVectorModel wordVectorModel = new WordVectorModel("msr.txt");

如果发生加载错误,内部会抛出IOException,用户请注意捕获。

计算两个词语的语义距离

/**
 * Cosine similarity
 *
 * @param what 一个词
 * @param with 另一个词
 * @return 余弦距离
 */
public float similarity(String what, String with)

调用示例:

System.out.println(wordVectorModel.similarity("山东", "江苏"));
System.out.println(wordVectorModel.similarity("山东", "上班"));

输出:

0.81871825
0.25067142

找出与某个词语最相似的N个词语

/**
 * 查询与词语最相似的词语
 *
 * @param word 词语
 * @param size topN个
 * @return 键值对列表, 键是相似词语, 值是相似度, 按相似度降序排列
 */
public List<Map.Entry<String, Float>> nearest(String word, int size)

示例:

System.out.println(wordVectorModel.nearest("山东"));

输出:

[江苏=0.81871825, 辽宁=0.8186185, 河北=0.8115349, 河南=0.8013508, 黑龙江=0.7941345, 陕西=0.78571993, 吉林=0.7780351, 广西=0.77572066, 山西=0.77110046, 宁夏=0.7684624]

词语类比

给定三个词语A、B、C,返回与(A - B + C)语义距离最近的词语及其相似度列表。

/**
 * 返回跟 A - B + C 最相似的词语,比如 中国 - 北京 + 东京 = 日本。输入顺序按照 中国 北京 东京
 *
 * @param A 做加法的词语
 * @param B 做减法的词语
 * @param C 做加法的词语
 * @return 与(A - B + C)语义距离最近的词语及其相似度列表
 */

示例:

System.out.println(wordVectorModel.analogy("日本", "自民党", "共和党"));

返回:

[美国=0.7542177, 劳拉公司=0.69007, 金里奇=0.6892426, 白俄罗斯=0.6786959, 利比亚=0.6729529, 斯卡尔法罗=0.6684207, 金大中=0.6647056, 副总统=0.66446203, 众议员=0.6615143, 民主党=0.6608608]

该接口的记忆方法是,A对于B来讲,就像D对于C一样。D是返回结果。比如对于 中国 - 北京 + 东京 = 日本,说明中国之于北京,就如同日本之于东京一样。对于日本 - 自民党 + 共和党 = 美国也许可以理解为,自民党在日本的地位就如同共和党在美国的地位一样;但受语料影响,不一定能得到此结果,也不要过度解读。

线性运算

除了执行A-B+C之外,还可以执行任意线性运算。第一步需要先取得词语的向量:

WordVectorModel wordVectorModel = new WordVectorModel("en-vectors.txt");
System.out.println(wordVectorModel.nearest("china"));
System.out.println(wordVectorModel.analogy("king", "man", "woman"));

Vector king = wordVectorModel.vector("king");
Vector man = wordVectorModel.vector("man");
Vector woman = wordVectorModel.vector("woman");

这里的en-vectors.txt是在英文语料上训练得到的词向量,所以可以处理英文。

第二步执行线性运算,输送到nearest接口中获取跟运算结果向量最相似的词向量对应的词语及其相似度:

System.out.println(wordVectorModel.nearest(king.minus(man).add(woman)));
	输出:
[king=0.84794855, woman=0.7684506, queen=0.738009, daughter=0.7225961, betrothed=0.7158251, prince=0.71347046, isabella=0.7129812, heiress=0.7067297, wife=0.7036921, anjou=0.6993726]

排在前面的词语可能与参与运算的词语相同,这是正常现象。

文档向量

文档向量是基于词向量,将一个文档转换成向量的模型(词袋模型)。可以用于短文本的相似度计算,是一个较强的基线模型。

构造文档向量模型

只需要一个预先训练好的词向量模型即可:

DocVectorModel docVectorModel = new DocVectorModel(new WordVectorModel("data/msr_vectors.txt"));

然后加载待查询的文档(不需要分词,内部会用NotionalTokenizer分词):

String[] documents = new String[]{
        "山东苹果丰收",
        "农民在江苏种水稻",
        "奥运会女排夺冠",
        "世界锦标赛胜出",
        "中国足球失败",
};

for (int i = 0; i < documents.length; i++)
{
    docVectorModel.addDocument(i, documents[i]);
}

语义查询

可以用一个词或多个词的查询语句来得到文档库中与查询语句语义上最相似的top N个文档(id及其相似度):

System.out.println("============体育=============");
List<Map.Entry<Integer, Float>> entryList = docVectorModel.nearest("体育");
for (Map.Entry<Integer, Float> entry : entryList)
{
    System.out.printf("%d %s %.2f\n", entry.getKey(), documents[entry.getKey()], entry.getValue());
}

System.out.println("============农业=============");
entryList = docVectorModel.nearest("农业");
for (Map.Entry<Integer, Float> entry : entryList)
{
    System.out.printf("%d %s %.2f\n", entry.getKey(), documents[entry.getKey()], entry.getValue());
}

输出:

============体育=============
3 世界 锦标赛 胜出 0.42
4 中国 足球 失败 0.41
2 奥运会 女排 夺冠 0.40
============农业=============
1 农民 在 江苏 种 水稻 0.55
0 山东 苹果 丰收 0.42

在上述例子中,虽然没有任何文档出现“体育”“农业”等字样,但文档向量模型就是能找出哪些文档是与“体育”相关,哪些文档是与“农业”相关的。

文档相似度计算

本模块也可以用于文章相似度计算,只需将文章输入即可。文档相似度归结为向量夹角的计算,DocVectorModel已经封装了相应接口:

/**
 * 余弦相似度
 *
 * @param docA
 * @param docB
 * @return
 */
public float similarity(String docA, String docB)
/**
 * 余弦相似度
 *
 * @param a
 * @param b
 * @return
 */
public float similarity(Vector a, Vector b)

调用示例:

public void testSimilarity() throws Exception
{
    System.out.println(docVectorModel.similarity("山西副省长贪污腐败开庭", "陕西村干部受贿违纪"));
    System.out.println(docVectorModel.similarity("山西副省长贪污腐败开庭", "股票基金增长"));
}

输出:

0.83750665
-0.027349018

这说明第一个文档是比较相似的(接近1),而第二个文档是不太相似的(相似度为负数)。