Knowledge-Aware Languege Model Pretraining

文章研究向 PLM 注入知识的方法,提出一种不需更改 transformer layer,不需添加额外的 knowledge-aware layer 的知识注入方法:KALM。

Introduction

PLM 生成的语言表征被认为含有丰富的词法语法知识,但是在缺乏显示知识性任务训练时,PLM 常常生成语法上正确,但事实上错误的文本。这说明 PLM 中缺乏语义,以及事实知识。

作者提出一种促使 PLM 注意输入文本中的实体以及其在文本中的角色的预训练模式,在没有使模型变得更大的情况下向 PLM 注入知识。

作者首先将输入语句中的 word span 连接到其指代的实体上,然后同时为 word 生成 word embedding 和 entity embedding。在输出层,除去 PLM 的语言建模目标之外,作者添加了一个 entity prediction task 引导模型从干扰项中分辨出 word 所指向的实体。这两个训练目标综合起来即显示地引导模型不仅要预测出正确的词(语法,句法,语义知识),还要预测出这些词所指代的实体(事实知识)。

KALM

作者在自回归模型的基础上设计 KALM,对于一个 nn 个 tokens 的序列 X={w1,w2,...,wn}X = \{w_1, w_2, ..., w_n\}, 自回归模型由以下语言概率分解描述:

p(X)=ip(wiw<i)p(X) = \prod_i p(w_i|w_{<i})

在 PLM 中,通常由 transformer layer 计算上述概率: p(wiw<i)=transformer(wiw<i)p(w_i|w_{<i}) = \text{transformer}(w_i|w_{<i})

上述过程中,PLM 间接通过词之间的共现模式捕捉语义知识。作者则通过向模型提供一个信号提示输入/输出中实体的存在以促使模型对知识的注意,期望 PLM 能够从输入语句中捕捉事实知识

Entity Tokenizer

作者首先使用一个 Entity Tokenizer 将输入文本中所有的 token 与其所指代的最常出现的实体连接:wi:i+keiw_{i:i+k}\rightarrow e_i, 其中 eie_i 是 word span wi:i+kw_{i:i+k} 所最经常指代的实体。当 wiw_i 不属于任何已知的实体时:ei=nulle_i = null. 上述过程通过在一个预定义的实体词典上进行文本匹配进行。

经过 tokenize 后,输入语句被分成词-实体两个 token 序列:

Xduet={{w1,w2,,wT}Word Sequence{e1,e2,,eT}Entity SequenceX_{duet} = \begin{cases} \{w_1, w_2, \dots, w_T\} &\text{Word Sequence}\\ \{e_1, e_2, \dots, e_T\} &\text{Entity Sequence} \end{cases}

上述两个序列逐位置对齐,当有多个词对应同一个实体时,如 wi:i+kw_{i:i+k} 对应一个实体,则 eie_iei+ke_{i+k} 是相同的。

Knowledge-Aware Input

经过 tokenize 后,作者为两个 token 序列分别生成嵌入:

ei=Embeddinge(ei)Rdewi=Embeddingw(wi)Rdw\begin{aligned} \mathbf{e}_i &= \text{Embedding}_e(e_i)\in\mathbb{R}^{d_e}\\ \mathbf{w}_i &= \text{Embedding}_w(w_i)\in\mathbb{R}^{d_w} \end{aligned}

两个嵌入线性相加作为模型的输入嵌入:ti=wi+Lineart(ei)\mathbf{t}_i = \mathbf{w}_i + \text{Linear}_t(\mathbf{e}_i), 其中 LineartRde×dw\text{Linear}_t\in\mathbb{R}^{d_e\times d_w}

Knowledge-Aware Output

在输出层上,除自回归模型的 next-word prediction 任务外,作者添加了一个 next-entity prediction 任务。

具体来说,作者添加了一个 output head 进行实体辨别。记 LL 层 transformer 层输出的第 ii 个 token 的表征为 hiL\mathbf{h}_i^L, 则第 ii 个位置的实体损失计算为:

le(eit<i)=max(0,s(hiL,ei)s(hiL,e)+λ)s(hiL,ej)=cos(Linear(hiL),ej)hiL=transformerL(t<i)\begin{aligned} l_e(e_i|t_{<i}) &= \max(0, s(\mathbf{h}_i^L, \mathbf{e}_i) - s(\mathbf{h}_i^L, \mathbf{e}_-)+\lambda)\\ s(\mathbf{h}_i^L, \mathbf{e}_j) &= \cos(\text{Linear}(\mathbf{h}_i^L), \mathbf{e}_j)\\ \mathbf{h}_i^L &= \text{transformer}^L(t_{<i}) \end{aligned}

上式中,eie_i 指第 ii 个 token 所指代的实体,ee_- 指作者从除 eie_i 之外的实体中采样得到的负例,该损失促使模型分辨 tokenitoken_i 所指代的实体。

值得注意的是,在实验中,作者使用的负例采样策略为:1% 的 nullnull , 49% 的随机采样实体,50% 的从目标实体的 Trans-E 嵌入空间中最近的 100 个实体中采样(被认为是难以分辨的负例)。

Pretraining

KALM 的总体损失是:

lKALM(Xduet=ilw(p(wit<i))+αle(eit<i)l_\text{KALM}(X_{duet} = \sum_i l_w(p(w_i|t_{<i})) + \alpha l_e(e_i|t_{<i})

Inference

推理时,仅使用 word prediction head。

实验

作者进行了知识嗅探评测和 zero-shot QA 任务评测。