From bb1c3e4dcec9a974a6be9e5b199b6d37e71e5bdd Mon Sep 17 00:00:00 2001 From: fellowtraveler Date: Sun, 12 May 2024 01:19:10 -0500 Subject: [PATCH] fixed so it correctly finds config.yaml based on NSHELL_CONFIG env var --- config.yaml.sample | 5 ++ natural_shell/__init__.py | 8 +-- natural_shell/cli.py | 68 +++++++++--------- natural_shell/command_executor.py | 90 +++++++++++++----------- natural_shell/language_model.py | 113 ++++++++++++++++-------------- natural_shell/parser.py | 14 ++-- 6 files changed, 158 insertions(+), 140 deletions(-) create mode 100644 config.yaml.sample diff --git a/config.yaml.sample b/config.yaml.sample new file mode 100644 index 0000000..cb86b62 --- /dev/null +++ b/config.yaml.sample @@ -0,0 +1,5 @@ +api_key: KEY_GOES_HERE +model: google-genai +base_url: https://api.example.com/v1 +temperature: 0.2 + diff --git a/natural_shell/__init__.py b/natural_shell/__init__.py index 17d76a3..b8299b9 100644 --- a/natural_shell/__init__.py +++ b/natural_shell/__init__.py @@ -1,4 +1,4 @@ -from .command_executor import execute_command, fetch_command -from .language_model import LanguageModel - -__all__ = ["execute_command", "fetch_command", "LanguageModel"] +from .command_executor import execute_command, fetch_command +from .language_model import LanguageModel + +__all__ = ["execute_command", "fetch_command", "LanguageModel"] \ No newline at end of file diff --git a/natural_shell/cli.py b/natural_shell/cli.py index 2b1c62d..e413c3c 100644 --- a/natural_shell/cli.py +++ b/natural_shell/cli.py @@ -1,34 +1,34 @@ -import argparse -from natural_shell.command_executor import execute_command, fetch_command -from natural_shell.language_model import llm - -def main(): - parser = argparse.ArgumentParser(description="Convert natural language input to shell commands") - parser.add_argument("input", help="The natural language input") - subparsers = parser.add_subparsers(dest="command", help="Subcommand to execute") - - print_parser = subparsers.add_parser("print", help="Print the shell command") - print_parser.set_defaults(func=print_command) - - exec_parser = subparsers.add_parser("exec", help="Execute the shell command") - exec_parser.set_defaults(func=exec_command) - - args = parser.parse_args() - - if not hasattr(args, "func"): - parser.print_help() - return - - args.func(args.input, llm.get_language_model()) - -def print_command(input, language_model): - command = fetch_command(input, language_model) - print(command) - -def exec_command(input, language_model): - command, output = execute_command(input, language_model) - print(f"[+] executing: {command}") - print(output) - -if __name__ == "__main__": - main() +import argparse +from natural_shell.command_executor import execute_command, fetch_command +from natural_shell.language_model import llm + +def main(): + parser = argparse.ArgumentParser(description="Convert natural language input to shell commands") + parser.add_argument("input", help="The natural language input") + subparsers = parser.add_subparsers(dest="command", help="Subcommand to execute") + + print_parser = subparsers.add_parser("print", help="Print the shell command") + print_parser.set_defaults(func=print_command) + + exec_parser = subparsers.add_parser("exec", help="Execute the shell command") + exec_parser.set_defaults(func=exec_command) + + args = parser.parse_args() + + if not hasattr(args, "func"): + parser.print_help() + return + + args.func(args.input, llm.get_language_model()) + +def print_command(input, language_model): + command = fetch_command(input, language_model) + print(command) + +def exec_command(input, language_model): + command, output = execute_command(input, language_model) + print(f"[+] executing: {command}") + print(output) + +if __name__ == "__main__": + main() diff --git a/natural_shell/command_executor.py b/natural_shell/command_executor.py index de88a96..d3bfd11 100644 --- a/natural_shell/command_executor.py +++ b/natural_shell/command_executor.py @@ -1,42 +1,48 @@ -import platform -from langchain.prompts import PromptTemplate -from natural_shell.language_model import llm -from natural_shell.parser import parser - -# Define the prompt template -prompt_template = """ -You are an AI assistant that converts natural language input into a real command. -The output should be a valid JSON with a key "command" containing the real command. -The command should be compatible with the {platform} platform. -{format_instructions} -Input: {query} -""" - -def fetch_command(user_input, language_model=None): - if language_model is None: - language_model = llm.get_language_model() - - # Create the chain - prompt = PromptTemplate(template=prompt_template, input_variables=["query"], partial_variables={"format_instructions": parser.get_format_instructions(), "platform": platform.system()}) - chain = prompt | language_model | parser - - retry = 0 - while retry <= 5: - try: - output = chain.invoke({"query": user_input}) - cmd = output['command'] - return cmd - except Exception as e: - print(f"Error: {e}") - retry += 1 - if retry <= 5: - return f"Error: {e}" - -def execute_command(user_input, language_model=None): - cmd = fetch_command(user_input, language_model) - import os - output = os.popen(cmd).read() - return cmd, output - -__all__ = ["fetch_command", "execute_command"] - +import platform +from langchain.prompts import PromptTemplate +from natural_shell.language_model import llm +from natural_shell.parser import parser + +# Define the prompt template +prompt_template = """ +You are an AI assistant that converts natural language input into a real CLI command based on the user's intent. +The output should be a valid JSON with the key "command" that has a value of type string containing the user's actual intended CLI command. +The command should be compatible with the {platform} platform. + +### Formatting instructions: +{format_instructions} + +### Input (contains user intent as a string): +{query} + +### Output (contains valid JSON): +""" + +def fetch_command(user_input, language_model=None): + if language_model is None: + language_model = llm.get_language_model() + + # Create the chain + prompt = PromptTemplate(template=prompt_template, input_variables=["query"], partial_variables={"format_instructions": parser.get_format_instructions(), "platform": platform.system()}) + chain = prompt | language_model | parser + + retry = 0 + while retry <= 5: + try: + output = chain.invoke({"query": user_input}) + cmd = output['command'] + return cmd + except Exception as e: + print(f"Error: {e}") + retry += 1 + if retry <= 5: + return f"Error: {e}" + +def execute_command(user_input, language_model=None): + cmd = fetch_command(user_input, language_model) + import os + output = os.popen(cmd).read() + return cmd, output + +__all__ = ["fetch_command", "execute_command"] + diff --git a/natural_shell/language_model.py b/natural_shell/language_model.py index 954f91a..056299a 100644 --- a/natural_shell/language_model.py +++ b/natural_shell/language_model.py @@ -1,53 +1,60 @@ -import os -import yaml -from langchain_google_genai import ChatGoogleGenerativeAI -from langchain_openai import OpenAI -from langchain_anthropic import Anthropic - -class LanguageModel: - def __init__(self, config_file=None): - self.load_config(config_file) - - def load_config(self, config_file=None): - if config_file is None: - # If no config file is provided, use default values - self.api_key = None - self.model = 'google-genai' - self.base_url = None - return - - if not os.path.exists(config_file): - raise FileNotFoundError(f"Config file '{config_file}' not found.") - - with open(config_file, 'r') as f: - config = yaml.safe_load(f) - - self.api_key = config.get('api_key') - self.model = config.get('model', 'google-genai') - self.base_url = config.get('base_url') - - def get_language_model(self): - if self.model == "google-genai": - return ChatGoogleGenerativeAI(model="gemini-pro", google_api_key=self.api_key, temperature=1) - elif self.model == "openai": - return OpenAI(api_key=self.api_key, temperature=1, base_url=self.base_url) - elif self.model == "anthropic": - return Anthropic(api_key=self.api_key, temperature=1) - else: - raise ValueError(f"Invalid model name: {self.model}") - - -# Create a default instance -llm = LanguageModel() - -# def get_language_model(llm_instance): -# if llm_instance.model == "google-genai": -# return ChatGoogleGenerativeAI(model="gemini-pro", google_api_key=llm_instance.api_key, temperature=1) -# elif llm_instance.model == "openai": -# return OpenAI(api_key=llm_instance.api_key, temperature=1, base_url=llm_instance.base_url) -# elif llm_instance.model == "anthropic": -# return Anthropic(api_key=llm_instance.api_key, temperature=1) -# else: -# raise ValueError(f"Invalid model name: {llm_instance.model}") - -# __all__ = ["llm", "get_language_model","load_config"] +import os +import sys +import yaml +from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_openai import OpenAI +from langchain_anthropic import Anthropic + +class LanguageModel: + def __init__(self, config_file=None): + self.load_config(config_file) + + def load_config(self, config_file=None): + if config_file is None: + # If no config file is provided, use default values + self.api_key = None + self.model = 'google-genai' + self.base_url = None + return + + if not os.path.exists(config_file): + raise FileNotFoundError(f"Config file '{config_file}' not found.") + + with open(config_file, 'r') as f: + config = yaml.safe_load(f) + + self.api_key = config.get('api_key') + self.model = config.get('model', 'google-genai') + self.base_url = config.get('base_url') + self.temperature = config.get('temperature', 1) + + def get_language_model(self): + if self.model == "google-genai": + return ChatGoogleGenerativeAI(model="gemini-pro", google_api_key=self.api_key, temperature=self.temperature) + elif self.model == "openai": + return OpenAI(api_key=self.api_key, temperature=self.temperature, base_url=self.base_url) + elif self.model == "anthropic": + return Anthropic(api_key=self.api_key, temperature=self.temperature) + else: + raise ValueError(f"Invalid model name: {self.model}") + +nshell_config_filename = os.getenv('NSHELL_CONFIG') + +if nshell_config_filename is None: + print("NSHELL_CONFIG environment variable is not set. (Should contain path to config.yaml). Exiting...") + sys.exit(1) # Exiting with a non-zero status code to indicate an error + +# Create a default instance +llm = LanguageModel(nshell_config_filename) + +# def get_language_model(llm_instance): +# if llm_instance.model == "google-genai": +# return ChatGoogleGenerativeAI(model="gemini-pro", google_api_key=llm_instance.api_key, temperature=1) +# elif llm_instance.model == "openai": +# return OpenAI(api_key=llm_instance.api_key, temperature=1, base_url=llm_instance.base_url) +# elif llm_instance.model == "anthropic": +# return Anthropic(api_key=llm_instance.api_key, temperature=1) +# else: +# raise ValueError(f"Invalid model name: {llm_instance.model}") + +# __all__ = ["llm", "get_language_model","load_config"] diff --git a/natural_shell/parser.py b/natural_shell/parser.py index 7f02749..7881a80 100644 --- a/natural_shell/parser.py +++ b/natural_shell/parser.py @@ -1,7 +1,7 @@ -from langchain_core.output_parsers import JsonOutputParser -from langchain_core.pydantic_v1 import BaseModel, Field - -class Command(BaseModel): - command: str = Field(None, description="The real command based on the input") - -parser = JsonOutputParser(pydantic_object=Command) +from langchain_core.output_parsers import JsonOutputParser +from langchain_core.pydantic_v1 import BaseModel, Field + +class Command(BaseModel): + command: str = Field(None, description="The intended CLI command, based on the input") + +parser = JsonOutputParser(pydantic_object=Command)