diff --git a/jupyter_notebooks/rspace_openai_chat.ipynb b/jupyter_notebooks/rspace_openai_chat.ipynb new file mode 100644 index 0000000..333dc39 --- /dev/null +++ b/jupyter_notebooks/rspace_openai_chat.ipynb @@ -0,0 +1,390 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f4d390c7-ec57-49c3-b59a-1fc7e0826dd5", + "metadata": {}, + "source": [ + "## Generating Lucene Query Language queries from free text input\n", + "\n", + "This notebook takes text input, generates Lucene Query Language statements and queries RSpace ELN API to retrieve documents. It uses OpenAI's GPT-4 model and 'Function' API.\n", + "\n", + "The use case is that to generate complex search queries using Lucene syntax, it's tricky to remember the syntax to create a valid query.\n", + "\n", + "This notebook allows queries made from 'free text' .\n", + "\n", + "For example:\n", + "\n", + " Find docs tagged with X and Y but not Z, oldest first\n", + " Find \"some phrase\" in full text, and tagged with Y, alphabetical order\n", + " Find docs where name starts with XXXX, newest first\n", + "\n", + "Note that your RSpace data **is not** sent to OpenAI's servers, only your query.\n", + "The actual search of RSpace is performed by this script.\n", + "\n", + "To run this you'll need:\n", + "\n", + "* An OpenAI API key\n", + "* An account on https://community.researchspace.com, and an RSpace API key. (It's free to set up).\n", + "* Python RSpace client, OpenAI API and various dependencies:\n", + " - `pip install rspace_client notebook requests openai scipy tenacity tiktoken termcolor`" + ] + }, + { + "cell_type": "markdown", + "id": "a3769571-abd1-47a7-9ed4-1b68b85b5118", + "metadata": {}, + "source": [ + "### Setup\n", + "\n", + "Here we import everything we need and check RSpace API connection:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "de2dca6f-4671-47a3-a292-db7878ac7dbc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'message': 'OK', 'rspaceVersion': '1.91.1'}\n" + ] + } + ], + "source": [ + "import os\n", + "from rspace_client.eln import eln\n", + "import json\n", + "import openai\n", + "import requests\n", + "from termcolor import colored\n", + "from tenacity import retry, wait_random_exponential, stop_after_attempt\n", + "\n", + "eln_cli = eln.ELNClient(os.getenv(\"RSPACE_URL\"), os.getenv(\"RSPACE_API_KEY\"))\n", + "print(eln_cli.get_status())\n", + "## Open AI API key should be an environment variable: OPENAI_API_KEY=\n", + "GPT_MODEL = \"gpt-4\" ##gpt-3.5-turbo is alternative option. " + ] + }, + { + "cell_type": "markdown", + "id": "77a2e1e7-39c3-4af1-b82c-47df49af3314", + "metadata": {}, + "source": [ + "### Boilerplate\n", + "\n", + "- sending request to OpenAI API\n", + "- pretty-printing results" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "b262d21e", + "metadata": {}, + "outputs": [], + "source": [ + "@retry(wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3))\n", + "def chat_completion_request(\n", + " messages, functions=None, function_call=None, model=GPT_MODEL\n", + "):\n", + " headers = {\n", + " \"Content-Type\": \"application/json\",\n", + " \"Authorization\": \"Bearer \" + openai.api_key,\n", + " }\n", + " json_data = {\"model\": model, \"messages\": messages}\n", + "\n", + " if functions is not None:\n", + " json_data.update({\"functions\": functions})\n", + " if function_call is not None:\n", + " json_data.update({\"function_call\": function_call})\n", + " try:\n", + " response = requests.post(\n", + " \"https://api.openai.com/v1/chat/completions\",\n", + " headers=headers,\n", + " json=json_data,\n", + " )\n", + " return response\n", + " except Exception as e:\n", + " print(\"Unable to generate ChatCompletion response\")\n", + " print(f\"Exception: {e}\")\n", + " return e" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "442ed481", + "metadata": {}, + "outputs": [], + "source": [ + "role_to_color = {\n", + " \"system\": \"red\",\n", + " \"user\": \"green\",\n", + " \"assistant\": \"blue\",\n", + " \"function\": \"magenta\",\n", + "}\n", + "\n", + "\n", + "def pretty_print_conversation(messages):\n", + " \"\"\"\n", + " Prints color-coded messages according to their role\n", + " \"\"\"\n", + "\n", + " for message in messages:\n", + " if message[\"role\"] == \"system\":\n", + " print(\n", + " colored(\n", + " f\"system: {message['content']}\\n\", role_to_color[message[\"role\"]]\n", + " )\n", + " )\n", + " elif message[\"role\"] == \"user\":\n", + " print(\n", + " colored(f\"user: {message['content']}\\n\", role_to_color[message[\"role\"]])\n", + " )\n", + " elif message[\"role\"] == \"assistant\" and message.get(\"function_call\"):\n", + " print(\n", + " colored(\n", + " f\"assistant: {message['function_call']}\\n\",\n", + " role_to_color[message[\"role\"]],\n", + " )\n", + " )\n", + " elif message[\"role\"] == \"assistant\" and not message.get(\"function_call\"):\n", + " print(\n", + " colored(\n", + " f\"assistant: {message['content']}\\n\", role_to_color[message[\"role\"]]\n", + " )\n", + " )\n", + " elif message[\"role\"] == \"function\":\n", + " print(\n", + " colored(\n", + " f\"function ({message['name']}): {message['content']}\\n\",\n", + " role_to_color[message[\"role\"]],\n", + " )\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8ed312e", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "5e3b8fbe-a582-4853-b932-f19ace7c7743", + "metadata": {}, + "source": [ + "Performs the conversation. Sends messages and functions to OpenAI's Chat completion API.\n", + "\n", + "IF a function is returned, it's invoked. The conversation and the results are returned as a tuple." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "7e574166-0724-45ae-958e-14161457bc7f", + "metadata": {}, + "outputs": [], + "source": [ + "def do_conversation(messages, functions):\n", + "\n", + " resp = chat_completion_request(messages, functions, {'name':'lucene'})\n", + " active_messages = messages.copy()\n", + " response_message = resp.json()['choices'][0]['message']\n", + " active_messages.append(response_message)\n", + "\n", + " if response_message['function_call'] is not None:\n", + " f_name = response_message['function_call']['name']\n", + " f_args = json.loads(response_message['function_call']['arguments'])\n", + " rspace_search_result = available_functions[f_name](**f_args)\n", + " return (active_messages, rspace_search_result)" + ] + }, + { + "cell_type": "markdown", + "id": "81a44ff3-8f56-458b-a492-51abd972f52c", + "metadata": {}, + "source": [ + "### Function definitions\n", + "\n", + "The function we want to call, and its description in JSON Schema" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "58d10244-d0c8-4eda-9657-208ceabc1b1a", + "metadata": {}, + "outputs": [], + "source": [ + "## This is the function that will be invoked with arguments generated by AI.\n", + "## It will make calls to RSpace's search API.\n", + "## Note that date-range basedd syntax is not supprted by RSpace. Also there seems to be a limit of ~125 characters\n", + "## for the Lucene search query length.\n", + "\n", + "def search_rspace_eln(luceneQuery, sort_order=\"lastModified desc\"):\n", + " q = \"l: \" + luceneQuery\n", + " docs = eln_cli.get_documents(query=q, order_by=sort_order)['documents']\n", + " wanted_keys = ['globalId','name', 'tags', 'created'] # The keys we want\n", + " summarised = list(map(lambda d: dict((k, d[k]) for k in wanted_keys if k in d), docs))\n", + " return summarised" + ] + }, + { + "cell_type": "markdown", + "id": "5648c5ce-38b7-49d0-af1a-19be6c5142b6", + "metadata": {}, + "source": [ + "Below we define the data structure that we want GPT-4 model to return." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "9b2a7342-ed8e-4132-93a9-95dd72ef33e3", + "metadata": {}, + "outputs": [], + "source": [ + "available_functions = {\n", + " \"lucene\":search_rspace_eln\n", + "}\n", + "\n", + "functions = [\n", + " {\n", + " \"name\": \"lucene\",\n", + " \"description\": \"\"\"\n", + " A valid Lucene Query Language string generated from user input.\n", + " Document fields are name, docTag, fields.fieldData, and username.\n", + " Don't use wildcards. Don't state your reasoning.\n", + " \"\"\",\n", + " \"parameters\": {\n", + " \"type\":\"object\",\n", + " \"properties\": {\n", + " \"luceneQuery\": {\n", + " \"type\":\"string\",\n", + " \"description\":\"Valid Lucene Query Language as plain text\"\n", + " },\n", + " \"sort_order\": {\n", + " \"type\":\"string\",\n", + " \"description\":\"How results should be sorted\",\n", + " \"enum\":[\"name asc\", \"name desc\", \"created asc\", \"created desc\"]\n", + " },\n", + " \n", + " }\n", + " }\n", + " }\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "d46a46c7-101d-4c28-9d15-ec85d2c2d86d", + "metadata": {}, + "source": [ + "### Executing the conversation\n", + "\n", + "Change the content of the second message to be something relevant for your work.\n", + "\n", + "Use as precise language as you can. Try seeing how little you need to type. \n", + "\n", + "Preamble such as `I want to search for documents....` seems mostly unnecessary.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "9b6122d1-43a1-4cd0-a466-37009ca67d53", + "metadata": {}, + "outputs": [], + "source": [ + "messages = [\n", + " {\n", + " \"role\" : \"system\",\n", + " \"content\": \"Generate function arguments from user input. Don't show reasoning.\"\n", + " },\n", + " {\n", + " \"role\" : \"user\",\n", + " \"content\": \"\"\"\n", + " I want to search for documents that are tagged with PCR but not ECL, \n", + " containing the phrase “DNA replication” but not \"RNA\".\n", + " Reverse alphabetical order\n", + " \"\"\"\n", + " } \n", + "]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "edc9555f-99f6-4a3d-80e0-cb61aaf7a32e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[31msystem: Generate function arguments from user input. Don't show reasoning.\n", + "\u001b[0m\n", + "\u001b[32muser: \n", + " I want to search for documents that are tagged with PCR but not ECL, \n", + " containing the phrase “DNA replication” but not \"RNA\"\n", + " List results in reverse alphabetical order\n", + " \n", + "\u001b[0m\n", + "\u001b[34massistant: {'name': 'lucene', 'arguments': '\\n{\\n \"luceneQuery\": \"docTag:PCR NOT docTag:ECL AND fields.fieldData:\\\\\"DNA replication\\\\\" NOT fields.fieldData:\\\\\"RNA\\\\\"\",\\n \"sort_order\": \"name desc\"\\n}'}\n", + "\u001b[0m\n", + "Search results from RSpace\n", + "--------------------------\n", + "[\n", + " {\n", + " \"globalId\": \"SD1924558\",\n", + " \"name\": \"aurora expression analysis on cell growth: second attempt\",\n", + " \"tags\": \"polyclonal,PCR\",\n", + " \"created\": \"2023-09-17T18:23:38.029Z\"\n", + " },\n", + " {\n", + " \"globalId\": \"SD1924288\",\n", + " \"name\": \"aurora expression analysis on cell growth\",\n", + " \"tags\": \"polyclonal,PCR\",\n", + " \"created\": \"2023-09-15T23:47:26.853Z\"\n", + " }\n", + "]\n" + ] + } + ], + "source": [ + "(conversation, results) = do_conversation(messages,functions)\n", + "pretty_print_conversation(conversation)\n", + "print(\"Search results from RSpace\\n--------------------------\")\n", + "print(json.dumps(results, indent=2))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}