본문 바로가기

AI/NLP

[논문리뷰+구현] R-BERT: Enriching Pre-trained Language Model with Entity Information for Relation Classification + R-RoBERTa Pytorch Lightning

Enriching Pre-trained Language Model with Entity Information for Relation Classification

Abstract

  • Realation Classification을 수행하기 위해서는 문장에 대한 정보두 엔티티에 대한 정보가 필요하다.
  • 문장에 대한 정보는 BERT의 last hidden states의 출력값 [CLS]토큰에 담겨있다.
  • 이 논문에서 두 엔티티에 대한 정보는, 1) 타겟 엔티티들을 찾고, 2) 해당 정보를 사전학습된 언어 모델에 전달하고, 3) 두 엔티티의 해당 인코딩을 통합한다.
  • SOTA 모델의 성능을 개선한다

Introduction

  • Relation Classification이란, 주어진 시퀀스 s가 있고, 명사쌍 e1과 e2가 있을 때, e2 사이의 관계를 알아내는 것이다.
  • 텍스트가 BERT를 거치기 전, 타겟 엔티티의 양 옆에 스페셜 토큰을 삽입해서 두 타겟 엔티티의 위치를 포착할 수 있게 된다.
  • BERT를 거치고 난 위, 출력되는 임베딩에서 타겟 엔티티 위치를 찾는다. 이 임베딩을 sentence encoding(CLS Token) 처럼, Multi-Layer neural network의 입력값으로 사용
  • 문장 의미, 두 타겟 엔티티 잘 포착, RE task SOTA 달성

Methodology

  • 문장 의미 포착하기 위해, 문장 맨 처음 [CLS] 토큰 삽입
  • 두 엔티티 위치 정보 포착하기 위해, entity1에 $, entity2에 # special token 사용

두 타겟 엔티티에 대한 연산

  • BERT의 마지막 hidden state output을 H라고 할때, H(i)부터 H(j)까지의 벡터는 e1에 대한 벡터, H(k)부터 H(m)까지는 e2에 대한 벡터
  • 두 타겟 엔티티에 대해 각각 Average operation을 적용
    • 시그마 H(t)는 i부터 j까지의 e1의 벡터 합
    • j-i+1은 i부터 j까지 e1의 길이
  • 활성화 함수 tanh 거침
    • tanh를 거친 벡터는 -1에서 1사이의 값을 갖게 됨
  • Fully-connected layer를 통과
    • W1와 W2, 그리고 b1과 b2는 같은 파라미터를 공유. W1=W2, b1=b2

[CLS] 토큰에 대한 연산

  • [CLS] 토큰은 활성화 함수인 tanh를 거쳐 fully-connected layer를 통과

[CLS], e1, e2에 대한 출력값 H'1, H'2, H'3에 대한 처리

  • 세 벡터를 concatenate
  • Fully-connected layer 통과
  • Softmax를 거쳐 예측값 p 출력

R-RoBERTa Pytorch Lightning 구현

class FCLayer(pl.LightningModule):
    def __init__(self, input_dim, output_dim, dropout_rate=0.0, use_activation=True):
        super().__init__()
        self.save_hyperparameters()

        self.use_activation = use_activation
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.linear = torch.nn.Linear(input_dim, output_dim)
        self.tanh = torch.nn.Tanh()

        torch.nn.init.xavier_uniform_(self.linear.weight)

    def forward(self, x):
        x = self.dropout(x)
        if self.use_activation:
            x = self.tanh(x)
        return self.linear(x)
  • Fully-connected layer
  • 활성화 함수 tanh 사용

class Model(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.save_hyperparameters()

        self.model_name = config.model.model_name
        self.lr = config.train.learning_rate
        self.lr_sch_use = config.train.lr_sch_use
        self.lr_decay_step = config.train.lr_decay_step
        self.scheduler_name = config.train.scheduler_name
        self.lr_weight_decay = config.train.lr_weight_decay
        self.dr_rate = 0
        self.hidden_size = 1024
        self.num_classes = 30

        # 사용할 모델을 호출합니다.
        self.plm = transformers.RobertaModel.from_pretrained(self.model_name, add_pooling_layer=False)
        self.cls_fc = FCLayer(self.hidden_size, self.hidden_size // 2, self.dr_rate)
        self.sentence_fc = FCLayer(self.hidden_size, self.hidden_size // 2, self.dr_rate)
        self.label_classifier = FCLayer(self.hidden_size // 2 * 3, self.num_classes, self.dr_rate, False)

        # Loss 계산을 위해 사용될 CE Loss를 호출합니다.
        self.loss_func = criterion_entrypoint(config.train.loss_name)
        self.optimizer_name = config.train.optimizer_name

    def forward(self, x):
        out = self.plm(
            input_ids=x["input_ids"],
            attention_mask=x["attention_mask"],
            token_type_ids=x["token_type_ids"],
        )[0]

        sentence_end_position = torch.where(x["input_ids"] == 2)[1]
        sent1_end, sent2_end = sentence_end_position[0], sentence_end_position[1]

        cls_vector = out[:, 0, :]  # take <s> token (equiv. to [CLS])
        prem_vector = out[:, 1:sent1_end]  # Get Premise vector
        hypo_vector = out[:, sent1_end + 1 : sent2_end]  # Get Hypothesis vector

        prem_vector = torch.mean(prem_vector, dim=1)  # Average
        hypo_vector = torch.mean(hypo_vector, dim=1)

        # Dropout -> tanh -> fc_layer (Share FC layer for premise and hypothesis)
        cls_embedding = self.cls_fc(cls_vector)
        prem_embedding = self.sentence_fc(prem_vector)
        hypo_embedding = self.sentence_fc(hypo_vector)

        # Concat -> fc_layer
        concat_embedding = torch.cat([cls_embedding, prem_embedding, hypo_embedding], dim=-1)

        return self.label_classifier(concat_embedding)
  • hidden_size 1024는 RoBERTa LARGE의 Hidden size
  • label_classifeier에서 label 개수 30개 지정