diff --git a/package-lock.json b/package-lock.json index ab28a44..d30b3fe 100644 --- a/package-lock.json +++ b/package-lock.json @@ -8,6 +8,9 @@ "name": "@gptsafe/promptguard", "version": "0.2.0", "license": "Apache-2.0", + "dependencies": { + "lande": "^1.0.10" + }, "devDependencies": { "@jest/globals": "^29.4.1", "@typescript-eslint/eslint-plugin": "^5.50.0", @@ -3521,6 +3524,14 @@ "node": ">=6" } }, + "node_modules/lande": { + "version": "1.0.10", + "resolved": "https://registry.npmjs.org/lande/-/lande-1.0.10.tgz", + "integrity": "sha512-yT52DQh+UV2pEp08jOYrA4drDv0DbjpiRyZYgl25ak9G2cVR2AimzrqkYQWrD9a7Ud+qkAcaiDDoNH9DXfHPmw==", + "dependencies": { + "toygrad": "^2.6.0" + } + }, "node_modules/leven": { "version": "3.1.0", "resolved": "https://registry.npmjs.org/leven/-/leven-3.1.0.tgz", @@ -4421,6 +4432,11 @@ "node": ">=8.0" } }, + "node_modules/toygrad": { + "version": "2.6.0", + "resolved": "https://registry.npmjs.org/toygrad/-/toygrad-2.6.0.tgz", + "integrity": "sha512-g4zBmlSbvzOE5FOILxYkAybTSxijKLkj1WoNqVGnbMcWDyj4wWQ+eYSr3ik7XOpIgMq/7eBcPRTJX3DM2E0YMg==" + }, "node_modules/ts-jest": { "version": "29.0.5", "resolved": "https://registry.npmjs.org/ts-jest/-/ts-jest-29.0.5.tgz", diff --git a/package.json b/package.json index 8c3ffe8..33d34d8 100644 --- a/package.json +++ b/package.json @@ -3,7 +3,7 @@ "version": "0.2.0", "description": "Prevent GPT prompt attacks for Node.js & TypeScript", "main": "./dist/index.js", - "scripts": { + "scripts": { "test": "jest", "build": "tsc --build", "clean": "tsc --build --clean", @@ -43,5 +43,8 @@ "jest": "^29.4.1", "ts-jest": "^29.0.5", "typescript": "^4.9.5" + }, + "dependencies": { + "lande": "^1.0.10" } } diff --git a/src/index.ts b/src/index.ts index 28e806a..7f78d92 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,49 +1,60 @@ #!/usr/bin/env ts-node import { - promptContainsDenyListItems, countPromptTokens, - encodePromptOutput, - promptContainsKnownAttack + promptContainsKnownAttack, + promptContainsLanguages, + promptContainsDenyListItems, + encodePromptOutput } from './utils'; enum FAILURE_REASON { DENY_LIST = 'CONTAINS_DENY_LIST_ITEM', MAX_TOKEN_THRESHOLD = 'EXCEEDS_MAX_TOKEN_THRESHOLD', - KNOWN_ATTACK = 'CONTAINS_KNOWN_ATTACK' + KNOWN_ATTACK = 'CONTAINS_KNOWN_ATTACK', + LANGUAGE_VALIDATION = 'FAILED_LANGUAGE_VALIDATION' } type UserPolicyOptions = { maxTokens?: number; denyList?: string[]; ignoreDefaultDenyList?: boolean; + allowedLanguages?: string[]; + deniedLanguages?: string[]; encodeOutput?: boolean; }; interface PromptGuardPolicy { maxTokens: number; // 1 token is ~4 characters in english - denyList: string[]; // this should be a fuzzy match + denyList: string[]; // this should use a fuzzy match but doesn't currently disableAttackMitigation: boolean; + allowedLanguages: string[]; + deniedLanguages: string[]; encodeOutput: boolean; // uses byte pair encoding to turn text into a series of integers } type PromptOutput = { - pass: boolean; // false if processing fails validation rules (max tokens, deny list, allow list) + pass: boolean; // false if processing fails validation rules output: string | number[]; // provide the processed prompt or failure reason }; export class PromptGuard { - promptGuardPolicy: PromptGuardPolicy; + policy: PromptGuardPolicy; constructor(userPolicyOptions: UserPolicyOptions = {}) { const defaultPromptGuardPolicy: PromptGuardPolicy = { maxTokens: 4096, denyList: [''], disableAttackMitigation: false, + allowedLanguages: [''], + deniedLanguages: [''], encodeOutput: false }; + // TODO validate the languages against the list of ISO 639-3 supported languages + // TODO validate that the allowed and denied language lists don't contain the same languages + // merge the user policy with the default policy to create the policy - this.promptGuardPolicy = { + this.policy = { ...defaultPromptGuardPolicy, ...userPolicyOptions }; @@ -51,26 +62,45 @@ export class PromptGuard { async process(prompt: string): Promise { // processing order - // normalize -> quote -> escape -> check tokens -> check cache -> check for known attacks -> check allow list -> check deny list -> encode output + // check tokens -> check allowed languages -> check denied languages -> + // check for known attacks -> check deny list -> encode output // check the prompt token count - if (countPromptTokens(prompt) > this.promptGuardPolicy.maxTokens) + if (countPromptTokens(prompt) > this.policy.maxTokens) return { pass: false, output: FAILURE_REASON.MAX_TOKEN_THRESHOLD }; + // check for the presence of allowed languages + // the prompt must be at least 10 characters long to reasonably expect to detect the language + if (prompt.length > 10) { + const allowedLanguages = this.policy.allowedLanguages; + const deniedLanguages = this.policy.deniedLanguages; + + if (allowedLanguages[0] !== '') { + if (await !promptContainsLanguages(prompt, allowedLanguages)) + return { pass: false, output: FAILURE_REASON.LANGUAGE_VALIDATION }; + } + if (deniedLanguages[0] !== '') { + if (await promptContainsLanguages(prompt, deniedLanguages)) + return { pass: false, output: FAILURE_REASON.LANGUAGE_VALIDATION }; + } + } + + // check for the presence of denied languages + // check prompt against known prompt attacks - if (!this.promptGuardPolicy.disableAttackMitigation) { + if (!this.policy.disableAttackMitigation) { if (await promptContainsKnownAttack(prompt)) return { pass: false, output: FAILURE_REASON.KNOWN_ATTACK }; } // check prompt again the user defined deny list if ( - await promptContainsDenyListItems(prompt, this.promptGuardPolicy.denyList) + await promptContainsDenyListItems(prompt, this.policy.denyList) ) return { pass: false, output: FAILURE_REASON.DENY_LIST }; // encode the prompt output if encodeOutput is set by the user - if (this.promptGuardPolicy.encodeOutput) + if (this.policy.encodeOutput) prompt = encodePromptOutput(prompt); return { pass: true, output: prompt }; diff --git a/src/tests/index.test.ts b/src/tests/index.test.ts index e8020d8..462f132 100644 --- a/src/tests/index.test.ts +++ b/src/tests/index.test.ts @@ -64,3 +64,23 @@ test("Checks that the prompt output is encoded when encodeOutput is true", async output: [5661, 318, 616, 7427, 6152, 0], }); }); + +test("Checks the prompt language allow list", async () => { + const promptGuard = new PromptGuard({ allowedLanguages: ['eng']}); + const prompt = "this is my awesome prompt!"; + const output = await promptGuard.process(prompt); + expect(output).toStrictEqual({ + pass: true, + output: prompt, + }); +}); + +test("Checks the prompt language denied list", async () => { + const promptGuard = new PromptGuard({ deniedLanguages: ['eng']}); + const prompt = "this is my awesome prompt!"; + const output = await promptGuard.process(prompt); + expect(output).toStrictEqual({ + pass: false, + output: 'FAILED_LANGUAGE_VALIDATION', + }); +}); diff --git a/src/utils/attackmitigation/index.ts b/src/utils/attack_mitigation/index.ts similarity index 96% rename from src/utils/attackmitigation/index.ts rename to src/utils/attack_mitigation/index.ts index 1c31ff6..7643a23 100644 --- a/src/utils/attackmitigation/index.ts +++ b/src/utils/attack_mitigation/index.ts @@ -2,7 +2,7 @@ type DenyList = string[]; const attackDenyList: DenyList = [ "ignore above", - "gnore the above", + "ignore the above", "ignore previous instructions", "ignore the previous instructions", "ignore above instructions", diff --git a/src/utils/denylist/index.ts b/src/utils/deny_list/index.ts similarity index 100% rename from src/utils/denylist/index.ts rename to src/utils/deny_list/index.ts diff --git a/src/utils/encoder/LICENSE b/src/utils/gpt_encoder/LICENSE similarity index 100% rename from src/utils/encoder/LICENSE rename to src/utils/gpt_encoder/LICENSE diff --git a/src/utils/encoder/encoder.json b/src/utils/gpt_encoder/encoder.json similarity index 100% rename from src/utils/encoder/encoder.json rename to src/utils/gpt_encoder/encoder.json diff --git a/src/utils/encoder/index.js b/src/utils/gpt_encoder/index.js similarity index 100% rename from src/utils/encoder/index.js rename to src/utils/gpt_encoder/index.js diff --git a/src/utils/encoder/vocab.bpe b/src/utils/gpt_encoder/vocab.bpe similarity index 100% rename from src/utils/encoder/vocab.bpe rename to src/utils/gpt_encoder/vocab.bpe diff --git a/src/utils/index.ts b/src/utils/index.ts index 374678d..b77bbdc 100644 --- a/src/utils/index.ts +++ b/src/utils/index.ts @@ -1,9 +1,12 @@ -import { containsDenyListItems } from "./denylist"; +import { containsDenyListItems } from './deny_list'; export const promptContainsDenyListItems = containsDenyListItems; -import { containsKnownAttack } from "./attackmitigation"; +import { containsKnownAttack } from './attack_mitigation'; export const promptContainsKnownAttack = containsKnownAttack; -const encoder = require("./encoder"); +import { containsLanguages } from './language_detection'; +export const promptContainsLanguages = containsLanguages; + +const encoder = require('./gpt_encoder'); export const countPromptTokens = encoder.countTokens; export const encodePromptOutput = encoder.encode; diff --git a/src/utils/language_detection/index.ts b/src/utils/language_detection/index.ts new file mode 100644 index 0000000..27090dc --- /dev/null +++ b/src/utils/language_detection/index.ts @@ -0,0 +1,32 @@ +import lande from "lande"; + +type LandeOuput = Array<[string, number]>; +type detectLanguageOutput = string[]; + +export async function containsLanguages( + prompt: string, + languages: string[], +): Promise { + const detectedLanguages: detectLanguageOutput = []; + + // lande returns a sorted list of detected languages and their probabilities. + // for now, we're selecting all languages with a probability greater than 80% + // this may need to be tuned later + const landeOuput: LandeOuput = lande(prompt); + + for (const lang of landeOuput) { + if (lang[1] > 0.8) detectedLanguages.push(lang[0]); + else break; + } + + for (const lang of detectedLanguages) { + if (languages.includes(lang)) return true; + } + + return false; +} + +// export async function validateLanguageList(list: string[]): Promise { +// //foo +// return true; +// }