diff --git a/.gitignore b/.gitignore
index f393884..ff94ae9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,5 +1,6 @@
.ipynb_checkpoints
data/
+model/
__pycache__
KR*
tempo/
@@ -8,3 +9,5 @@ tempo/
qualitative/
outputs/
*.ipynb
+
+*.json
\ No newline at end of file
diff --git a/README.md b/README.md
index 7f8f286..95ed095 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,36 @@
# SOM-DST
+Convert code pytorch-transformers to huggingface transformers
+
+```
+# Fixed Requirements
+
+# pip install torch==1.7.1+cu110 -f https://download.pytorch.org/whl/torch_stable.html
+torch==1.7.1+cu110
+transformers==3.0.2
+wget==3.2
+jsonlines
+tqdm
+```
+
+```
+# 동일한 Parameter setting (MW 2.1) 성능 소폭 감소 (0.5309 -> 0.5275)
+------------------------------
+op_code: 4, is_gt_op: False, is_gt_p_state: False, is_gt_gen: False
+Epoch 0 joint accuracy : 0.5275515743756786
+Epoch 0 slot turn accuracy : 0.9732401375316211
+Epoch 0 slot turn F1: 0.9175307139165523
+Epoch 0 op accuracy : 0.9737830256966589
+Epoch 0 op F1 : {'delete': 0.018656716417910446, 'update': 0.8015826338020638, 'dontcare': 0.3235668789808917, 'carryover': 0.9862940159245958}
+Epoch 0 op hit count : {'delete': 15, 'update': 7496, 'dontcare': 127, 'carryover': 207607}
+Epoch 0 op all count : {'delete': 1576, 'update': 10595, 'dontcare': 581, 'carryover': 208288}
+Final Joint Accuracy : 0.3713713713713714
+Final slot turn F1 : 0.9101975987924662
+Latency Per Prediction : 24.244383 ms
+-----------------------------
+```
+
+## The original readme.md is as follows
This code is the official pytorch implementation of [Efficient Dialogue State Tracking by Selectively Overwriting Memory](https://arxiv.org/abs/1911.03906).
> [Sungdong Kim](https://github.com/dsksd), [Sohee Yang](https://github.com/soheeyang), [Gyuwan Kim](mailto:gyuwan.kim@navercorp.com), [Sang-woo Lee](https://scholar.google.co.kr/citations?user=TMTTMuQAAAAJ)
@@ -96,10 +127,10 @@ taxi 0.5903426791277259 0.9803219106957396
### Main results on MultiWOZ dataset (Joint Goal Accuracy)
-|Model |MultiWOZ 2.0 |MultWOZ 2.1|
-|-------------|------------|------------|
-|SOM-DST Base | 51.72 | 53.01 |
-|SOM-DST Large| 52.32 | 53.68 |
+| Model | MultiWOZ 2.0 | MultWOZ 2.1 |
+| ------------- | ------------ | ----------- |
+| SOM-DST Base | 51.72 | 53.01 |
+| SOM-DST Large | 52.32 | 53.68 |
## Citation
diff --git a/evaluation.py b/evaluation.py
index 8ccbb5e..560bb1c 100644
--- a/evaluation.py
+++ b/evaluation.py
@@ -5,9 +5,17 @@
"""
from utils.data_utils import prepare_dataset, MultiWozDataset
-from utils.data_utils import make_slot_meta, domain2id, OP_SET, make_turn_label, postprocessing
+from utils.data_utils import (
+ make_slot_meta,
+ domain2id,
+ OP_SET,
+ make_turn_label,
+ postprocessing,
+)
from utils.eval_utils import compute_prf, compute_acc, per_domain_join_accuracy
-from pytorch_transformers import BertTokenizer, BertConfig
+
+# from pytorch_transformers import BertTokenizer, BertConfig
+from transformers import BertTokenizer, BertConfig
from model import SomDST
import torch.nn as nn
@@ -23,51 +31,82 @@
import json
from copy import deepcopy
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args):
ontology = json.load(open(os.path.join(args.data_root, args.ontology_data)))
slot_meta, _ = make_slot_meta(ontology)
tokenizer = BertTokenizer(args.vocab_path, do_lower_case=True)
- data = prepare_dataset(os.path.join(args.data_root, args.test_data),
- tokenizer,
- slot_meta, args.n_history, args.max_seq_length, args.op_code)
+ data = prepare_dataset(
+ os.path.join(args.data_root, args.test_data),
+ tokenizer,
+ slot_meta,
+ args.n_history,
+ args.max_seq_length,
+ args.op_code,
+ )
model_config = BertConfig.from_json_file(args.bert_config_path)
model_config.dropout = 0.1
op2id = OP_SET[args.op_code]
- model = SomDST(model_config, len(op2id), len(domain2id), op2id['update'])
- ckpt = torch.load(args.model_ckpt_path, map_location='cpu')
+ model = SomDST(model_config, len(op2id), len(domain2id), op2id["update"])
+ ckpt = torch.load(args.model_ckpt_path, map_location="cpu")
model.load_state_dict(ckpt)
model.eval()
model.to(device)
if args.eval_all:
- model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
- False, False, False)
- model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
- False, False, True)
- model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
- False, True, False)
- model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
- False, True, True)
- model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
- True, False, False)
- model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
- True, True, False)
- model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
- True, False, True)
- model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
- True, True, True)
+ model_evaluation(
+ model, data, tokenizer, slot_meta, 0, args.op_code, False, False, False
+ )
+ model_evaluation(
+ model, data, tokenizer, slot_meta, 0, args.op_code, False, False, True
+ )
+ model_evaluation(
+ model, data, tokenizer, slot_meta, 0, args.op_code, False, True, False
+ )
+ model_evaluation(
+ model, data, tokenizer, slot_meta, 0, args.op_code, False, True, True
+ )
+ model_evaluation(
+ model, data, tokenizer, slot_meta, 0, args.op_code, True, False, False
+ )
+ model_evaluation(
+ model, data, tokenizer, slot_meta, 0, args.op_code, True, True, False
+ )
+ model_evaluation(
+ model, data, tokenizer, slot_meta, 0, args.op_code, True, False, True
+ )
+ model_evaluation(
+ model, data, tokenizer, slot_meta, 0, args.op_code, True, True, True
+ )
else:
- model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
- args.gt_op, args.gt_p_state, args.gt_gen)
-
-
-def model_evaluation(model, test_data, tokenizer, slot_meta, epoch, op_code='4',
- is_gt_op=False, is_gt_p_state=False, is_gt_gen=False):
+ model_evaluation(
+ model,
+ data,
+ tokenizer,
+ slot_meta,
+ 0,
+ args.op_code,
+ args.gt_op,
+ args.gt_p_state,
+ args.gt_gen,
+ )
+
+
+def model_evaluation(
+ model,
+ test_data,
+ tokenizer,
+ slot_meta,
+ epoch,
+ op_code="4",
+ is_gt_op=False,
+ is_gt_p_state=False,
+ is_gt_gen=False,
+):
model.eval()
op2id = OP_SET[op_code]
id2op = {v: k for k, v in op2id.items()}
@@ -91,19 +130,20 @@ def model_evaluation(model, test_data, tokenizer, slot_meta, epoch, op_code='4',
if is_gt_p_state is False:
i.last_dialog_state = deepcopy(last_dialog_state)
- i.make_instance(tokenizer, word_dropout=0.)
+ i.make_instance(tokenizer, word_dropout=0.0)
else: # ground-truth previous dialogue state
last_dialog_state = deepcopy(i.gold_p_state)
i.last_dialog_state = deepcopy(last_dialog_state)
- i.make_instance(tokenizer, word_dropout=0.)
+ i.make_instance(tokenizer, word_dropout=0.0)
input_ids = torch.LongTensor([i.input_id]).to(device)
input_mask = torch.FloatTensor([i.input_mask]).to(device)
segment_ids = torch.LongTensor([i.segment_id]).to(device)
state_position_ids = torch.LongTensor([i.slot_position]).to(device)
- d_gold_op, _, _ = make_turn_label(slot_meta, last_dialog_state, i.gold_state,
- tokenizer, op_code, dynamic=True)
+ d_gold_op, _, _ = make_turn_label(
+ slot_meta, last_dialog_state, i.gold_state, tokenizer, op_code, dynamic=True
+ )
gold_op_ids = torch.LongTensor([d_gold_op]).to(device)
start = time.perf_counter()
@@ -111,12 +151,14 @@ def model_evaluation(model, test_data, tokenizer, slot_meta, epoch, op_code='4',
with torch.no_grad():
# ground-truth state operation
gold_op_inputs = gold_op_ids if is_gt_op else None
- d, s, g = model(input_ids=input_ids,
- token_type_ids=segment_ids,
- state_positions=state_position_ids,
- attention_mask=input_mask,
- max_value=MAX_LENGTH,
- op_ids=gold_op_inputs)
+ d, s, g = model(
+ input_ids=input_ids,
+ token_type_ids=segment_ids,
+ state_positions=state_position_ids,
+ attention_mask=input_mask,
+ max_value=MAX_LENGTH,
+ op_ids=gold_op_inputs,
+ )
_, op_ids = s.view(-1, len(op2id)).max(-1)
@@ -133,20 +175,29 @@ def model_evaluation(model, test_data, tokenizer, slot_meta, epoch, op_code='4',
if is_gt_gen:
# ground_truth generation
- gold_gen = {'-'.join(ii.split('-')[:2]): ii.split('-')[-1] for ii in i.gold_state}
+ gold_gen = {
+ "-".join(ii.split("-")[:2]): ii.split("-")[-1] for ii in i.gold_state
+ }
else:
gold_gen = {}
- generated, last_dialog_state = postprocessing(slot_meta, pred_ops, last_dialog_state,
- generated, tokenizer, op_code, gold_gen)
+ generated, last_dialog_state = postprocessing(
+ slot_meta,
+ pred_ops,
+ last_dialog_state,
+ generated,
+ tokenizer,
+ op_code,
+ gold_gen,
+ )
end = time.perf_counter()
wall_times.append(end - start)
pred_state = []
for k, v in last_dialog_state.items():
- pred_state.append('-'.join([k, v]))
+ pred_state.append("-".join([k, v]))
if set(pred_state) == set(i.gold_state):
joint_acc += 1
- key = str(i.id) + '_' + str(i.turn_id)
+ key = str(i.id) + "_" + str(i.turn_id)
results[key] = [pred_state, i.gold_state]
# Compute prediction slot accuracy
@@ -159,7 +210,9 @@ def model_evaluation(model, test_data, tokenizer, slot_meta, epoch, op_code='4',
slot_F1_count += count
# Compute operation accuracy
- temp_acc = sum([1 if p == g else 0 for p, g in zip(pred_ops, gold_ops)]) / len(pred_ops)
+ temp_acc = sum([1 if p == g else 0 for p, g in zip(pred_ops, gold_ops)]) / len(
+ pred_ops
+ )
op_acc += temp_acc
if i.is_last_turn:
@@ -191,14 +244,20 @@ def model_evaluation(model, test_data, tokenizer, slot_meta, epoch, op_code='4',
tp = tp_dic[k]
fn = fn_dic[k]
fp = fp_dic[k]
- precision = tp / (tp+fp) if (tp+fp) != 0 else 0
- recall = tp / (tp+fn) if (tp+fn) != 0 else 0
- F1 = 2 * precision * recall / float(precision + recall) if (precision + recall) != 0 else 0
+ precision = tp / (tp + fp) if (tp + fp) != 0 else 0
+ recall = tp / (tp + fn) if (tp + fn) != 0 else 0
+ F1 = (
+ 2 * precision * recall / float(precision + recall)
+ if (precision + recall) != 0
+ else 0
+ )
op_F1_score[k] = F1
print("------------------------------")
- print('op_code: %s, is_gt_op: %s, is_gt_p_state: %s, is_gt_gen: %s' % \
- (op_code, str(is_gt_op), str(is_gt_p_state), str(is_gt_gen)))
+ print(
+ "op_code: %s, is_gt_op: %s, is_gt_p_state: %s, is_gt_gen: %s"
+ % (op_code, str(is_gt_op), str(is_gt_p_state), str(is_gt_gen))
+ )
print("Epoch %d joint accuracy : " % epoch, joint_acc_score)
print("Epoch %d slot turn accuracy : " % epoch, turn_acc_score)
print("Epoch %d slot turn F1: " % epoch, slot_F1_score)
@@ -210,31 +269,39 @@ def model_evaluation(model, test_data, tokenizer, slot_meta, epoch, op_code='4',
print("Final slot turn F1 : ", final_slot_F1_score)
print("Latency Per Prediction : %f ms" % latency)
print("-----------------------------\n")
- json.dump(results, open('preds_%d.json' % epoch, 'w'))
+ json.dump(results, open("preds_%d.json" % epoch, "w"))
per_domain_join_accuracy(results, slot_meta)
- scores = {'epoch': epoch, 'joint_acc': joint_acc_score,
- 'slot_acc': turn_acc_score, 'slot_f1': slot_F1_score,
- 'op_acc': op_acc_score, 'op_f1': op_F1_score, 'final_slot_f1': final_slot_F1_score}
+ scores = {
+ "epoch": epoch,
+ "joint_acc": joint_acc_score,
+ "slot_acc": turn_acc_score,
+ "slot_f1": slot_F1_score,
+ "op_acc": op_acc_score,
+ "op_f1": op_F1_score,
+ "final_slot_f1": final_slot_F1_score,
+ }
return scores
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument("--data_root", default='data/mwz2.1', type=str)
- parser.add_argument("--test_data", default='test_dials.json', type=str)
- parser.add_argument("--ontology_data", default='ontology.json', type=str)
- parser.add_argument("--vocab_path", default='assets/vocab.txt', type=str)
- parser.add_argument("--bert_config_path", default='assets/bert_config_base_uncased.json', type=str)
- parser.add_argument("--model_ckpt_path", default='outputs/model_best.bin', type=str)
+ parser.add_argument("--data_root", default="data/mwz2.1", type=str)
+ parser.add_argument("--test_data", default="test_dials.json", type=str)
+ parser.add_argument("--ontology_data", default="ontology.json", type=str)
+ parser.add_argument("--vocab_path", default="assets/vocab.txt", type=str)
+ parser.add_argument(
+ "--bert_config_path", default="assets/bert_config_base_uncased.json", type=str
+ )
+ parser.add_argument("--model_ckpt_path", default="outputs/model_best.bin", type=str)
parser.add_argument("--n_history", default=1, type=int)
parser.add_argument("--max_seq_length", default=256, type=int)
parser.add_argument("--op_code", default="4", type=str)
- parser.add_argument("--gt_op", default=False, action='store_true')
- parser.add_argument("--gt_p_state", default=False, action='store_true')
- parser.add_argument("--gt_gen", default=False, action='store_true')
- parser.add_argument("--eval_all", default=False, action='store_true')
+ parser.add_argument("--gt_op", default=False, action="store_true")
+ parser.add_argument("--gt_p_state", default=False, action="store_true")
+ parser.add_argument("--gt_gen", default=False, action="store_true")
+ parser.add_argument("--eval_all", default=False, action="store_true")
args = parser.parse_args()
main(args)
diff --git a/model.py b/model.py
index 7665c91..dd65daf 100644
--- a/model.py
+++ b/model.py
@@ -6,7 +6,9 @@
import torch
import torch.nn as nn
-from pytorch_transformers.modeling_bert import BertPreTrainedModel, BertModel
+
+# from pytorch_transformers.modeling_bert import BertPreTrainedModel, BertModel
+from modeling_bert import BertPreTrainedModel, BertModel
class SomDST(BertPreTrainedModel):
@@ -14,23 +16,47 @@ def __init__(self, config, n_op, n_domain, update_id, exclude_domain=False):
super(SomDST, self).__init__(config)
self.hidden_size = config.hidden_size
self.encoder = Encoder(config, n_op, n_domain, update_id, exclude_domain)
- self.decoder = Decoder(config, self.encoder.bert.embeddings.word_embeddings.weight)
+ self.decoder = Decoder(
+ config, self.encoder.bert.embeddings.word_embeddings.weight
+ )
self.apply(self.init_weights)
- def forward(self, input_ids, token_type_ids,
- state_positions, attention_mask,
- max_value, op_ids=None, max_update=None, teacher=None):
-
- enc_outputs = self.encoder(input_ids=input_ids,
- token_type_ids=token_type_ids,
- state_positions=state_positions,
- attention_mask=attention_mask,
- op_ids=op_ids,
- max_update=max_update)
-
- domain_scores, state_scores, decoder_inputs, sequence_output, pooled_output = enc_outputs
- gen_scores = self.decoder(input_ids, decoder_inputs, sequence_output,
- pooled_output, max_value, teacher)
+ def forward(
+ self,
+ input_ids,
+ token_type_ids,
+ state_positions,
+ attention_mask,
+ max_value,
+ op_ids=None,
+ max_update=None,
+ teacher=None,
+ ):
+
+ enc_outputs = self.encoder(
+ input_ids=input_ids,
+ token_type_ids=token_type_ids,
+ state_positions=state_positions,
+ attention_mask=attention_mask,
+ op_ids=op_ids,
+ max_update=max_update,
+ )
+
+ (
+ domain_scores,
+ state_scores,
+ decoder_inputs,
+ sequence_output,
+ pooled_output,
+ ) = enc_outputs
+ gen_scores = self.decoder(
+ input_ids,
+ decoder_inputs,
+ sequence_output,
+ pooled_output,
+ max_value,
+ teacher,
+ )
return domain_scores, state_scores, gen_scores
@@ -49,9 +75,15 @@ def __init__(self, config, n_op, n_domain, update_id, exclude_domain=False):
self.n_domain = n_domain
self.update_id = update_id
- def forward(self, input_ids, token_type_ids,
- state_positions, attention_mask,
- op_ids=None, max_update=None):
+ def forward(
+ self,
+ input_ids,
+ token_type_ids,
+ state_positions,
+ attention_mask,
+ op_ids=None,
+ max_update=None,
+ ):
bert_outputs = self.bert(input_ids, token_type_ids, attention_mask)
sequence_output, pooled_output = bert_outputs[:2]
state_pos = state_positions[:, :, None].expand(-1, -1, sequence_output.size(-1))
@@ -75,13 +107,23 @@ def forward(self, input_ids, token_type_ids,
n = v.size(1)
gap = max_update - n
if gap > 0:
- zeros = torch.zeros(1, 1*gap, self.hidden_size, device=input_ids.device)
+ zeros = torch.zeros(
+ 1, 1 * gap, self.hidden_size, device=input_ids.device
+ )
v = torch.cat([v, zeros], 1)
else:
- v = torch.zeros(1, max_update, self.hidden_size, device=input_ids.device)
+ v = torch.zeros(
+ 1, max_update, self.hidden_size, device=input_ids.device
+ )
gathered.append(v)
decoder_inputs = torch.cat(gathered)
- return domain_scores, state_scores, decoder_inputs, sequence_output, pooled_output.unsqueeze(0)
+ return (
+ domain_scores,
+ state_scores,
+ decoder_inputs,
+ sequence_output,
+ pooled_output.unsqueeze(0),
+ )
class Decoder(nn.Module):
@@ -90,22 +132,26 @@ def __init__(self, config, bert_model_embedding_weights):
self.pad_idx = 0
self.hidden_size = config.hidden_size
self.vocab_size = config.vocab_size
- self.embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.pad_idx)
+ self.embed = nn.Embedding(
+ config.vocab_size, config.hidden_size, padding_idx=self.pad_idx
+ )
self.embed.weight = bert_model_embedding_weights
self.gru = nn.GRU(config.hidden_size, config.hidden_size, 1, batch_first=True)
- self.w_gen = nn.Linear(config.hidden_size*3, 1)
+ self.w_gen = nn.Linear(config.hidden_size * 3, 1)
self.sigmoid = nn.Sigmoid()
self.dropout = nn.Dropout(config.dropout)
for n, p in self.gru.named_parameters():
- if 'weight' in n:
+ if "weight" in n:
p.data.normal_(mean=0.0, std=config.initializer_range)
def forward(self, x, decoder_input, encoder_output, hidden, max_len, teacher=None):
mask = x.eq(self.pad_idx)
batch_size, n_update, _ = decoder_input.size() # B,J',5 # long
state_in = decoder_input
- all_point_outputs = torch.zeros(n_update, batch_size, max_len, self.vocab_size).to(x.device)
+ all_point_outputs = torch.zeros(
+ n_update, batch_size, max_len, self.vocab_size
+ ).to(x.device)
result_dict = {}
for j in range(n_update):
w = state_in[:, j].unsqueeze(1) # B,1,D
@@ -119,13 +165,17 @@ def forward(self, x, decoder_input, encoder_output, hidden, max_len, teacher=Non
attn_history = nn.functional.softmax(attn_e, -1) # B,T
# B,D * D,V => B,V
- attn_v = torch.matmul(hidden.squeeze(0), self.embed.weight.transpose(0, 1)) # B,V
+ attn_v = torch.matmul(
+ hidden.squeeze(0), self.embed.weight.transpose(0, 1)
+ ) # B,V
attn_vocab = nn.functional.softmax(attn_v, -1)
# B,1,T * B,T,D => B,1,D
context = torch.bmm(attn_history.unsqueeze(1), encoder_output) # B,1,D
- p_gen = self.sigmoid(self.w_gen(torch.cat([w, hidden.transpose(0, 1), context], -1))) # B,1
+ p_gen = self.sigmoid(
+ self.w_gen(torch.cat([w, hidden.transpose(0, 1), context], -1))
+ ) # B,1
p_gen = p_gen.squeeze(-1)
p_context_ptr = torch.zeros_like(attn_vocab).to(x.device)
diff --git a/modeling_bert.py b/modeling_bert.py
new file mode 100644
index 0000000..6c60bf8
--- /dev/null
+++ b/modeling_bert.py
@@ -0,0 +1,1281 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch BERT model. """
+
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+import json
+import logging
+import math
+import os
+import sys
+from io import open
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss, MSELoss
+
+# from pytorch_transformers.modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel,
+# prune_linear_layer, add_start_docstrings)
+from transformers.modeling_utils import PretrainedConfig, PreTrainedModel, prune_linear_layer
+from transformers.file_utils import add_start_docstrings
+
+logger = logging.getLogger(__name__)
+
+BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
+ 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
+ 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
+ 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
+ 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
+ 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
+ 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
+ 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
+ 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
+ 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
+ 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
+ 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
+ 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
+ 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
+}
+
+BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
+ 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
+ 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
+ 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
+ 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
+ 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
+ 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
+ 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
+ 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
+ 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
+ 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
+ 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
+ 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
+}
+
+
+def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
+ """ Load tf checkpoints in a pytorch model.
+ """
+ try:
+ import re
+ import numpy as np
+ import tensorflow as tf
+ except ImportError:
+ logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions.")
+ raise
+ tf_path = os.path.abspath(tf_checkpoint_path)
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+ for name, shape in init_vars:
+ logger.info("Loading TF weight {} with shape {}".format(name, shape))
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array)
+
+ for name, array in zip(names, arrays):
+ name = name.split('/')
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
+ # which are not required for using pretrained model
+ if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
+ logger.info("Skipping {}".format("/".join(name)))
+ continue
+ pointer = model
+ for m_name in name:
+ if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
+ l = re.split(r'_(\d+)', m_name)
+ else:
+ l = [m_name]
+ if l[0] == 'kernel' or l[0] == 'gamma':
+ pointer = getattr(pointer, 'weight')
+ elif l[0] == 'output_bias' or l[0] == 'beta':
+ pointer = getattr(pointer, 'bias')
+ elif l[0] == 'output_weights':
+ pointer = getattr(pointer, 'weight')
+ elif l[0] == 'squad':
+ pointer = getattr(pointer, 'classifier')
+ else:
+ try:
+ pointer = getattr(pointer, l[0])
+ except AttributeError:
+ logger.info("Skipping {}".format("/".join(name)))
+ continue
+ if len(l) >= 2:
+ num = int(l[1])
+ pointer = pointer[num]
+ if m_name[-11:] == '_embeddings':
+ pointer = getattr(pointer, 'weight')
+ elif m_name == 'kernel':
+ array = np.transpose(array)
+ try:
+ assert pointer.shape == array.shape
+ except AssertionError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+ logger.info("Initialize PyTorch weight {}".format(name))
+ pointer.data = torch.from_numpy(array)
+ return model
+
+
+def gelu(x):
+ """Implementation of the gelu activation function.
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
+ Also see https://arxiv.org/abs/1606.08415
+ """
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
+
+
+def swish(x):
+ return x * torch.sigmoid(x)
+
+
+ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
+
+
+class BertConfig(PretrainedConfig):
+ r"""
+ :class:`~pytorch_transformers.BertConfig` is the configuration class to store the configuration of a
+ `BertModel`.
+
+
+ Arguments:
+ vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
+ hidden_size: Size of the encoder layers and the pooler layer.
+ num_hidden_layers: Number of hidden layers in the Transformer encoder.
+ num_attention_heads: Number of attention heads for each attention layer in
+ the Transformer encoder.
+ intermediate_size: The size of the "intermediate" (i.e., feed-forward)
+ layer in the Transformer encoder.
+ hidden_act: The non-linear activation function (function or string) in the
+ encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
+ hidden_dropout_prob: The dropout probabilitiy for all fully connected
+ layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob: The dropout ratio for the attention
+ probabilities.
+ max_position_embeddings: The maximum sequence length that this model might
+ ever be used with. Typically set this to something large just in case
+ (e.g., 512 or 1024 or 2048).
+ type_vocab_size: The vocabulary size of the `token_type_ids` passed into
+ `BertModel`.
+ initializer_range: The sttdev of the truncated_normal_initializer for
+ initializing all weight matrices.
+ layer_norm_eps: The epsilon used by LayerNorm.
+ """
+ pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
+
+ def __init__(self,
+ vocab_size_or_config_json_file=30522,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ **kwargs):
+ super(BertConfig, self).__init__(**kwargs)
+ if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
+ and isinstance(vocab_size_or_config_json_file, unicode)):
+ with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
+ json_config = json.loads(reader.read())
+ for key, value in json_config.items():
+ self.__dict__[key] = value
+ elif isinstance(vocab_size_or_config_json_file, int):
+ self.vocab_size = vocab_size_or_config_json_file
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ else:
+ raise ValueError("First argument must be either a vocabulary size (int)"
+ "or the path to a pretrained model config file (str)")
+
+
+
+try:
+ from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
+except (ImportError, AttributeError) as e:
+ logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
+ class BertLayerNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-12):
+ """Construct a layernorm module in the TF style (epsilon inside the square root).
+ """
+ super(BertLayerNorm, self).__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, x):
+ u = x.mean(-1, keepdim=True)
+ s = (x - u).pow(2).mean(-1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
+ return self.weight * x + self.bias
+
+# for graph fix BertEmbeddings
+class BertEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings.
+ """
+ def __init__(self, config, type_vocab_size=None):
+ super(BertEmbeddings, self).__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+
+ if type_vocab_size is not None:
+ config.type_vocab_size = type_vocab_size
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, input_ids, token_type_ids=None, position_ids=None):
+ seq_length = input_ids.size(1)
+ if position_ids is None:
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
+ if token_type_ids is None:
+ token_type_ids = torch.zeros_like(input_ids)
+
+ words_embeddings = self.word_embeddings(input_ids)
+ position_embeddings = self.position_embeddings(position_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = words_embeddings + position_embeddings + token_type_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class BertSelfAttention(nn.Module):
+ def __init__(self, config):
+ super(BertSelfAttention, self).__init__()
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
+ self.output_attentions = config.output_attentions
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(self, hidden_states, attention_mask, head_mask=None, history_states=None):
+ if history_states is None:
+ mixed_query_layer = self.query(hidden_states)
+ mixed_key_layer = self.key(hidden_states)
+ mixed_value_layer = self.value(hidden_states)
+ else:
+ x_states = torch.cat((history_states, hidden_states), dim=1)
+ mixed_query_layer = self.query(hidden_states)
+ mixed_key_layer = self.key(x_states)
+ mixed_value_layer = self.value(x_states)
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+ key_layer = self.transpose_for_scores(mixed_key_layer)
+ value_layer = self.transpose_for_scores(mixed_value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+
+ try:
+ attention_scores = attention_scores + attention_mask
+ except RuntimeError:
+ print("---RuntimeError---")
+ print("attention_scores", attention_scores.shape)
+ print("attention_mask", attention_mask.shape)
+ exit()
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
+ return outputs
+
+
+class BertSelfOutput(nn.Module):
+ def __init__(self, config):
+ super(BertSelfOutput, self).__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertAttention(nn.Module):
+ def __init__(self, config):
+ super(BertAttention, self).__init__()
+ self.self = BertSelfAttention(config)
+ self.output = BertSelfOutput(config)
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
+ for head in heads:
+ mask[head] = 0
+ mask = mask.view(-1).contiguous().eq(1)
+ index = torch.arange(len(mask))[mask].long()
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+ # Update hyper params
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+
+ def forward(self, input_tensor, attention_mask, head_mask=None, history_states=None):
+ self_outputs = self.self(input_tensor, attention_mask, head_mask, history_states=history_states)
+ attention_output = self.output(self_outputs[0], input_tensor)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class BertIntermediate(nn.Module):
+ def __init__(self, config):
+ super(BertIntermediate, self).__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class BertOutput(nn.Module):
+ def __init__(self, config):
+ super(BertOutput, self).__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertLayer(nn.Module):
+ def __init__(self, config):
+ super(BertLayer, self).__init__()
+ self.attention = BertAttention(config)
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ def forward(self, hidden_states, attention_mask, head_mask=None, history_states=None):
+ attention_outputs = self.attention(hidden_states, attention_mask, head_mask, history_states=history_states)
+ attention_output = attention_outputs[0]
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class BertEncoder(nn.Module):
+ def __init__(self, config):
+ super(BertEncoder, self).__init__()
+
+ # self.output_attentions = config.output_attentions
+ # self.output_hidden_states = config.output_hidden_states
+
+ self.output_attentions = False
+ self.output_hidden_states = True
+
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
+
+ def forward(self, hidden_states, attention_mask, head_mask=None, prev_embedding=None, prev_encoded_layers=None):
+ assert (prev_embedding is None) == (prev_encoded_layers is None)
+
+ all_hidden_states = ()
+
+ if (prev_embedding is not None) and (prev_encoded_layers is not None):
+ history_states = prev_embedding
+ for i, layer_module in enumerate(self.layer):
+ layer_outputs = layer_module(
+ hidden_states, attention_mask, head_mask[i],
+ history_states=history_states)
+ hidden_states = layer_outputs[0]
+
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ history_states = prev_encoded_layers[i]
+
+ else:
+ for i, layer_module in enumerate(self.layer):
+
+ layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
+ hidden_states = layer_outputs[0]
+
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # all_hidden_states: exclude embedding
+
+ return hidden_states, all_hidden_states
+
+
+class BertPooler(nn.Module):
+ def __init__(self, config):
+ super(BertPooler, self).__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super(BertPredictionHeadTransform, self).__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+ def __init__(self, config, bert_model_embedding_weights=None):
+ super(BertLMPredictionHead, self).__init__()
+ self.transform = BertPredictionHeadTransform(config)
+
+ if bert_model_embedding_weights is None:
+ self.decoder = nn.Linear(config.hidden_size,
+ config.vocab_size,
+ bias=False)
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+ else:
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
+ bert_model_embedding_weights.size(0),
+ bias=False)
+ self.decoder.weight = bert_model_embedding_weights
+ self.bias = nn.Parameter(torch.zeros(
+ bert_model_embedding_weights.size(0)))
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states) + self.bias
+ return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super(BertOnlyMLMHead, self).__init__()
+ self.predictions = BertLMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+class BertOnlyNSPHead(nn.Module):
+ def __init__(self, config):
+ super(BertOnlyNSPHead, self).__init__()
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+ def forward(self, pooled_output):
+ seq_relationship_score = self.seq_relationship(pooled_output)
+ return seq_relationship_score
+
+
+class BertPreTrainingHeads(nn.Module):
+ def __init__(self, config):
+ super(BertPreTrainingHeads, self).__init__()
+ self.predictions = BertLMPredictionHead(config)
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+ def forward(self, sequence_output, pooled_output):
+ prediction_scores = self.predictions(sequence_output)
+ seq_relationship_score = self.seq_relationship(pooled_output)
+ return prediction_scores, seq_relationship_score
+
+
+class BertPreTrainedModel(PreTrainedModel):
+ """ An abstract class to handle weights initialization and
+ a simple interface for dowloading and loading pretrained models.
+ """
+ config_class = BertConfig
+ pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
+ load_tf_weights = load_tf_weights_in_bert
+ base_model_prefix = "bert"
+
+ def __init__(self, *inputs, **kwargs):
+ super(BertPreTrainedModel, self).__init__(*inputs, **kwargs)
+
+ def init_weights(self, module):
+ """ Initialize the weights.
+ """
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, BertLayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+
+BERT_START_DOCSTRING = r""" The BERT model was proposed in
+ `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_
+ by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. It's a bidirectional transformer
+ pre-trained using a combination of masked language modeling objective and next sentence prediction
+ on a large corpus comprising the Toronto Book Corpus and Wikipedia.
+
+ This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
+ refer to the PyTorch documentation for all matter related to general usage and behavior.
+
+ .. _`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`:
+ https://arxiv.org/abs/1810.04805
+
+ .. _`torch.nn.Module`:
+ https://pytorch.org/docs/stable/nn.html#module
+
+ Parameters:
+ config (:class:`~pytorch_transformers.BertConfig`): Model configuration class with all the parameters of the model.
+"""
+
+BERT_INPUTS_DOCSTRING = r"""
+ Inputs:
+ **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
+ Indices of input sequence tokens in the vocabulary.
+ To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows:
+
+ (a) For sequence pairs:
+
+ ``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
+
+ ``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
+
+ (b) For single sequences:
+
+ ``tokens: [CLS] the dog is hairy . [SEP]``
+
+ ``token_type_ids: 0 0 0 0 0 0 0``
+
+ Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
+ See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
+ :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
+ **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
+ Indices of positions of each input sequence tokens in the position embeddings.
+ Selected in the range ``[0, config.max_position_embeddings - 1]``.
+ **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
+ Segment token indices to indicate first and second portions of the inputs.
+ Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
+ corresponds to a `sentence B` token
+ (see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
+ **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
+ Mask to avoid performing attention on padding token indices.
+ Mask values selected in ``[0, 1]``:
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
+ **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
+ Mask to nullify selected heads of the self-attention modules.
+ Mask values selected in ``[0, 1]``:
+ ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
+"""
+
+@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
+class BertModel(BertPreTrainedModel):
+ r"""
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
+ **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
+ Sequence of hidden-states at the output of the last layer of the model.
+ **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
+ Last layer hidden-state of the first token of the sequence (classification token)
+ further processed by a Linear layer and a Tanh activation function. The Linear
+ layer weights are trained from the next sentence prediction (classification)
+ objective during Bert pretraining. This output is usually *not* a good summary
+ of the semantic content of the input, you're often better with averaging or pooling
+ the sequence of hidden-states for the whole input sequence.
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
+ of shape ``(batch_size, sequence_length, hidden_size)``:
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
+
+ Examples::
+
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+ model = BertModel.from_pretrained('bert-base-uncased')
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
+ outputs = model(input_ids)
+ last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
+
+ """
+ def __init__(self, config, type_vocab_size=None):
+ super(BertModel, self).__init__(config)
+
+ self.embeddings = BertEmbeddings(config, type_vocab_size=type_vocab_size)
+ self.encoder = BertEncoder(config)
+ self.pooler = BertPooler(config)
+
+ self.apply(self.init_weights)
+
+ def _resize_token_embeddings(self, new_num_tokens):
+ old_embeddings = self.embeddings.word_embeddings
+ new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
+ self.embeddings.word_embeddings = new_embeddings
+ return self.embeddings.word_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """ Prunes heads of the model.
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
+ See base class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None, prev_embedding=None, prev_encoded_layers=None):
+ if attention_mask is None:
+ attention_mask = torch.ones_like(input_ids)
+ if token_type_ids is None:
+ token_type_ids = torch.zeros_like(input_ids)
+
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ if attention_mask.dim() == 2:
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
+ elif attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask.unsqueeze(1)
+ else:
+ raise NotImplementedError
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ # extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ if head_mask is not None:
+ if head_mask.dim() == 1:
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
+ head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
+ elif head_mask.dim() == 2:
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
+ head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
+ else: # TODO: This
+ head_mask = [None] * self.config.num_hidden_layers
+
+ embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
+
+ sequence_output, all_hidden_states = self.encoder(embedding_output,
+ extended_attention_mask,
+ head_mask=head_mask, prev_embedding=prev_embedding, prev_encoded_layers=prev_encoded_layers)
+
+ pooled_output = self.pooler(sequence_output)
+
+ return sequence_output, pooled_output, embedding_output, all_hidden_states
+
+
+@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
+ a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
+class BertForPreTraining(BertPreTrainedModel):
+ r"""
+ **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
+ Labels for computing the masked language modeling loss.
+ Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
+ Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
+ in ``[0, ..., config.vocab_size]``
+ **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
+ Indices should be in ``[0, 1]``.
+ ``0`` indicates sequence B is a continuation of sequence A,
+ ``1`` indicates sequence B is a random sequence.
+
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
+ **loss**: (`optional`, returned when both ``masked_lm_labels`` and ``next_sentence_label`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
+ **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ **seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)``
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
+ of shape ``(batch_size, sequence_length, hidden_size)``:
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
+
+ Examples::
+
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+ model = BertForPreTraining.from_pretrained('bert-base-uncased')
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
+ outputs = model(input_ids)
+ prediction_scores, seq_relationship_scores = outputs[:2]
+
+ """
+ def __init__(self, config):
+ super(BertForPreTraining, self).__init__(config)
+
+ self.bert = BertModel(config)
+ self.cls = BertPreTrainingHeads(config)
+
+ self.apply(self.init_weights)
+ self.tie_weights()
+
+ def tie_weights(self):
+ """ Make sure we are sharing the input and output embeddings.
+ Export to TorchScript can't handle parameter sharing so we are cloning them instead.
+ """
+ self._tie_or_clone_weights(self.cls.predictions.decoder,
+ self.bert.embeddings.word_embeddings)
+
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
+ next_sentence_label=None, position_ids=None, head_mask=None):
+ outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
+ attention_mask=attention_mask, head_mask=head_mask)
+
+ sequence_output, pooled_output = outputs[:2]
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
+
+ outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
+
+ if masked_lm_labels is not None and next_sentence_label is not None:
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
+ total_loss = masked_lm_loss + next_sentence_loss
+ outputs = (total_loss,) + outputs
+
+ return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
+
+
+@add_start_docstrings("""Bert Model with a `language modeling` head on top. """,
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
+class BertForMaskedLM(BertPreTrainedModel):
+ r"""
+ **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
+ Labels for computing the masked language modeling loss.
+ Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
+ Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
+ in ``[0, ..., config.vocab_size]``
+
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
+ **loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
+ Masked language modeling loss.
+ **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
+ of shape ``(batch_size, sequence_length, hidden_size)``:
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
+
+ Examples::
+
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+ model = BertForMaskedLM.from_pretrained('bert-base-uncased')
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
+ outputs = model(input_ids, masked_lm_labels=input_ids)
+ loss, prediction_scores = outputs[:2]
+
+ """
+ def __init__(self, config):
+ super(BertForMaskedLM, self).__init__(config)
+
+ self.bert = BertModel(config)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.apply(self.init_weights)
+ self.tie_weights()
+
+ def tie_weights(self):
+ """ Make sure we are sharing the input and output embeddings.
+ Export to TorchScript can't handle parameter sharing so we are cloning them instead.
+ """
+ self._tie_or_clone_weights(self.cls.predictions.decoder,
+ self.bert.embeddings.word_embeddings)
+
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
+ position_ids=None, head_mask=None):
+ outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
+ attention_mask=attention_mask, head_mask=head_mask)
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
+ if masked_lm_labels is not None:
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
+ outputs = (masked_lm_loss,) + outputs
+
+ return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
+
+
+@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
+class BertForNextSentencePrediction(BertPreTrainedModel):
+ r"""
+ **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
+ Indices should be in ``[0, 1]``.
+ ``0`` indicates sequence B is a continuation of sequence A,
+ ``1`` indicates sequence B is a random sequence.
+
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
+ **loss**: (`optional`, returned when ``next_sentence_label`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
+ Next sequence prediction (classification) loss.
+ **seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)``
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
+ of shape ``(batch_size, sequence_length, hidden_size)``:
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
+
+ Examples::
+
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+ model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
+ outputs = model(input_ids)
+ seq_relationship_scores = outputs[0]
+
+ """
+ def __init__(self, config):
+ super(BertForNextSentencePrediction, self).__init__(config)
+
+ self.bert = BertModel(config)
+ self.cls = BertOnlyNSPHead(config)
+
+ self.apply(self.init_weights)
+
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None,
+ position_ids=None, head_mask=None):
+ outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
+ attention_mask=attention_mask, head_mask=head_mask)
+ pooled_output = outputs[1]
+
+ seq_relationship_score = self.cls(pooled_output)
+
+ outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
+ if next_sentence_label is not None:
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
+ outputs = (next_sentence_loss,) + outputs
+
+ return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
+
+
+@add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
+ the pooled output) e.g. for GLUE tasks. """,
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
+class BertForSequenceClassification(BertPreTrainedModel):
+ r"""
+ **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
+ Labels for computing the sequence classification/regression loss.
+ Indices should be in ``[0, ..., config.num_labels - 1]``.
+ If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
+ If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
+
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
+ Classification (or regression if config.num_labels==1) loss.
+ **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
+ of shape ``(batch_size, sequence_length, hidden_size)``:
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
+
+ Examples::
+
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+ model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
+ labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
+ outputs = model(input_ids, labels=labels)
+ loss, logits = outputs[:2]
+
+ """
+ def __init__(self, config):
+ super(BertForSequenceClassification, self).__init__(config)
+ self.num_labels = config.num_labels
+
+ self.bert = BertModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
+
+ self.apply(self.init_weights)
+
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
+ position_ids=None, head_mask=None):
+ outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
+ attention_mask=attention_mask, head_mask=head_mask)
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
+
+ if labels is not None:
+ if self.num_labels == 1:
+ # We are doing regression
+ loss_fct = MSELoss()
+ loss = loss_fct(logits.view(-1), labels.view(-1))
+ else:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ outputs = (loss,) + outputs
+
+ return outputs # (loss), logits, (hidden_states), (attentions)
+
+
+@add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of
+ the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
+ BERT_START_DOCSTRING)
+class BertForMultipleChoice(BertPreTrainedModel):
+ r"""
+ Inputs:
+ **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
+ Indices of input sequence tokens in the vocabulary.
+ The second dimension of the input (`num_choices`) indicates the number of choices to score.
+ To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows:
+
+ (a) For sequence pairs:
+
+ ``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
+
+ ``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
+
+ (b) For single sequences:
+
+ ``tokens: [CLS] the dog is hairy . [SEP]``
+
+ ``token_type_ids: 0 0 0 0 0 0 0``
+
+ Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
+ See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
+ :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
+ **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
+ Segment token indices to indicate first and second portions of the inputs.
+ The second dimension of the input (`num_choices`) indicates the number of choices to score.
+ Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
+ corresponds to a `sentence B` token
+ (see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
+ **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
+ Mask to avoid performing attention on padding token indices.
+ The second dimension of the input (`num_choices`) indicates the number of choices to score.
+ Mask values selected in ``[0, 1]``:
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
+ **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
+ Mask to nullify selected heads of the self-attention modules.
+ Mask values selected in ``[0, 1]``:
+ ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
+ **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
+ Labels for computing the multiple choice classification loss.
+ Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
+ of the input tensors. (see `input_ids` above)
+
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
+ Classification loss.
+ **classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
+ of the input tensors. (see `input_ids` above).
+ Classification scores (before SoftMax).
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
+ of shape ``(batch_size, sequence_length, hidden_size)``:
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
+
+ Examples::
+
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+ model = BertForMultipleChoice.from_pretrained('bert-base-uncased')
+ choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
+ input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
+ labels = torch.tensor(1).unsqueeze(0) # Batch size 1
+ outputs = model(input_ids, labels=labels)
+ loss, classification_scores = outputs[:2]
+
+ """
+ def __init__(self, config):
+ super(BertForMultipleChoice, self).__init__(config)
+
+ self.bert = BertModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, 1)
+
+ self.apply(self.init_weights)
+
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
+ position_ids=None, head_mask=None):
+ num_choices = input_ids.shape[1]
+
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1))
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ outputs = self.bert(flat_input_ids, position_ids=flat_position_ids, token_type_ids=flat_token_type_ids,
+ attention_mask=flat_attention_mask, head_mask=head_mask)
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
+
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+ outputs = (loss,) + outputs
+
+ return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
+
+
+@add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of
+ the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
+class BertForTokenClassification(BertPreTrainedModel):
+ r"""
+ **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
+ Labels for computing the token classification loss.
+ Indices should be in ``[0, ..., config.num_labels - 1]``.
+
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
+ Classification loss.
+ **scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
+ Classification scores (before SoftMax).
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
+ of shape ``(batch_size, sequence_length, hidden_size)``:
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
+
+ Examples::
+
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+ model = BertForTokenClassification.from_pretrained('bert-base-uncased')
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
+ labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
+ outputs = model(input_ids, labels=labels)
+ loss, scores = outputs[:2]
+
+ """
+ def __init__(self, config):
+ super(BertForTokenClassification, self).__init__(config)
+ self.num_labels = config.num_labels
+
+ self.bert = BertModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ self.apply(self.init_weights)
+
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
+ position_ids=None, head_mask=None):
+ outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
+ attention_mask=attention_mask, head_mask=head_mask)
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ # Only keep active parts of the loss
+ if attention_mask is not None:
+ active_loss = attention_mask.view(-1) == 1
+ active_logits = logits.view(-1, self.num_labels)[active_loss]
+ active_labels = labels.view(-1)[active_loss]
+ loss = loss_fct(active_logits, active_labels)
+ else:
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ outputs = (loss,) + outputs
+
+ return outputs # (loss), scores, (hidden_states), (attentions)
+
+
+@add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
+ the hidden-states output to compute `span start logits` and `span end logits`). """,
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
+class BertForQuestionAnswering(BertPreTrainedModel):
+ r"""
+ **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`).
+ Position outside of the sequence are not taken into account for computing the loss.
+ **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`).
+ Position outside of the sequence are not taken into account for computing the loss.
+
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
+ **start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
+ Span-start scores (before SoftMax).
+ **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
+ Span-end scores (before SoftMax).
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
+ of shape ``(batch_size, sequence_length, hidden_size)``:
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
+
+ Examples::
+
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+ model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
+ start_positions = torch.tensor([1])
+ end_positions = torch.tensor([3])
+ outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
+ loss, start_scores, end_scores = outputs[:2]
+
+ """
+ def __init__(self, config):
+ super(BertForQuestionAnswering, self).__init__(config)
+ self.num_labels = config.num_labels
+
+ self.bert = BertModel(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ self.apply(self.init_weights)
+
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,
+ end_positions=None, position_ids=None, head_mask=None):
+ outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
+ attention_mask=attention_mask, head_mask=head_mask)
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1)
+ end_logits = end_logits.squeeze(-1)
+
+ outputs = (start_logits, end_logits,) + outputs[2:]
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions.clamp_(0, ignored_index)
+ end_positions.clamp_(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+ outputs = (total_loss,) + outputs
+
+ return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
diff --git a/requirements.txt b/requirements.txt
index 4c538d6..85ed469 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,6 @@
-pytorch-transformers==1.0.0
-torch==1.3.0a0+24ae9b5
+# pip install torch==1.7.1+cu110 -f https://download.pytorch.org/whl/torch_stable.html
+torch==1.7.1+cu110
+transformers==3.0.2
wget==3.2
+jsonlines
+tqdm
\ No newline at end of file
diff --git a/run_eval.sh b/run_eval.sh
new file mode 100644
index 0000000..8c80f49
--- /dev/null
+++ b/run_eval.sh
@@ -0,0 +1,8 @@
+#!/bin/sh
+
+MODEL_PATH='./model_2.1/model_best.bin'
+DATASET_DIR='./data/mw_2.1/'
+
+python evaluation.py\
+ --model_ckpt_path $MODEL_PATH\
+ --data_root $DATASET_DIR\
diff --git a/run_train.sh b/run_train.sh
new file mode 100644
index 0000000..41a86ab
--- /dev/null
+++ b/run_train.sh
@@ -0,0 +1,8 @@
+DATASET_DIR='./data/mw_2.1/'
+SAVE_DIR='./model_2.1/'
+
+python train.py\
+ --data_root ${DATASET_DIR}\
+ --save_dir ${SAVE_DIR}\
+ --bert_ckpt_path 'bert-base-uncased-pytorch_model.bin'\
+ --op_code '4'
\ No newline at end of file
diff --git a/train.py b/train.py
index fc0c021..4adc3f1 100644
--- a/train.py
+++ b/train.py
@@ -3,11 +3,24 @@
Copyright (c) 2020-present NAVER Corp.
MIT license
"""
-
+import sys
from model import SomDST
-from pytorch_transformers import BertTokenizer, AdamW, WarmupLinearSchedule, BertConfig
+
+# from pytorch_transformers import BertTokenizer, AdamW, WarmupLinearSchedule, BertConfig
+from transformers import (
+ BertTokenizer,
+ AdamW,
+ get_linear_schedule_with_warmup,
+ BertConfig,
+)
from utils.data_utils import prepare_dataset, MultiWozDataset
-from utils.data_utils import make_slot_meta, domain2id, OP_SET, make_turn_label, postprocessing
+from utils.data_utils import (
+ make_slot_meta,
+ domain2id,
+ OP_SET,
+ make_turn_label,
+ postprocessing,
+)
from utils.eval_utils import compute_prf, compute_acc, per_domain_join_accuracy
from utils.ckpt_utils import download_ckpt, convert_ckpt_compatible
from evaluation import model_evaluation
@@ -23,7 +36,7 @@
import time
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def masked_cross_entropy_for_value(logits, target, pad_idx=0):
@@ -64,38 +77,46 @@ def worker_init_fn(worker_id):
print(op2id)
tokenizer = BertTokenizer(args.vocab_path, do_lower_case=True)
- train_data_raw = prepare_dataset(data_path=args.train_data_path,
- tokenizer=tokenizer,
- slot_meta=slot_meta,
- n_history=args.n_history,
- max_seq_length=args.max_seq_length,
- op_code=args.op_code)
-
- train_data = MultiWozDataset(train_data_raw,
- tokenizer,
- slot_meta,
- args.max_seq_length,
- rng,
- ontology,
- args.word_dropout,
- args.shuffle_state,
- args.shuffle_p)
+ train_data_raw = prepare_dataset(
+ data_path=args.train_data_path,
+ tokenizer=tokenizer,
+ slot_meta=slot_meta,
+ n_history=args.n_history,
+ max_seq_length=args.max_seq_length,
+ op_code=args.op_code,
+ )
+
+ train_data = MultiWozDataset(
+ train_data_raw,
+ tokenizer,
+ slot_meta,
+ args.max_seq_length,
+ rng,
+ ontology,
+ args.word_dropout,
+ args.shuffle_state,
+ args.shuffle_p,
+ )
print("# train examples %d" % len(train_data_raw))
- dev_data_raw = prepare_dataset(data_path=args.dev_data_path,
- tokenizer=tokenizer,
- slot_meta=slot_meta,
- n_history=args.n_history,
- max_seq_length=args.max_seq_length,
- op_code=args.op_code)
+ dev_data_raw = prepare_dataset(
+ data_path=args.dev_data_path,
+ tokenizer=tokenizer,
+ slot_meta=slot_meta,
+ n_history=args.n_history,
+ max_seq_length=args.max_seq_length,
+ op_code=args.op_code,
+ )
print("# dev examples %d" % len(dev_data_raw))
- test_data_raw = prepare_dataset(data_path=args.test_data_path,
- tokenizer=tokenizer,
- slot_meta=slot_meta,
- n_history=args.n_history,
- max_seq_length=args.max_seq_length,
- op_code=args.op_code)
+ test_data_raw = prepare_dataset(
+ data_path=args.test_data_path,
+ tokenizer=tokenizer,
+ slot_meta=slot_meta,
+ n_history=args.n_history,
+ max_seq_length=args.max_seq_length,
+ op_code=args.op_code,
+ )
print("# test examples %d" % len(test_data_raw))
model_config = BertConfig.from_json_file(args.bert_config_path)
@@ -103,80 +124,149 @@ def worker_init_fn(worker_id):
model_config.attention_probs_dropout_prob = args.attention_probs_dropout_prob
model_config.hidden_dropout_prob = args.hidden_dropout_prob
- model = SomDST(model_config, len(op2id), len(domain2id), op2id['update'], args.exclude_domain)
+ type_vocab_size = 2 # token typy id == 2개
+ model = SomDST(
+ model_config, len(op2id), len(domain2id), op2id["update"], args.exclude_domain
+ )
if not os.path.exists(args.bert_ckpt_path):
- args.bert_ckpt_path = download_ckpt(args.bert_ckpt_path, args.bert_config_path, 'assets')
+ args.bert_ckpt_path = download_ckpt(
+ args.bert_ckpt_path, args.bert_config_path, "assets"
+ )
+
+ state_dict = torch.load(args.bert_ckpt_path, map_location="cpu")
+ _k = "bert.embeddings.token_type_embeddings.weight"
+ print(
+ "config.type_vocab_size != state_dict[bert.embeddings.token_type_embeddings.weight] ({0} != {1})".format(
+ type_vocab_size, state_dict[_k].shape[0]
+ )
+ )
+
+ keys = list(state_dict.keys())
+ for key in keys:
+ if "LayerNorm" in key:
+ if "gamma" in key:
+ state_dict[key.replace("gamma", "weight")] = state_dict.pop(key)
+ else:
+ state_dict[key.replace("beta", "bias")] = state_dict.pop(key)
+
+ from collections import OrderedDict
- ckpt = torch.load(args.bert_ckpt_path, map_location='cpu')
- model.encoder.bert.load_state_dict(ckpt)
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ name = k[5:] if k[:5] == "bert." else k # remove `bert.`
+ new_state_dict[name] = v
+
+ # load params
+ model.encoder.bert.load_state_dict(new_state_dict, strict=False)
+
+ print("\n### Done Load BERT")
+ sys.stdout.flush()
# re-initialize added special tokens ([SLOT], [NULL], [EOS])
- model.encoder.bert.embeddings.word_embeddings.weight.data[1].normal_(mean=0.0, std=0.02)
- model.encoder.bert.embeddings.word_embeddings.weight.data[2].normal_(mean=0.0, std=0.02)
- model.encoder.bert.embeddings.word_embeddings.weight.data[3].normal_(mean=0.0, std=0.02)
+ model.encoder.bert.embeddings.word_embeddings.weight.data[1].normal_(
+ mean=0.0, std=0.02
+ )
+ model.encoder.bert.embeddings.word_embeddings.weight.data[2].normal_(
+ mean=0.0, std=0.02
+ )
+ model.encoder.bert.embeddings.word_embeddings.weight.data[3].normal_(
+ mean=0.0, std=0.02
+ )
+
model.to(device)
num_train_steps = int(len(train_data_raw) / args.batch_size * args.n_epochs)
- no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
+ no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
enc_param_optimizer = list(model.encoder.named_parameters())
enc_optimizer_grouped_parameters = [
- {'params': [p for n, p in enc_param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
- {'params': [p for n, p in enc_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
- ]
+ {
+ "params": [
+ p for n, p in enc_param_optimizer if not any(nd in n for nd in no_decay)
+ ],
+ "weight_decay": 0.01,
+ },
+ {
+ "params": [
+ p for n, p in enc_param_optimizer if any(nd in n for nd in no_decay)
+ ],
+ "weight_decay": 0.0,
+ },
+ ]
enc_optimizer = AdamW(enc_optimizer_grouped_parameters, lr=args.enc_lr)
- enc_scheduler = WarmupLinearSchedule(enc_optimizer, int(num_train_steps * args.enc_warmup),
- t_total=num_train_steps)
+ enc_scheduler = get_linear_schedule_with_warmup(
+ enc_optimizer,
+ int(num_train_steps * args.enc_warmup),
+ num_training_steps=num_train_steps,
+ )
dec_param_optimizer = list(model.decoder.parameters())
dec_optimizer = AdamW(dec_param_optimizer, lr=args.dec_lr)
- dec_scheduler = WarmupLinearSchedule(dec_optimizer, int(num_train_steps * args.dec_warmup),
- t_total=num_train_steps)
+ dec_scheduler = get_linear_schedule_with_warmup(
+ dec_optimizer,
+ int(num_train_steps * args.dec_warmup),
+ num_training_steps=num_train_steps,
+ )
if n_gpu > 1:
model = torch.nn.DataParallel(model)
train_sampler = RandomSampler(train_data)
- train_dataloader = DataLoader(train_data,
- sampler=train_sampler,
- batch_size=args.batch_size,
- collate_fn=train_data.collate_fn,
- num_workers=args.num_workers,
- worker_init_fn=worker_init_fn)
+ train_dataloader = DataLoader(
+ train_data,
+ sampler=train_sampler,
+ batch_size=args.batch_size,
+ collate_fn=train_data.collate_fn,
+ num_workers=args.num_workers,
+ worker_init_fn=worker_init_fn,
+ )
loss_fnc = nn.CrossEntropyLoss()
- best_score = {'epoch': 0, 'joint_acc': 0, 'op_acc': 0, 'final_slot_f1': 0}
+ best_score = {"epoch": 0, "joint_acc": 0, "op_acc": 0, "final_slot_f1": 0}
for epoch in range(args.n_epochs):
batch_loss = []
model.train()
for step, batch in enumerate(train_dataloader):
batch = [b.to(device) if not isinstance(b, int) else b for b in batch]
- input_ids, input_mask, segment_ids, state_position_ids, op_ids,\
- domain_ids, gen_ids, max_value, max_update = batch
+ (
+ input_ids,
+ input_mask,
+ segment_ids,
+ state_position_ids,
+ op_ids,
+ domain_ids,
+ gen_ids,
+ max_value,
+ max_update,
+ ) = batch
if rng.random() < args.decoder_teacher_forcing: # teacher forcing
teacher = gen_ids
else:
teacher = None
- domain_scores, state_scores, gen_scores = model(input_ids=input_ids,
- token_type_ids=segment_ids,
- state_positions=state_position_ids,
- attention_mask=input_mask,
- max_value=max_value,
- op_ids=op_ids,
- max_update=max_update,
- teacher=teacher)
+ domain_scores, state_scores, gen_scores = model(
+ input_ids=input_ids,
+ token_type_ids=segment_ids,
+ state_positions=state_position_ids,
+ attention_mask=input_mask,
+ max_value=max_value,
+ op_ids=op_ids,
+ max_update=max_update,
+ teacher=teacher,
+ )
loss_s = loss_fnc(state_scores.view(-1, len(op2id)), op_ids.view(-1))
- loss_g = masked_cross_entropy_for_value(gen_scores.contiguous(),
- gen_ids.contiguous(),
- tokenizer.vocab['[PAD]'])
+ loss_g = masked_cross_entropy_for_value(
+ gen_scores.contiguous(), gen_ids.contiguous(), tokenizer.vocab["[PAD]"]
+ )
loss = loss_s + loss_g
if args.exclude_domain is not True:
- loss_d = loss_fnc(domain_scores.view(-1, len(domain2id)), domain_ids.view(-1))
+ loss_d = loss_fnc(
+ domain_scores.view(-1, len(domain2id)), domain_ids.view(-1)
+ )
loss = loss + loss_d
batch_loss.append(loss.item())
@@ -189,66 +279,165 @@ def worker_init_fn(worker_id):
if step % 100 == 0:
if args.exclude_domain is not True:
- print("[%d/%d] [%d/%d] mean_loss : %.3f, state_loss : %.3f, gen_loss : %.3f, dom_loss : %.3f" \
- % (epoch+1, args.n_epochs, step,
- len(train_dataloader), np.mean(batch_loss),
- loss_s.item(), loss_g.item(), loss_d.item()))
+ print(
+ "[%d/%d] [%d/%d] mean_loss : %.3f, state_loss : %.3f, gen_loss : %.3f, dom_loss : %.3f"
+ % (
+ epoch + 1,
+ args.n_epochs,
+ step,
+ len(train_dataloader),
+ np.mean(batch_loss),
+ loss_s.item(),
+ loss_g.item(),
+ loss_d.item(),
+ )
+ )
else:
- print("[%d/%d] [%d/%d] mean_loss : %.3f, state_loss : %.3f, gen_loss : %.3f" \
- % (epoch+1, args.n_epochs, step,
- len(train_dataloader), np.mean(batch_loss),
- loss_s.item(), loss_g.item()))
+ print(
+ "[%d/%d] [%d/%d] mean_loss : %.3f, state_loss : %.3f, gen_loss : %.3f"
+ % (
+ epoch + 1,
+ args.n_epochs,
+ step,
+ len(train_dataloader),
+ np.mean(batch_loss),
+ loss_s.item(),
+ loss_g.item(),
+ )
+ )
batch_loss = []
- if (epoch+1) % args.eval_epoch == 0:
- eval_res = model_evaluation(model, dev_data_raw, tokenizer, slot_meta, epoch+1, args.op_code)
- if eval_res['joint_acc'] > best_score['joint_acc']:
+ if (epoch + 1) % args.eval_epoch == 0:
+ eval_res = model_evaluation(
+ model, dev_data_raw, tokenizer, slot_meta, epoch + 1, args.op_code
+ )
+ if eval_res["joint_acc"] > best_score["joint_acc"]:
best_score = eval_res
- model_to_save = model.module if hasattr(model, 'module') else model
- save_path = os.path.join(args.save_dir, 'model_best.bin')
+ model_to_save = model.module if hasattr(model, "module") else model
+ save_path = os.path.join(args.save_dir, "model_best.bin")
torch.save(model_to_save.state_dict(), save_path)
print("Best Score : ", best_score)
print("\n")
print("Test using best model...")
- best_epoch = best_score['epoch']
- ckpt_path = os.path.join(args.save_dir, 'model_best.bin')
- model = SomDST(model_config, len(op2id), len(domain2id), op2id['update'], args.exclude_domain)
- ckpt = torch.load(ckpt_path, map_location='cpu')
+ best_epoch = best_score["epoch"]
+ ckpt_path = os.path.join(args.save_dir, "model_best.bin")
+ model = SomDST(
+ model_config, len(op2id), len(domain2id), op2id["update"], args.exclude_domain
+ )
+ ckpt = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(ckpt)
model.to(device)
- model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
- is_gt_op=False, is_gt_p_state=False, is_gt_gen=False)
- model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
- is_gt_op=False, is_gt_p_state=False, is_gt_gen=True)
- model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
- is_gt_op=False, is_gt_p_state=True, is_gt_gen=False)
- model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
- is_gt_op=False, is_gt_p_state=True, is_gt_gen=True)
- model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
- is_gt_op=True, is_gt_p_state=False, is_gt_gen=False)
- model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
- is_gt_op=True, is_gt_p_state=True, is_gt_gen=False)
- model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
- is_gt_op=True, is_gt_p_state=False, is_gt_gen=True)
- model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
- is_gt_op=True, is_gt_p_state=True, is_gt_gen=True)
+ model_evaluation(
+ model,
+ test_data_raw,
+ tokenizer,
+ slot_meta,
+ best_epoch,
+ args.op_code,
+ is_gt_op=False,
+ is_gt_p_state=False,
+ is_gt_gen=False,
+ )
+ model_evaluation(
+ model,
+ test_data_raw,
+ tokenizer,
+ slot_meta,
+ best_epoch,
+ args.op_code,
+ is_gt_op=False,
+ is_gt_p_state=False,
+ is_gt_gen=True,
+ )
+ model_evaluation(
+ model,
+ test_data_raw,
+ tokenizer,
+ slot_meta,
+ best_epoch,
+ args.op_code,
+ is_gt_op=False,
+ is_gt_p_state=True,
+ is_gt_gen=False,
+ )
+ model_evaluation(
+ model,
+ test_data_raw,
+ tokenizer,
+ slot_meta,
+ best_epoch,
+ args.op_code,
+ is_gt_op=False,
+ is_gt_p_state=True,
+ is_gt_gen=True,
+ )
+ model_evaluation(
+ model,
+ test_data_raw,
+ tokenizer,
+ slot_meta,
+ best_epoch,
+ args.op_code,
+ is_gt_op=True,
+ is_gt_p_state=False,
+ is_gt_gen=False,
+ )
+ model_evaluation(
+ model,
+ test_data_raw,
+ tokenizer,
+ slot_meta,
+ best_epoch,
+ args.op_code,
+ is_gt_op=True,
+ is_gt_p_state=True,
+ is_gt_gen=False,
+ )
+ model_evaluation(
+ model,
+ test_data_raw,
+ tokenizer,
+ slot_meta,
+ best_epoch,
+ args.op_code,
+ is_gt_op=True,
+ is_gt_p_state=False,
+ is_gt_gen=True,
+ )
+ model_evaluation(
+ model,
+ test_data_raw,
+ tokenizer,
+ slot_meta,
+ best_epoch,
+ args.op_code,
+ is_gt_op=True,
+ is_gt_p_state=True,
+ is_gt_gen=True,
+ )
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
- parser.add_argument("--data_root", default='data/mwz2.1', type=str)
- parser.add_argument("--train_data", default='train_dials.json', type=str)
- parser.add_argument("--dev_data", default='dev_dials.json', type=str)
- parser.add_argument("--test_data", default='test_dials.json', type=str)
- parser.add_argument("--ontology_data", default='ontology.json', type=str)
- parser.add_argument("--vocab_path", default='assets/vocab.txt', type=str)
- parser.add_argument("--bert_config_path", default='assets/bert_config_base_uncased.json', type=str)
- parser.add_argument("--bert_ckpt_path", default='assets/bert-base-uncased-pytorch_model.bin', type=str)
- parser.add_argument("--save_dir", default='outputs', type=str)
+ parser.add_argument("--data_root", default="data/mwz2.1", type=str)
+ parser.add_argument("--train_data", default="train_dials.json", type=str)
+ parser.add_argument("--dev_data", default="dev_dials.json", type=str)
+ parser.add_argument("--test_data", default="test_dials.json", type=str)
+ parser.add_argument("--ontology_data", default="ontology.json", type=str)
+ parser.add_argument("--vocab_path", default="assets/vocab.txt", type=str)
+ parser.add_argument(
+ "--bert_config_path", default="./assets/bert_config_base_uncased.json", type=str
+ )
+ parser.add_argument(
+ "--bert_ckpt_path",
+ default="./assets/bert-base-uncased-pytorch_model.bin",
+ type=str,
+ )
+ parser.add_argument("--save_dir", default="outputs", type=str)
parser.add_argument("--random_seed", default=42, type=int)
parser.add_argument("--num_workers", default=4, type=int)
@@ -267,13 +456,13 @@ def worker_init_fn(worker_id):
parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float)
parser.add_argument("--decoder_teacher_forcing", default=0.5, type=float)
parser.add_argument("--word_dropout", default=0.1, type=float)
- parser.add_argument("--not_shuffle_state", default=False, action='store_true')
+ parser.add_argument("--not_shuffle_state", default=False, action="store_true")
parser.add_argument("--shuffle_p", default=0.5, type=float)
parser.add_argument("--n_history", default=1, type=int)
parser.add_argument("--max_seq_length", default=256, type=int)
parser.add_argument("--msg", default=None, type=str)
- parser.add_argument("--exclude_domain", default=False, action='store_true')
+ parser.add_argument("--exclude_domain", default=False, action="store_true")
args = parser.parse_args()
args.train_data_path = os.path.join(args.data_root, args.train_data)
@@ -281,6 +470,6 @@ def worker_init_fn(worker_id):
args.test_data_path = os.path.join(args.data_root, args.test_data)
args.ontology_data = os.path.join(args.data_root, args.ontology_data)
args.shuffle_state = False if args.not_shuffle_state else True
- print('pytorch version: ', torch.__version__)
+ print("pytorch version: ", torch.__version__)
print(args)
main(args)
diff --git a/utils/ckpt_utils.py b/utils/ckpt_utils.py
index 6902d38..0a4d8b0 100644
--- a/utils/ckpt_utils.py
+++ b/utils/ckpt_utils.py
@@ -1,27 +1,29 @@
import wget
import os
import torch
-from pytorch_transformers import BertForPreTraining, BertConfig
+
+# from pytorch_transformers import BertForPreTraining, BertConfig
+from transformers import BertForPreTraining, BertConfig
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
- 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
- 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
+ "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
+ "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
}
-def download_ckpt(ckpt_path, config_path, target_path='assets'):
+def download_ckpt(ckpt_path, config_path, target_path="assets"):
key = None
- if 'base' in ckpt_path.lower():
- key = 'bert-base-uncased'
- if 'large' in ckpt_path.lower():
- key = 'bert-large-uncased'
+ if "base" in ckpt_path.lower():
+ key = "bert-base-uncased"
+ if "large" in ckpt_path.lower():
+ key = "bert-large-uncased"
assert key in BERT_PRETRAINED_MODEL_ARCHIVE_MAP
url_path = BERT_PRETRAINED_MODEL_ARCHIVE_MAP[key]
- print('start download %s from huggingface' % key)
+ print("start download %s from huggingface" % key)
wget.download(url_path, out=target_path)
- ckpt_path = os.path.join(target_path, key + '-pytorch_model.bin')
+ ckpt_path = os.path.join(target_path, key + "-pytorch_model.bin")
ckpt = convert_ckpt_compatible(ckpt_path, config_path)
torch.save(ckpt, ckpt_path)
@@ -29,14 +31,14 @@ def download_ckpt(ckpt_path, config_path, target_path='assets'):
def convert_ckpt_compatible(ckpt_path, config_path):
- ckpt = torch.load(ckpt_path, map_location='cpu')
+ ckpt = torch.load(ckpt_path, map_location="cpu")
keys = list(ckpt.keys())
for key in keys:
- if 'LayerNorm' in key:
- if 'gamma' in key:
- ckpt[key.replace('gamma', 'weight')] = ckpt.pop(key)
+ if "LayerNorm" in key:
+ if "gamma" in key:
+ ckpt[key.replace("gamma", "weight")] = ckpt.pop(key)
else:
- ckpt[key.replace('beta', 'bias')] = ckpt.pop(key)
+ ckpt[key.replace("beta", "bias")] = ckpt.pop(key)
model_config = BertConfig.from_json_file(config_path)
model = BertForPreTraining(model_config)