-
Notifications
You must be signed in to change notification settings - Fork 26
Open
Description
您好,我在运行BERT部分代码的时候,将mode参数设为Oracle时这个函数会因为cal_lead为FALSE,selected_ids未被初始化而报错,请问lead、和oracle分别对应您论文中的哪部分呢,方便的话请您告知,非常感谢
` def test(self, test_iter, step, cal_lead=False, cal_oracle=False):
# Set model in validating mode.
def _get_ngrams(n, text):
ngram_set = set()
text_length = len(text)
max_index_ngram_start = text_length - n
for i in range(max_index_ngram_start + 1):
ngram_set.add(tuple(text[i:i + n]))
return ngram_set
def _block_tri(c, p):
tri_c = _get_ngrams(3, c.split())
for s in p:
tri_s = _get_ngrams(3, s.split())
if len(tri_c.intersection(tri_s))>0:
return True
return False
if (not cal_lead and not cal_oracle):
self.model.eval()
stats = Statistics()
can_path = '%s_step%d.candidate'%(self.args.result_path,step)
gold_path = '%s_step%d.gold' % (self.args.result_path, step)
with open(can_path, 'w') as save_pred:
with open(gold_path, 'w') as save_gold:
with torch.no_grad():
for batch in test_iter:
gold = []
pred = []
if (cal_lead):
selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size
for i, idx in enumerate(selected_ids):
_pred = []
if(len(batch.src_str[i])==0):
continue
for j in selected_ids[i][:len(batch.src_str[i])]:
if(j>=len( batch.src_str[i])):
continue
candidate = batch.src_str[i][j].strip()
_pred.append(candidate)
if ((not cal_oracle) and (not self.args.recall_eval) and len(_pred) == 3):
break`
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels