Enriching Pre-trained Language Model with Entity Information for Relation Classification
Abstract
- Realation Classification을 수행하기 위해서는 문장에 대한 정보와 두 엔티티에 대한 정보가 필요하다.
- 문장에 대한 정보는 BERT의 last hidden states의 출력값 [CLS]토큰에 담겨있다.
- 이 논문에서 두 엔티티에 대한 정보는, 1) 타겟 엔티티들을 찾고, 2) 해당 정보를 사전학습된 언어 모델에 전달하고, 3) 두 엔티티의 해당 인코딩을 통합한다.
- SOTA 모델의 성능을 개선한다
Introduction
- Relation Classification이란, 주어진 시퀀스 s가 있고, 명사쌍 e1과 e2가 있을 때, e2 사이의 관계를 알아내는 것이다.
- 텍스트가 BERT를 거치기 전, 타겟 엔티티의 양 옆에 스페셜 토큰을 삽입해서 두 타겟 엔티티의 위치를 포착할 수 있게 된다.
- BERT를 거치고 난 위, 출력되는 임베딩에서 타겟 엔티티 위치를 찾는다. 이 임베딩을 sentence encoding(CLS Token) 처럼, Multi-Layer neural network의 입력값으로 사용
- 문장 의미, 두 타겟 엔티티 잘 포착, RE task SOTA 달성
Methodology
- 문장 의미 포착하기 위해, 문장 맨 처음 [CLS] 토큰 삽입
- 두 엔티티 위치 정보 포착하기 위해, entity1에 $, entity2에 # special token 사용
두 타겟 엔티티에 대한 연산
- BERT의 마지막 hidden state output을 H라고 할때, H(i)부터 H(j)까지의 벡터는 e1에 대한 벡터, H(k)부터 H(m)까지는 e2에 대한 벡터
- 두 타겟 엔티티에 대해 각각 Average operation을 적용
- 시그마 H(t)는 i부터 j까지의 e1의 벡터 합
- j-i+1은 i부터 j까지 e1의 길이
- 활성화 함수 tanh 거침
- tanh를 거친 벡터는 -1에서 1사이의 값을 갖게 됨
- Fully-connected layer를 통과
- W1와 W2, 그리고 b1과 b2는 같은 파라미터를 공유. W1=W2, b1=b2
[CLS] 토큰에 대한 연산
- [CLS] 토큰은 활성화 함수인 tanh를 거쳐 fully-connected layer를 통과
[CLS], e1, e2에 대한 출력값 H'1, H'2, H'3에 대한 처리
- 세 벡터를 concatenate
- Fully-connected layer 통과
- Softmax를 거쳐 예측값 p 출력
R-RoBERTa Pytorch Lightning 구현
class FCLayer(pl.LightningModule):
def __init__(self, input_dim, output_dim, dropout_rate=0.0, use_activation=True):
super().__init__()
self.save_hyperparameters()
self.use_activation = use_activation
self.dropout = torch.nn.Dropout(dropout_rate)
self.linear = torch.nn.Linear(input_dim, output_dim)
self.tanh = torch.nn.Tanh()
torch.nn.init.xavier_uniform_(self.linear.weight)
def forward(self, x):
x = self.dropout(x)
if self.use_activation:
x = self.tanh(x)
return self.linear(x)
- Fully-connected layer
- 활성화 함수 tanh 사용
class Model(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.save_hyperparameters()
self.model_name = config.model.model_name
self.lr = config.train.learning_rate
self.lr_sch_use = config.train.lr_sch_use
self.lr_decay_step = config.train.lr_decay_step
self.scheduler_name = config.train.scheduler_name
self.lr_weight_decay = config.train.lr_weight_decay
self.dr_rate = 0
self.hidden_size = 1024
self.num_classes = 30
# 사용할 모델을 호출합니다.
self.plm = transformers.RobertaModel.from_pretrained(self.model_name, add_pooling_layer=False)
self.cls_fc = FCLayer(self.hidden_size, self.hidden_size // 2, self.dr_rate)
self.sentence_fc = FCLayer(self.hidden_size, self.hidden_size // 2, self.dr_rate)
self.label_classifier = FCLayer(self.hidden_size // 2 * 3, self.num_classes, self.dr_rate, False)
# Loss 계산을 위해 사용될 CE Loss를 호출합니다.
self.loss_func = criterion_entrypoint(config.train.loss_name)
self.optimizer_name = config.train.optimizer_name
def forward(self, x):
out = self.plm(
input_ids=x["input_ids"],
attention_mask=x["attention_mask"],
token_type_ids=x["token_type_ids"],
)[0]
sentence_end_position = torch.where(x["input_ids"] == 2)[1]
sent1_end, sent2_end = sentence_end_position[0], sentence_end_position[1]
cls_vector = out[:, 0, :] # take <s> token (equiv. to [CLS])
prem_vector = out[:, 1:sent1_end] # Get Premise vector
hypo_vector = out[:, sent1_end + 1 : sent2_end] # Get Hypothesis vector
prem_vector = torch.mean(prem_vector, dim=1) # Average
hypo_vector = torch.mean(hypo_vector, dim=1)
# Dropout -> tanh -> fc_layer (Share FC layer for premise and hypothesis)
cls_embedding = self.cls_fc(cls_vector)
prem_embedding = self.sentence_fc(prem_vector)
hypo_embedding = self.sentence_fc(hypo_vector)
# Concat -> fc_layer
concat_embedding = torch.cat([cls_embedding, prem_embedding, hypo_embedding], dim=-1)
return self.label_classifier(concat_embedding)
- hidden_size 1024는 RoBERTa LARGE의 Hidden size
- label_classifeier에서 label 개수 30개 지정
'AI > NLP' 카테고리의 다른 글
Passage retrieval(문서 검색) - Sparse Embedding, Dense Embedding, Scaling up with FAISS (0) | 2023.01.05 |
---|---|
MRC(기계독해, Machine Reading Comprehension), Extraction-based MRC, Generation-based MRC (0) | 2023.01.05 |
Relation Extraction Pytorch Lightning Refactoring (0) | 2022.12.07 |
[자연어처리 모델 정리] BERT, RoBERTa, ALBERT, XLNet, ELECTRA (0) | 2022.11.05 |
[논문리뷰] Attention Is All You Need, Transformer (2) | 2022.10.13 |