Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 60 additions & 5 deletions discern/discern_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from discern import util
from discern.discern_base import DisCERN
from sklearn.preprocessing import MinMaxScaler
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration

class DisCERNTabular(DisCERN):
"""
Expand Down Expand Up @@ -102,9 +104,19 @@ def find_cf(self, test_instance, test_label, desired_class='opposite', **kwargs)
if abs(val_x - val_nun) <= self.threshold:
None
else:
x_adapted[indices[now_index]] = nun_data[indices[now_index]]
changes +=1
amounts += abs(val_x - val_nun)
if self.feature_names[indices[now_index]]=="age" and nun_data[indices[now_index]] > x_adapted[indices[now_index]]:
None
elif self.feature_names[indices[now_index]] in [' Black', ' Other', ' White']:
None
elif self.feature_names[indices[now_index]]==' Self-emp-inc' and nun_data[indices[now_index]]==" Local-gov":
None
elif self.feature_names[indices[now_index]]==' Self-emp-inc' and nun_data[indices[now_index]]==" State-gov":
None
else:
x_adapted[indices[now_index]] = nun_data[indices[now_index]]
changes +=1
amounts += abs(val_x - val_nun)

new_class = self.model.predict([x_adapted])[0]
# print('new_class: '+str(new_class))
now_index += 1
Expand All @@ -120,5 +132,48 @@ def find_cf(self, test_instance, test_label, desired_class='opposite', **kwargs)
break
return x_adapted, sparsity, proximity

def show_cf(self, test_instance, test_label, cf, cf_label, **kwargs):
None
def show_cf(self, test_instance, cf, **kwargs):

PATH = "../discern/NLG_model/model.pt"
if torch.cuda.is_available():
dev = torch.device("cuda:0")
else:
dev = torch.device("cpu")

tokenizer = T5Tokenizer.from_pretrained('t5-base')
model = T5ForConditionalGeneration.from_pretrained('t5-base', return_dict=True)
model.eval()
model.to(dev)
model.load_state_dict(torch.load(PATH))

l_test = []
for i in test_instance:
l_test.append(str(i))
l_test.append(str(test_instance[i]))
test_instance = '|'.join(l_test)
input_ids = tokenizer.encode(test_instance, return_tensors="pt")
input_ids=input_ids.to(dev)
outputs = model.generate(input_ids,
do_sample=True,
max_length=50,
top_k=50,
top_p=0.95)
out_test = tokenizer.decode(outputs[0])

l_cf = []
for i in cf:
l_cf.append(str(i))
l_cf.append(str(cf[i]))
cf = '|'.join(l_cf)
input_ids = tokenizer.encode(cf, return_tensors="pt")
input_ids=input_ids.to(dev)
outputs = model.generate(input_ids,
do_sample=True,
max_length=50,
top_k=50,
top_p=0.95)
out_cf = tokenizer.decode(outputs[0])

print(f"Instance: {out_test}")
print(f"Counterfactual: {out_cf}")

3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
numpy
pandas
scikit-learn
heapq
tqdm
lime
transformers==4.11.2
sentencepiece==0.1.96
Loading