-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathcontroller.py
More file actions
128 lines (115 loc) · 5.24 KB
/
controller.py
File metadata and controls
128 lines (115 loc) · 5.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import streamlit as st
from classes import Data
import tiktoken
from typing import List, Dict, Any
from random import randint
from classes import IdCounter
from interface import render, display_result_prompt, display_result_generated_fewshot
import os
def session_state_init() -> None:
if "examples" not in st.session_state:
st.session_state.examples = []
if "id_counter" not in st.session_state:
st.session_state.id_counter = IdCounter()
def read_file(file_path: str) -> str:
with open(file_path, "r") as file:
return file.read()
def combine_prompt(data: Data) -> str:
base_prompt = read_file(f"{os.path.dirname(__file__)}/prompt/generate_fewshots_base.md")
prompt = base_prompt.replace("{{prompt}}", data.get("prompt"))
requirements = data.get("requirements")
constraints = data.get("constraints")
count_generation = data.get("count_generation")
prompt = prompt.replace("{{requirements}}", requirements) if requirements else prompt.replace("{{requirements}}", "None")
prompt = prompt.replace("{{constraints}}", constraints) if constraints else prompt.replace("{{constraints}}", "None")
prompt = prompt.replace("\n", "<br>")
format_instructions = read_file(f"{os.path.dirname(__file__)}/prompt/format_instructions.md")
prompt = prompt.replace("{{format_instructions}}", format_instructions)
prompt = prompt.replace("{{language}}", data.get("language"))
prompt = prompt.replace("{{count_generation}}", str(count_generation))
return prompt
from openai import OpenAI
def generate_fewshot(combined_prompt: str, api_key: str, is_validate: bool, data: Data) -> str:
if api_key == "":
return ""
if data:
hyperparams = {
"temperature": data.get("temperature"),
"max_tokens": 2000,
"top_p": data.get("top_p"),
"frequency_penalty": data.get("frequency_penalty"),
"presence_penalty": data.get("presence_penalty")
}
if is_validate:
client = OpenAI(api_key=api_key)
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": combined_prompt}
],
**hyperparams
)
return response.choices[0].message.content
else:
return ""
input_cost_per_token_gpt4o_mini = 0.00000015
output_cost_per_token_gpt4o_mini = 0.0000006
def calculate_cost(prompt : str, data: Data) -> float:
if prompt == "":
return 0
enc = tiktoken.encoding_for_model("gpt-4o-mini")
input_token_count = len(enc.encode(prompt))
max_example_length = max(
(len(example["user_input"]) + len(example["assistant_output"])
for example in st.session_state.examples),
default=0
)
output_expacted_token_count = max_example_length * data.get("count_generation")
input_cost = input_token_count * input_cost_per_token_gpt4o_mini
output_cost = output_expacted_token_count * output_cost_per_token_gpt4o_mini
return input_cost + output_cost, input_token_count + output_expacted_token_count
def postprocess(data: Data) -> Data:
prompt = combine_prompt(data)
token_cost, token_count = calculate_cost(prompt, data)
data.set("result_prompt", prompt)
data.set("token_cost", token_cost)
data.set("token_count", token_count)
if data.get("fewshot_generation_button"):
result = generate_fewshot(prompt, data.get("api_key"), data.get("is_validate"), data)
data.set("result_generation_fewshot", result)
else:
data.set("result_generation_fewshot", "")
return data
not_yet_generated_notice = "아직 생성되지 않았습니다."
def update_interface(data: Data) -> None:
data.get("count_notice").text(f"생성시 필요 입력 토큰 개수: {data.get('token_count')}")
cost = data.get("token_cost")
data.get("cost_notice").text(f"생성 비용: {cost:.5f} 달러 | 이 비용은 임의적 산출비용입니다. 더 많은 비용이 부과될 수 있습니다.")
if data.get("example_append"):
if data.get("assistant_output").get():
st.session_state.examples.append({
"id": st.session_state.id_counter.get(),
"user_input": data.get("user_input").get() if data.get("user_input").get() else "",
"assistant_output": data.get("assistant_output").get(),
})
data.get("user_input").clear()
data.get("assistant_output").clear()
data.get("example_addition_notice").success("예제 추가 성공")
st.rerun()
else:
data.get("example_addition_notice").error("예제 추가 실패")
display_result_prompt(data.get("result_prompt"), data.get("asset_combined_prompt"))
if data.get("result_generation_fewshot"):
display_result_generated_fewshot(data.get("result_generation_fewshot"), data.get("asset_generation_fewshot"))
else:
display_result_generated_fewshot(not_yet_generated_notice, data.get("asset_generation_fewshot"))
def run_impl() -> None:
st.set_page_config(layout="wide")
session_state_init()
data = render()
postprocessed = postprocess(data)
update_interface(postprocessed)
class MainController:
@staticmethod
def run() -> None:
run_impl()