From 2aab074b6033213cf19f8e1b3be2ed5e70f911c5 Mon Sep 17 00:00:00 2001 From: Pedram Salimi Date: Wed, 6 Oct 2021 16:31:47 +0200 Subject: [PATCH 1/7] [#Add : transformers and sentencepiece] --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7364288..7c04b49 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ numpy pandas scikit-learn -heapq tqdm lime +transformers==4.11.2 +sentencepiece==0.1.96 \ No newline at end of file From 9ca3e9dca674e173a7979cff610e799ec0239f5a Mon Sep 17 00:00:00 2001 From: Pedram Salimi Date: Wed, 6 Oct 2021 16:33:49 +0200 Subject: [PATCH 2/7] [#Add : Implement code to generate a textual representation with T5 conditional generation] --- discern/discern_tabular.py | 49 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/discern/discern_tabular.py b/discern/discern_tabular.py index 3d515ad..38b3477 100755 --- a/discern/discern_tabular.py +++ b/discern/discern_tabular.py @@ -4,6 +4,8 @@ from discern import util from discern.discern_base import DisCERN from sklearn.preprocessing import MinMaxScaler +import torch +from transformers import T5Tokenizer, T5ForConditionalGeneration class DisCERNTabular(DisCERN): """ @@ -120,5 +122,48 @@ def find_cf(self, test_instance, test_label, desired_class='opposite', **kwargs) break return x_adapted, sparsity, proximity - def show_cf(self, test_instance, test_label, cf, cf_label, **kwargs): - None \ No newline at end of file + def show_cf(self, test_instance, cf, **kwargs): + + PATH = "../discern/NLG_model/model.pt" + if torch.cuda.is_available(): + dev = torch.device("cuda:0") + else: + dev = torch.device("cpu") + + tokenizer = T5Tokenizer.from_pretrained('t5-base') + model = T5ForConditionalGeneration.from_pretrained('t5-base', return_dict=True) + model.eval() + model.to(dev) + model.load_state_dict(torch.load(PATH)) + + l_test = [] + for i in test_instance: + l_test.append(str(i)) + l_test.append(str(test_instance[i])) + test_instance = '|'.join(l_test) + input_ids = tokenizer.encode(test_instance, return_tensors="pt") + input_ids=input_ids.to(dev) + outputs = model.generate(input_ids, + do_sample=True, + max_length=50, + top_k=50, + top_p=0.95) + out_test = tokenizer.decode(outputs[0]) + + l_cf = [] + for i in cf: + l_cf.append(str(i)) + l_cf.append(str(cf[i])) + cf = '|'.join(l_cf) + input_ids = tokenizer.encode(cf, return_tensors="pt") + input_ids=input_ids.to(dev) + outputs = model.generate(input_ids, + do_sample=True, + max_length=50, + top_k=50, + top_p=0.95) + out_cf = tokenizer.decode(outputs[0]) + + print(f"Instance: {out_test}") + print(f"Counterfactual: {out_cf}") + \ No newline at end of file From a3b28e4684720575a6055f4902dd8b08087ad81a Mon Sep 17 00:00:00 2001 From: Pedram Salimi Date: Wed, 6 Oct 2021 16:34:20 +0200 Subject: [PATCH 3/7] [#Add : Implement tests to demonstrate show_cf functionality] --- tests/test_show_cf.py | 50 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 tests/test_show_cf.py diff --git a/tests/test_show_cf.py b/tests/test_show_cf.py new file mode 100644 index 0000000..4c9af6a --- /dev/null +++ b/tests/test_show_cf.py @@ -0,0 +1,50 @@ +from discern.discern_tabular import DisCERNTabular +import pandas as pd +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import MinMaxScaler +from sklearn.ensemble import RandomForestClassifier +from sklearn.svm import SVC +from sklearn.metrics import accuracy_score + +data_df = pd.read_csv('adult_income.csv') +df = data_df.copy() +print("Reading data complete!") +df = df.drop("Unnamed: 0", axis=1) +x = df.loc[:, df.columns != 'salary'].values +y = df['salary'].values +x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.33, random_state=42) +print("Train test split complete!") + + +scalar = MinMaxScaler() +x_train= scalar.fit_transform(x_train) +x_test = scalar.transform(x_test) +print("Data transform complete!") + +svm = SVC(probability=True) +svm.fit(x_train, y_train) +print("Training classifier complete!") +print(accuracy_score(y_test, svm.predict(x_test))) + +x_test = x_test[0] +y_test = svm.predict([x_test]) + +discern = DisCERNTabular(svm, 'LIME', 'Q') +discern.init_data(x_train, y_train, [c for c in df.columns if c!='salary'], ['<=50K', '>50K'], cat_feature_indices=[]) + + +cf, _, _ = discern.find_cf(x_test, y_test) + +x = scalar.inverse_transform([x_test]) +c = scalar.inverse_transform([cf]) + +cls = list(df.columns) +x=pd.DataFrame(x,columns=[c for c in df.columns if c!='salary']) +c=pd.DataFrame(c,columns=[c for c in df.columns if c!='salary']) +x['salary'] = y_test +c['salary'] = 1 # Temporary + +x=x.to_dict(orient="index")[0] +c=c.to_dict(orient="index")[0] + +discern.show_cf(x, c) \ No newline at end of file From 1bde75c7949d5a3136de2b46bbc790aa675e9ba2 Mon Sep 17 00:00:00 2001 From: Pedram Salimi Date: Wed, 6 Oct 2021 17:55:26 +0200 Subject: [PATCH 4/7] [#Add : conditions for cf discovery] --- discern/discern_tabular.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/discern/discern_tabular.py b/discern/discern_tabular.py index 38b3477..1847e81 100755 --- a/discern/discern_tabular.py +++ b/discern/discern_tabular.py @@ -104,9 +104,19 @@ def find_cf(self, test_instance, test_label, desired_class='opposite', **kwargs) if abs(val_x - val_nun) <= self.threshold: None else: - x_adapted[indices[now_index]] = nun_data[indices[now_index]] - changes +=1 - amounts += abs(val_x - val_nun) + if self.feature_names[indices[now_index]]=="age" and nun_data[indices[now_index]] > x_adapted[indices[now_index]]: + None + elif self.feature_names[indices[now_index]] in [' Black', ' Other', ' White']: + None + elif self.feature_names[indices[now_index]]==' Self-emp-inc' and nun_data[indices[now_index]]==" Local-gov": + None + elif self.feature_names[indices[now_index]]==' Self-emp-inc' and nun_data[indices[now_index]]==" State-gov": + None + else: + x_adapted[indices[now_index]] = nun_data[indices[now_index]] + changes +=1 + amounts += abs(val_x - val_nun) + new_class = self.model.predict([x_adapted])[0] # print('new_class: '+str(new_class)) now_index += 1 From 8f15d9692042e8f84c4201967e70da3fa67b2e82 Mon Sep 17 00:00:00 2001 From: Pedram Salimi Date: Wed, 6 Oct 2021 17:56:49 +0200 Subject: [PATCH 5/7] [#Add : implementation of test file in order to test applied conditions] --- tests/CF_conditions_test.py | 78 +++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 tests/CF_conditions_test.py diff --git a/tests/CF_conditions_test.py b/tests/CF_conditions_test.py new file mode 100644 index 0000000..8c77397 --- /dev/null +++ b/tests/CF_conditions_test.py @@ -0,0 +1,78 @@ +from discern.discern_tabular import DisCERNTabular +import pandas as pd +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import MinMaxScaler +from sklearn.ensemble import RandomForestClassifier +from sklearn.svm import SVC +from sklearn.metrics import accuracy_score + +data_df = pd.read_csv('adult_income.csv') +df = data_df.copy() +print("Reading data complete!") +df = df.drop("Unnamed: 0", axis=1) +x = df.loc[:, df.columns != 'salary'].values +y = df['salary'].values +x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.33, random_state=42) +print("Train test split complete!") + + +scalar = MinMaxScaler() +x_train= scalar.fit_transform(x_train) +x_test = scalar.transform(x_test) +print("Data transform complete!") + +svm = SVC(probability=True) +svm.fit(x_train, y_train) +print("Training classifier complete!") +print(accuracy_score(y_test, svm.predict(x_test))) + +x_test = x_test[0] +y_test = svm.predict([x_test]) +""" + changed race to black in order to check is + wether it meet the changed condition or not +""" +x_test[-8] = 1.0 + +discern = DisCERNTabular(svm, 'LIME', 'Q') +discern.init_data(x_train, y_train, [c for c in df.columns if c!='salary'], ['<=50K', '>50K'], cat_feature_indices=[]) + + +cf, _, _ = discern.find_cf(x_test, y_test) + +# data_df = pd.read_csv('adult_income.csv') +# df = data_df.copy() +# print("Reading data complete!") + +# x = df.loc[:, df.columns != 'salary'].values +# y = df['salary'].values +# x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.33, random_state=42) +# print("Train test split complete!") + +# scalar = MinMaxScaler() +# x_train= scalar.fit_transform(x_train) +# x_test = scalar.transform(x_test) +# print("Data transform complete!") + +# svm = SVC(probability=True) +# svm.fit(x_train, y_train) +# print("Training classifier complete!") +# print(accuracy_score(y_test, svm.predict(x_test))) + +# x_test = x_test[:2] +# y_test = svm.predict(x_test[:2]) + +# sparsity = [] +# proximity = [] +# discern = DisCERNTabular(svm, 'LIME', 'Q') +# discern.init_data(x_train, y_train, [c for c in df.columns if c!='salary'], ['<=50K', '>50K'], cat_feature_indices=[]) + +# for idx in range(len(x_test)): +# cf, s, p = discern.find_cf(x_test[idx], y_test[idx]) +# sparsity.append(s) +# proximity.append(p) + +# _sparsity = sum(sparsity)/len(sparsity) +# _proximity = sum(proximity)/(len(proximity)*_sparsity) +# print(_sparsity) +# print(_proximity) \ No newline at end of file From 418f518dfe751559933e3d8c208f1a15595ed56c Mon Sep 17 00:00:00 2001 From: Pedram Salimi Date: Wed, 6 Oct 2021 17:58:16 +0200 Subject: [PATCH 6/7] [#Add : implementation of test file in order to test applied conditions] --- tests/CF_conditions_test.py | 39 +------------------------------------ 1 file changed, 1 insertion(+), 38 deletions(-) diff --git a/tests/CF_conditions_test.py b/tests/CF_conditions_test.py index 8c77397..59777d7 100644 --- a/tests/CF_conditions_test.py +++ b/tests/CF_conditions_test.py @@ -38,41 +38,4 @@ discern.init_data(x_train, y_train, [c for c in df.columns if c!='salary'], ['<=50K', '>50K'], cat_feature_indices=[]) -cf, _, _ = discern.find_cf(x_test, y_test) - -# data_df = pd.read_csv('adult_income.csv') -# df = data_df.copy() -# print("Reading data complete!") - -# x = df.loc[:, df.columns != 'salary'].values -# y = df['salary'].values -# x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.33, random_state=42) -# print("Train test split complete!") - -# scalar = MinMaxScaler() -# x_train= scalar.fit_transform(x_train) -# x_test = scalar.transform(x_test) -# print("Data transform complete!") - -# svm = SVC(probability=True) -# svm.fit(x_train, y_train) -# print("Training classifier complete!") -# print(accuracy_score(y_test, svm.predict(x_test))) - -# x_test = x_test[:2] -# y_test = svm.predict(x_test[:2]) - -# sparsity = [] -# proximity = [] -# discern = DisCERNTabular(svm, 'LIME', 'Q') -# discern.init_data(x_train, y_train, [c for c in df.columns if c!='salary'], ['<=50K', '>50K'], cat_feature_indices=[]) - -# for idx in range(len(x_test)): -# cf, s, p = discern.find_cf(x_test[idx], y_test[idx]) -# sparsity.append(s) -# proximity.append(p) - -# _sparsity = sum(sparsity)/len(sparsity) -# _proximity = sum(proximity)/(len(proximity)*_sparsity) -# print(_sparsity) -# print(_proximity) \ No newline at end of file +cf, _, _ = discern.find_cf(x_test, y_test) \ No newline at end of file From dfd68d646fd614b7109621b9114f6fc8f6896e5b Mon Sep 17 00:00:00 2001 From: Pedram Salimi Date: Thu, 7 Oct 2021 12:50:22 +0200 Subject: [PATCH 7/7] [#Add : Jupyter notebook in order to test assignment 3] --- tests/Assignment_3.ipynb | 482 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 482 insertions(+) create mode 100644 tests/Assignment_3.ipynb diff --git a/tests/Assignment_3.ipynb b/tests/Assignment_3.ipynb new file mode 100644 index 0000000..83acb9f --- /dev/null +++ b/tests/Assignment_3.ipynb @@ -0,0 +1,482 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Assignment #3.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "HqMISQSEIw-a", + "outputId": "364f7e14-b635-45df-e151-7061f828be2f" + }, + "source": [ + "# !git clone https://github.com/RGU-Computing/DisCERN-XAI\n", + "!git clone https://github.com/pedramsalimi/DisCERN-XAI\n", + "!pip install lime\n", + "!pip install shap" + ], + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Cloning into 'DisCERN-XAI'...\n", + "remote: Enumerating objects: 108, done.\u001b[K\n", + "remote: Counting objects: 0% (1/108)\u001b[K\rremote: Counting objects: 1% (2/108)\u001b[K\rremote: Counting objects: 2% (3/108)\u001b[K\rremote: Counting objects: 3% (4/108)\u001b[K\rremote: Counting objects: 4% (5/108)\u001b[K\rremote: Counting objects: 5% (6/108)\u001b[K\rremote: Counting objects: 6% (7/108)\u001b[K\rremote: Counting objects: 7% (8/108)\u001b[K\rremote: Counting objects: 8% (9/108)\u001b[K\rremote: Counting objects: 9% (10/108)\u001b[K\rremote: Counting objects: 10% (11/108)\u001b[K\rremote: Counting objects: 11% (12/108)\u001b[K\rremote: Counting objects: 12% (13/108)\u001b[K\rremote: Counting objects: 13% (15/108)\u001b[K\rremote: Counting objects: 14% (16/108)\u001b[K\rremote: Counting objects: 15% (17/108)\u001b[K\rremote: Counting objects: 16% (18/108)\u001b[K\rremote: Counting objects: 17% (19/108)\u001b[K\rremote: Counting objects: 18% (20/108)\u001b[K\rremote: Counting objects: 19% (21/108)\u001b[K\rremote: Counting objects: 20% (22/108)\u001b[K\rremote: Counting objects: 21% (23/108)\u001b[K\rremote: Counting objects: 22% (24/108)\u001b[K\rremote: Counting objects: 23% (25/108)\u001b[K\rremote: Counting objects: 24% (26/108)\u001b[K\rremote: Counting objects: 25% (27/108)\u001b[K\rremote: Counting objects: 26% (29/108)\u001b[K\rremote: Counting objects: 27% (30/108)\u001b[K\rremote: Counting objects: 28% (31/108)\u001b[K\rremote: Counting objects: 29% (32/108)\u001b[K\rremote: Counting objects: 30% (33/108)\u001b[K\rremote: Counting objects: 31% (34/108)\u001b[K\rremote: Counting objects: 32% (35/108)\u001b[K\rremote: Counting objects: 33% (36/108)\u001b[K\rremote: Counting objects: 34% (37/108)\u001b[K\rremote: Counting objects: 35% (38/108)\u001b[K\rremote: Counting objects: 36% (39/108)\u001b[K\rremote: Counting objects: 37% (40/108)\u001b[K\rremote: Counting objects: 38% (42/108)\u001b[K\rremote: Counting objects: 39% (43/108)\u001b[K\rremote: Counting objects: 40% (44/108)\u001b[K\rremote: Counting objects: 41% (45/108)\u001b[K\rremote: Counting objects: 42% (46/108)\u001b[K\rremote: Counting objects: 43% (47/108)\u001b[K\rremote: Counting objects: 44% (48/108)\u001b[K\rremote: Counting objects: 45% (49/108)\u001b[K\rremote: Counting objects: 46% (50/108)\u001b[K\rremote: Counting objects: 47% (51/108)\u001b[K\rremote: Counting objects: 48% (52/108)\u001b[K\rremote: Counting objects: 49% (53/108)\u001b[K\rremote: Counting objects: 50% (54/108)\u001b[K\rremote: Counting objects: 51% (56/108)\u001b[K\rremote: Counting objects: 52% (57/108)\u001b[K\rremote: Counting objects: 53% (58/108)\u001b[K\rremote: Counting objects: 54% (59/108)\u001b[K\rremote: Counting objects: 55% (60/108)\u001b[K\rremote: Counting objects: 56% (61/108)\u001b[K\rremote: Counting objects: 57% (62/108)\u001b[K\rremote: Counting objects: 58% (63/108)\u001b[K\rremote: Counting objects: 59% (64/108)\u001b[K\rremote: Counting objects: 60% (65/108)\u001b[K\rremote: Counting objects: 61% (66/108)\u001b[K\rremote: Counting objects: 62% (67/108)\u001b[K\rremote: Counting objects: 63% (69/108)\u001b[K\rremote: Counting objects: 64% (70/108)\u001b[K\rremote: Counting objects: 65% (71/108)\u001b[K\rremote: Counting objects: 66% (72/108)\u001b[K\rremote: Counting objects: 67% (73/108)\u001b[K\rremote: Counting objects: 68% (74/108)\u001b[K\rremote: Counting objects: 69% (75/108)\u001b[K\rremote: Counting objects: 70% (76/108)\u001b[K\rremote: Counting objects: 71% (77/108)\u001b[K\rremote: Counting objects: 72% (78/108)\u001b[K\rremote: Counting objects: 73% (79/108)\u001b[K\rremote: Counting objects: 74% (80/108)\u001b[K\rremote: Counting objects: 75% (81/108)\u001b[K\rremote: Counting objects: 76% (83/108)\u001b[K\rremote: Counting objects: 77% (84/108)\u001b[K\rremote: Counting objects: 78% (85/108)\u001b[K\rremote: Counting objects: 79% (86/108)\u001b[K\rremote: Counting objects: 80% (87/108)\u001b[K\rremote: Counting objects: 81% (88/108)\u001b[K\rremote: Counting objects: 82% (89/108)\u001b[K\rremote: Counting objects: 83% (90/108)\u001b[K\rremote: Counting objects: 84% (91/108)\u001b[K\rremote: Counting objects: 85% (92/108)\u001b[K\rremote: Counting objects: 86% (93/108)\u001b[K\rremote: Counting objects: 87% (94/108)\u001b[K\rremote: Counting objects: 88% (96/108)\u001b[K\rremote: Counting objects: 89% (97/108)\u001b[K\rremote: Counting objects: 90% (98/108)\u001b[K\rremote: Counting objects: 91% (99/108)\u001b[K\rremote: Counting objects: 92% (100/108)\u001b[K\rremote: Counting objects: 93% (101/108)\u001b[K\rremote: Counting objects: 94% (102/108)\u001b[K\rremote: Counting objects: 95% (103/108)\u001b[K\rremote: Counting objects: 96% (104/108)\u001b[K\rremote: Counting objects: 97% (105/108)\u001b[K\rremote: Counting objects: 98% (106/108)\u001b[K\rremote: Counting objects: 99% (107/108)\u001b[K\rremote: Counting objects: 100% (108/108)\u001b[K\rremote: Counting objects: 100% (108/108), done.\u001b[K\n", + "remote: Compressing objects: 100% (77/77), done.\u001b[K\n", + "remote: Total 108 (delta 50), reused 62 (delta 22), pack-reused 0\u001b[K\n", + "Receiving objects: 100% (108/108), 1.54 MiB | 12.66 MiB/s, done.\n", + "Resolving deltas: 100% (50/50), done.\n", + "Collecting lime\n", + " Downloading lime-0.2.0.1.tar.gz (275 kB)\n", + "\u001b[K |████████████████████████████████| 275 kB 28.6 MB/s \n", + "\u001b[?25hRequirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from lime) (3.2.2)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from lime) (1.19.5)\n", + "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from lime) (1.4.1)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from lime) (4.62.3)\n", + "Requirement already satisfied: scikit-learn>=0.18 in /usr/local/lib/python3.7/dist-packages (from lime) (0.22.2.post1)\n", + "Requirement already satisfied: scikit-image>=0.12 in /usr/local/lib/python3.7/dist-packages (from lime) (0.16.2)\n", + "Requirement already satisfied: networkx>=2.0 in /usr/local/lib/python3.7/dist-packages (from scikit-image>=0.12->lime) (2.6.3)\n", + "Requirement already satisfied: pillow>=4.3.0 in /usr/local/lib/python3.7/dist-packages (from scikit-image>=0.12->lime) (7.1.2)\n", + "Requirement already satisfied: imageio>=2.3.0 in /usr/local/lib/python3.7/dist-packages (from scikit-image>=0.12->lime) (2.4.1)\n", + "Requirement already satisfied: PyWavelets>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from scikit-image>=0.12->lime) (1.1.1)\n", + "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->lime) (2.8.2)\n", + "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->lime) (2.4.7)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->lime) (1.3.2)\n", + "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->lime) (0.10.0)\n", + "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from cycler>=0.10->matplotlib->lime) (1.15.0)\n", + "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.18->lime) (1.0.1)\n", + "Building wheels for collected packages: lime\n", + " Building wheel for lime (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for lime: filename=lime-0.2.0.1-py3-none-any.whl size=283857 sha256=59ee5649c2c1d0a7f49fb949bb646a526d441be0b89f014349f2ba55097787bc\n", + " Stored in directory: /root/.cache/pip/wheels/ca/cb/e5/ac701e12d365a08917bf4c6171c0961bc880a8181359c66aa7\n", + "Successfully built lime\n", + "Installing collected packages: lime\n", + "Successfully installed lime-0.2.0.1\n", + "Collecting shap\n", + " Downloading shap-0.39.0.tar.gz (356 kB)\n", + "\u001b[K |████████████████████████████████| 356 kB 33.3 MB/s \n", + "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from shap) (1.19.5)\n", + "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from shap) (1.4.1)\n", + "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from shap) (0.22.2.post1)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from shap) (1.1.5)\n", + "Requirement already satisfied: tqdm>4.25.0 in /usr/local/lib/python3.7/dist-packages (from shap) (4.62.3)\n", + "Collecting slicer==0.0.7\n", + " Downloading slicer-0.0.7-py3-none-any.whl (14 kB)\n", + "Requirement already satisfied: numba in /usr/local/lib/python3.7/dist-packages (from shap) (0.51.2)\n", + "Requirement already satisfied: cloudpickle in /usr/local/lib/python3.7/dist-packages (from shap) (1.3.0)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from numba->shap) (57.4.0)\n", + "Requirement already satisfied: llvmlite<0.35,>=0.34.0.dev0 in /usr/local/lib/python3.7/dist-packages (from numba->shap) (0.34.0)\n", + "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->shap) (2018.9)\n", + "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->shap) (2.8.2)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->shap) (1.15.0)\n", + "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->shap) (1.0.1)\n", + "Building wheels for collected packages: shap\n", + " Building wheel for shap (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for shap: filename=shap-0.39.0-cp37-cp37m-linux_x86_64.whl size=491649 sha256=36f26b15cee212cd5f5a09c09c46c660b24610a1644dde70068218dcb620cfc3\n", + " Stored in directory: /root/.cache/pip/wheels/ca/25/8f/6ae5df62c32651cd719e972e738a8aaa4a87414c4d2b14c9c0\n", + "Successfully built shap\n", + "Installing collected packages: slicer, shap\n", + "Successfully installed shap-0.39.0 slicer-0.0.7\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Nn6RVUk8EX7q", + "outputId": "b26ebd5e-a717-422a-c0bd-dcf6f7731d7f" + }, + "source": [ + "cd DisCERN-XAI/" + ], + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/content/DisCERN-XAI\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2ObSvGy5-rXV", + "outputId": "9a080f51-4a4c-418e-becc-2510ab3b20c8" + }, + "source": [ + "!pip install -r requirements.txt\n", + "!pip install -e ." + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 1)) (1.19.5)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 2)) (1.1.5)\n", + "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 3)) (0.22.2.post1)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 4)) (4.62.3)\n", + "Requirement already satisfied: lime in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 5)) (0.2.0.1)\n", + "Collecting transformers==4.11.2\n", + " Downloading transformers-4.11.2-py3-none-any.whl (2.9 MB)\n", + "\u001b[K |████████████████████████████████| 2.9 MB 23.2 MB/s \n", + "\u001b[?25hCollecting sentencepiece==0.1.96\n", + " Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)\n", + "\u001b[K |████████████████████████████████| 1.2 MB 54.1 MB/s \n", + "\u001b[?25hRequirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.11.2->-r requirements.txt (line 6)) (2019.12.20)\n", + "Collecting huggingface-hub>=0.0.17\n", + " Downloading huggingface_hub-0.0.19-py3-none-any.whl (56 kB)\n", + "\u001b[K |████████████████████████████████| 56 kB 5.0 MB/s \n", + "\u001b[?25hRequirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers==4.11.2->-r requirements.txt (line 6)) (21.0)\n", + "Collecting sacremoses\n", + " Downloading sacremoses-0.0.46-py3-none-any.whl (895 kB)\n", + "\u001b[K |████████████████████████████████| 895 kB 39.7 MB/s \n", + "\u001b[?25hCollecting tokenizers<0.11,>=0.10.1\n", + " Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)\n", + "\u001b[K |████████████████████████████████| 3.3 MB 29.9 MB/s \n", + "\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers==4.11.2->-r requirements.txt (line 6)) (2.23.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers==4.11.2->-r requirements.txt (line 6)) (3.2.0)\n", + "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers==4.11.2->-r requirements.txt (line 6)) (4.8.1)\n", + "Collecting pyyaml>=5.1\n", + " Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)\n", + "\u001b[K |████████████████████████████████| 636 kB 45.6 MB/s \n", + "\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.0.17->transformers==4.11.2->-r requirements.txt (line 6)) (3.7.4.3)\n", + "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers==4.11.2->-r requirements.txt (line 6)) (2.4.7)\n", + "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->-r requirements.txt (line 2)) (2.8.2)\n", + "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->-r requirements.txt (line 2)) (2018.9)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->-r requirements.txt (line 2)) (1.15.0)\n", + "Requirement already satisfied: scipy>=0.17.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->-r requirements.txt (line 3)) (1.4.1)\n", + "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->-r requirements.txt (line 3)) (1.0.1)\n", + "Requirement already satisfied: scikit-image>=0.12 in /usr/local/lib/python3.7/dist-packages (from lime->-r requirements.txt (line 5)) (0.16.2)\n", + "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from lime->-r requirements.txt (line 5)) (3.2.2)\n", + "Requirement already satisfied: imageio>=2.3.0 in /usr/local/lib/python3.7/dist-packages (from scikit-image>=0.12->lime->-r requirements.txt (line 5)) (2.4.1)\n", + "Requirement already satisfied: pillow>=4.3.0 in /usr/local/lib/python3.7/dist-packages (from scikit-image>=0.12->lime->-r requirements.txt (line 5)) (7.1.2)\n", + "Requirement already satisfied: networkx>=2.0 in /usr/local/lib/python3.7/dist-packages (from scikit-image>=0.12->lime->-r requirements.txt (line 5)) (2.6.3)\n", + "Requirement already satisfied: PyWavelets>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from scikit-image>=0.12->lime->-r requirements.txt (line 5)) (1.1.1)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->lime->-r requirements.txt (line 5)) (1.3.2)\n", + "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->lime->-r requirements.txt (line 5)) (0.10.0)\n", + "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers==4.11.2->-r requirements.txt (line 6)) (3.6.0)\n", + "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.11.2->-r requirements.txt (line 6)) (3.0.4)\n", + "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.11.2->-r requirements.txt (line 6)) (2.10)\n", + "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.11.2->-r requirements.txt (line 6)) (1.24.3)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.11.2->-r requirements.txt (line 6)) (2021.5.30)\n", + "Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.11.2->-r requirements.txt (line 6)) (7.1.2)\n", + "Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers, sentencepiece\n", + " Attempting uninstall: pyyaml\n", + " Found existing installation: PyYAML 3.13\n", + " Uninstalling PyYAML-3.13:\n", + " Successfully uninstalled PyYAML-3.13\n", + "Successfully installed huggingface-hub-0.0.19 pyyaml-5.4.1 sacremoses-0.0.46 sentencepiece-0.1.96 tokenizers-0.10.3 transformers-4.11.2\n", + "Obtaining file:///content/DisCERN-XAI\n", + "Installing collected packages: discern-xai\n", + " Running setup.py develop for discern-xai\n", + "Successfully installed discern-xai-0.0.23\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "lKMiRtTN_NbA" + }, + "source": [ + "!mkdir discern/NLG_model" + ], + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "noESjW6J_l2d", + "outputId": "875839f4-dbc8-4508-845c-115fc7579083" + }, + "source": [ + "!wget https://www.dropbox.com/s/15uwbwv00oubrf7/model.pt?dl=0" + ], + "execution_count": 5, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "--2021-10-07 10:31:07-- https://www.dropbox.com/s/15uwbwv00oubrf7/model.pt?dl=0\n", + "Resolving www.dropbox.com (www.dropbox.com)... 162.125.4.18, 2620:100:601c:18::a27d:612\n", + "Connecting to www.dropbox.com (www.dropbox.com)|162.125.4.18|:443... connected.\n", + "HTTP request sent, awaiting response... 301 Moved Permanently\n", + "Location: /s/raw/15uwbwv00oubrf7/model.pt [following]\n", + "--2021-10-07 10:31:07-- https://www.dropbox.com/s/raw/15uwbwv00oubrf7/model.pt\n", + "Reusing existing connection to www.dropbox.com:443.\n", + "HTTP request sent, awaiting response... 302 Found\n", + "Location: https://uc457f63436c52d9e371046e4c9c.dl.dropboxusercontent.com/cd/0/inline/BXnlL1-3JKI3rTyDPryw2WmcOvEvG4PUAEwq2j5BOVA4u0SEby9QECPsTVQWdUERJywNjdMEk3hd0zUGy5YOAvHS0TAFOBCld4dfJvH55h-pm-zaBKAljLp1H40DE81ANmCLEFusW_ktQyg9-fZCk1xV/file# [following]\n", + "--2021-10-07 10:31:07-- https://uc457f63436c52d9e371046e4c9c.dl.dropboxusercontent.com/cd/0/inline/BXnlL1-3JKI3rTyDPryw2WmcOvEvG4PUAEwq2j5BOVA4u0SEby9QECPsTVQWdUERJywNjdMEk3hd0zUGy5YOAvHS0TAFOBCld4dfJvH55h-pm-zaBKAljLp1H40DE81ANmCLEFusW_ktQyg9-fZCk1xV/file\n", + "Resolving uc457f63436c52d9e371046e4c9c.dl.dropboxusercontent.com (uc457f63436c52d9e371046e4c9c.dl.dropboxusercontent.com)... 162.125.6.15, 2620:100:601c:15::a27d:60f\n", + "Connecting to uc457f63436c52d9e371046e4c9c.dl.dropboxusercontent.com (uc457f63436c52d9e371046e4c9c.dl.dropboxusercontent.com)|162.125.6.15|:443... connected.\n", + "HTTP request sent, awaiting response... 302 Found\n", + "Location: /cd/0/inline2/BXnx1X9ESJXFZ8tro6esj-yrFUzME17T361WY99D1yOpFYn2x6UwMcPPDBnJp0MuXDpagP-9TzLX6BGtUSxOt0efxbjN3e7_9esvaL-aDI3BX4LoyEpHtzzuSsdmManir4nWXJOOcgBCIaRqT6zbPhqD25Jc_yd-bgHmUwzWp16VIieSQi2OryHCt_aNOtaXkLJBO35OttYNhTpr5zy5ku-soOIHD5mXzbgaEUULQgRyagVl9cYvOkYeZMCIyaLOg7RGL01vp7vslB0brqjsqj9hiLB_xWtXQhuYmWuAk8wbf0DkJ8cz0RjX8wIEA5EYyttbb871DhfOhNSiDDgX4oJfS-nHkUJ619O6tDyfavSOvXpH0CKg-_M-kt5CXNRrhsw/file [following]\n", + "--2021-10-07 10:31:08-- https://uc457f63436c52d9e371046e4c9c.dl.dropboxusercontent.com/cd/0/inline2/BXnx1X9ESJXFZ8tro6esj-yrFUzME17T361WY99D1yOpFYn2x6UwMcPPDBnJp0MuXDpagP-9TzLX6BGtUSxOt0efxbjN3e7_9esvaL-aDI3BX4LoyEpHtzzuSsdmManir4nWXJOOcgBCIaRqT6zbPhqD25Jc_yd-bgHmUwzWp16VIieSQi2OryHCt_aNOtaXkLJBO35OttYNhTpr5zy5ku-soOIHD5mXzbgaEUULQgRyagVl9cYvOkYeZMCIyaLOg7RGL01vp7vslB0brqjsqj9hiLB_xWtXQhuYmWuAk8wbf0DkJ8cz0RjX8wIEA5EYyttbb871DhfOhNSiDDgX4oJfS-nHkUJ619O6tDyfavSOvXpH0CKg-_M-kt5CXNRrhsw/file\n", + "Reusing existing connection to uc457f63436c52d9e371046e4c9c.dl.dropboxusercontent.com:443.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 891730879 (850M) [application/octet-stream]\n", + "Saving to: ‘model.pt?dl=0’\n", + "\n", + "model.pt?dl=0 100%[===================>] 850.42M 114MB/s in 8.3s \n", + "\n", + "2021-10-07 10:31:16 (102 MB/s) - ‘model.pt?dl=0’ saved [891730879/891730879]\n", + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "09ou9DNnARtX" + }, + "source": [ + "mv /content/DisCERN-XAI/model.pt?dl=0 /content/DisCERN-XAI/discern/NLG_model/model.pt" + ], + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "9jRAVLtgkN6o" + }, + "source": [ + "from discern.discern_tabular import DisCERNTabular\n", + "import pandas as pd\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.preprocessing import MinMaxScaler\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.metrics import accuracy_score\n", + "from sklearn.svm import SVC\n", + "svm = SVC(probability=True)\n", + "def test_cancer_risk(q,label,d_label):\n", + " data_df = pd.read_csv('/content/DisCERN-XAI/tests/lung_cancer.csv')\n", + " data_df = data_df.replace({'Level': {'Low': 0, 'Medium': 1, 'High': 2}})\n", + " data_df = data_df.replace({'Gender': {2: 0}})\n", + " data_df = data_df.replace({'Alcohol use': {2: 0}})\n", + " data_df = data_df.replace({'Dust Allergy': {2: 0}})\n", + " data_df = data_df.replace({'Smoking': {2: 0}})\n", + " data_df = data_df.replace({'Chest Pain': {2: 0}})\n", + " data_df = data_df.replace({'Fatigue': {2: 0}})\n", + " data_df = data_df.replace({'Shortness of Breath': {2: 0}})\n", + " data_df = data_df.replace({'Wheezing': {2: 0}})\n", + " data_df = data_df.replace({'Swallowing Difficulty': {2: 0}})\n", + " data_df = data_df.replace({'Cough': {2: 0}})\n", + " data_df = data_df.replace({'chronic Lung Disease': {2: 0}})\n", + " # print(\"Reading data complete!\",data_df.columns)\n", + " discern = DisCERNTabular(svm, 'LIME', 'Q')\n", + "\n", + " df = data_df.copy()\n", + " x = df.loc[:, df.columns != 'Level'].values\n", + " y = df['Level'].values\n", + " x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.33, random_state=1)\n", + " scaler = MinMaxScaler()\n", + " x_train= scaler.fit_transform(x_train)\n", + " x_test = scaler.transform(x_test)\n", + " svm.fit(x_train, y_train)\n", + " x_test = scaler.transform([q])\n", + " y_test = svm.predict(x_test)\n", + "\n", + " cat_indices = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]\n", + "\n", + " sparsity = []\n", + " proximity = []\n", + " discern = DisCERNTabular(svm, 'LIME', 'Q')\n", + " discern.init_data(x_train, y_train, [c for c in df.columns if c!='Level'], ['Low', 'Medium', 'High'], cat_feature_indices=cat_indices)\n", + " cf, s, p = discern.find_cf(x_test[0], {'Low': 0, 'Medium': 1, 'High': 2}[label], desired_class=d_label)\n", + " sparsity.append(s)\n", + " proximity.append(p)\n", + " _sparsity = sum(sparsity)/len(sparsity)\n", + " _proximity = sum(proximity)/(len(proximity)*_sparsity)\n", + " \n", + " x = scaler.inverse_transform(x_test)\n", + " c = scaler.inverse_transform([cf])\n", + " for i in range(len(data_df.columns)-1):\n", + " if c[0][i] == x[0][i]:\n", + " pass\n", + " else:\n", + " print(\"{} should be changed from {} to {}\".format(data_df.columns[i],x[0][i],c[0][i]))\n", + " cls = list(df.columns)\n", + " x=pd.DataFrame(x,columns=[c for c in df.columns if c!='Level'])\n", + " c=pd.DataFrame(c,columns=[c for c in df.columns if c!='Level'])\n", + " x['Level'] = y_test\n", + " c['Level'] = 1 # Temporary\n", + "\n", + " x=x.to_dict(orient=\"index\")[0]\n", + " c=c.to_dict(orient=\"index\")[0]\n", + " discern.show_cf(x, c)\n", + " break" + ], + "execution_count": 16, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "sL0HYMKWkl9j", + "outputId": "56ff6f20-2f6a-49bf-e4cb-8767f2793bb4" + }, + "source": [ + "data_df = pd.read_csv('/content/DisCERN-XAI/tests/lung_cancer.csv')\n", + "data_df.columns" + ], + "execution_count": 8, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "Index(['Age', 'Gender', 'Alcohol use', 'Dust Allergy', 'chronic Lung Disease',\n", + " 'Smoking', 'Chest Pain', 'Fatigue', 'Shortness of Breath', 'Wheezing',\n", + " 'Swallowing Difficulty', 'Cough', 'Level'],\n", + " dtype='object')" + ] + }, + "metadata": {}, + "execution_count": 8 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "xmiL55uCqA_5" + }, + "source": [ + "features = ['Age', 'Gender', 'Alcohol use', 'Dust Allergy', 'chronic Lung Disease',\n", + " 'Smoking', 'Chest Pain', 'Fatigue', 'Shortness of Breath', 'Wheezing',\n", + " 'Swallowing Difficulty', 'Cough', 'Level','Target Level']\n", + "data_dic = {}\n", + "for i in features:\n", + " if i==\"Level\" or i==\"Target Level\":\n", + " data_dic[i] = list(set(data_df.Level))\n", + " else:\n", + " data_dic[i] = [data_df.describe()[i][\"min\"], data_df.describe()[i][\"max\"]]" + ], + "execution_count": 9, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "e8h5B2i_tQMz", + "outputId": "9e56de30-d774-4874-f8b6-5105b786bfe9" + }, + "source": [ + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "import torch\n", + "import numpy as np\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/DialoGPT-medium\")\n", + "model = AutoModelForCausalLM.from_pretrained(\"microsoft/DialoGPT-medium\")\n", + "\n", + "features = ['Age', 'Gender', 'Alcohol use', 'Dust Allergy', 'chronic Lung Disease',\n", + " 'Smoking', 'Chest Pain', 'Fatigue', 'Shortness of Breath', 'Wheezing',\n", + " 'Swallowing Difficulty', 'Cough', 'Level', 'Target Level']\n", + "exit = [\"Thanks\",\"Goodbye\",\"goodbye\",\"Bye\",\"bye\"]\n", + "\n", + "for step in range(5):\n", + " new_user_input_ids = tokenizer.encode(input(\">> User:\") + tokenizer.eos_token, return_tensors='pt')\n", + " print(tokenizer.decode(new_user_input_ids[0], skip_special_tokens=True))\n", + " bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids\n", + " chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)\n", + " print(\"DialoGPT: {}\".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))\n", + " if \"help\" in tokenizer.decode(new_user_input_ids[0], skip_special_tokens=True):\n", + " info = []\n", + " lvl = \"\"\n", + " trg = \"\"\n", + " for i in range(len(data_dic)):\n", + " if features[i]==\"Level\":\n", + " lvl = input(f\"Please give info about your {features[i]} status in a following range: {data_dic[features[i]]}\")\n", + " elif features[i]==\"Target Level\":\n", + " trg = input(f\"Please give info about your {features[i]} status in a following range: {data_dic[features[i]]}\")\n", + " else:\n", + " info.append(float(input(f\"Please give info about your {features[i]} status in a following range: {data_dic[features[i]]}\")))\n", + " test_cancer_risk(np.array(info), lvl, trg)\n", + " if tokenizer.decode(new_user_input_ids[0], skip_special_tokens=True) in exit:\n", + " print(\"Have a good day, bye.\")\n", + " break" + ], + "execution_count": 18, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + ">> User:Hello\n", + "Hello\n", + "DialoGPT: Hello! :D\n", + ">> User:how are you?\n", + "how are you?\n", + "DialoGPT: I'm good, how are you?\n", + ">> User:Thanks, can you help me in some sort of decision making?\n", + "Thanks, can you help me in some sort of decision making?\n", + "DialoGPT: I can try, but I'm not sure if I can help you.\n", + "Please give info about your Age status in a following range: [14.0, 87.0]15\n", + "Please give info about your Gender status in a following range: [1.0, 2.0]1\n", + "Please give info about your Alcohol use status in a following range: [1.0, 2.0]2\n", + "Please give info about your Dust Allergy status in a following range: [1.0, 2.0]1\n", + "Please give info about your chronic Lung Disease status in a following range: [1.0, 2.0]1\n", + "Please give info about your Smoking status in a following range: [1.0, 2.0]1\n", + "Please give info about your Chest Pain status in a following range: [1.0, 2.0]1\n", + "Please give info about your Fatigue status in a following range: [1.0, 2.0]1\n", + "Please give info about your Shortness of Breath status in a following range: [1.0, 2.0]1\n", + "Please give info about your Wheezing status in a following range: [1.0, 2.0]2\n", + "Please give info about your Swallowing Difficulty status in a following range: [1.0, 2.0]1\n", + "Please give info about your Cough status in a following range: [1.0, 2.0]1\n", + "Please give info about your Level status in a following range: ['High', 'Medium', 'Low']Low\n", + "Please give info about your Target Level status in a following range: ['High', 'Medium', 'Low']High\n", + "Fatigue should be changed from 1.0 to 0.0\n", + "Instance: A 1.0 smoking, 1.0 a prolonged period of cough, 1.0 x 1.0 cough level or 0.33, 1.0 was a chest pain.\n", + "Counterfactual: The one litre of alcohol used in the 1.0 tas, or the 1.0 wheezing level is 1 level of noise, the cough and the tobacco use (1.5).\n", + ">> User:Thanks\n", + "Thanks\n", + "DialoGPT: No problem, I'll try to help you.\n", + "Have a good day, bye.\n" + ] + } + ] + } + ] +} \ No newline at end of file