Skip to content

Commit e1bc1ac

Browse files
authored
remove asynio from get_models (#220)
1 parent 0d5f4c6 commit e1bc1ac

File tree

3 files changed

+50
-51
lines changed

3 files changed

+50
-51
lines changed

src/client/content/testbed.py

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,6 @@ def main():
260260
if not db_avail or not available_ll_models:
261261
st.stop()
262262

263-
264263
# If there is no eligible (OpenAI Compat.) Embedding Model; disable Generate Test Set
265264
gen_testset_disabled = False
266265
embed_models_enabled = st_common.enabled_models_lookup("embed")
@@ -398,49 +397,51 @@ def main():
398397
# Load TestSets (and Evaluations if from DB)
399398
if col_left.button(button_text, key="load_tests", use_container_width=True, disabled=state.running):
400399
placeholder = st.empty()
401-
with placeholder:
402-
st.info("Processing Q&A... please be patient.", icon="⚠️")
403-
if testset_source != "Database":
404-
api_params["name"] = (state.testbed["testset_name"],)
405-
files = st_common.local_file_payload(state[f"selected_uploader_{state.testbed['uploader_key']}"])
406-
api_payload = {"files": files}
407-
try:
408-
response = api_call.post(endpoint=endpoint, params=api_params, payload=api_payload, timeout=3600)
409-
get_testbed_db_testsets.clear()
410-
state.testbed_db_testsets = get_testbed_db_testsets()
400+
with st.spinner("Processing Q&A... please be patient.", show_time=True):
401+
if testset_source != "Database":
402+
api_params["name"] = (state.testbed["testset_name"],)
403+
files = st_common.local_file_payload(state[f"selected_uploader_{state.testbed['uploader_key']}"])
404+
api_payload = {"files": files}
405+
try:
406+
response = api_call.post(endpoint=endpoint, params=api_params, payload=api_payload, timeout=3600)
407+
get_testbed_db_testsets.clear()
408+
state.testbed_db_testsets = get_testbed_db_testsets()
409+
state.testbed["testset_id"] = next(
410+
(
411+
d["tid"]
412+
for d in state.testbed_db_testsets
413+
if d.get("name") == state.testbed["testset_name"]
414+
),
415+
None,
416+
)
417+
except api_call.ApiError as ex:
418+
st.error(f"Error Generating TestSet: {ex}", icon="🚨")
419+
st.stop()
420+
except Exception as ex:
421+
logger.error("Exception: %s", ex)
422+
st.error(f"Looks like you found a bug: {ex}", icon="🚨")
423+
st.stop()
424+
else:
425+
# Set required state from splitting selected DB TestSet
426+
testset_name, testset_created = state.selected_db_testset.split(" -- Created: ", 1)
427+
state.testbed["testset_name"] = testset_name
411428
state.testbed["testset_id"] = next(
412-
(d["tid"] for d in state.testbed_db_testsets if d.get("name") == state.testbed["testset_name"]),
429+
(
430+
d["tid"]
431+
for d in state.testbed_db_testsets
432+
if d["name"] == testset_name and d["created"] == testset_created
433+
),
413434
None,
414435
)
415-
except api_call.ApiError as ex:
416-
st.error(f"Error Generating TestSet: {ex}", icon="🚨")
417-
st.stop()
418-
except Exception as ex:
419-
logger.error("Exception: %s", ex)
420-
st.error(f"Looks like you found a bug: {ex}", icon="🚨")
421-
st.stop()
422-
else:
423-
# Set required state from splitting selected DB TestSet
424-
testset_name, testset_created = state.selected_db_testset.split(" -- Created: ", 1)
425-
state.testbed["testset_name"] = testset_name
426-
state.testbed["testset_id"] = next(
427-
(
428-
d["tid"]
429-
for d in state.testbed_db_testsets
430-
if d["name"] == testset_name and d["created"] == testset_created
431-
),
432-
None,
433-
)
434-
api_params = {"tid": state.testbed["testset_id"]}
435-
# Retrieve TestSet Data
436-
response = api_call.get(endpoint=endpoint, params=api_params)
436+
api_params = {"tid": state.testbed["testset_id"]}
437+
# Retrieve TestSet Data
438+
response = api_call.get(endpoint=endpoint, params=api_params)
437439
try:
438440
state.testbed_qa = response["qa_data"]
439441
st.success(f"{len(state.testbed_qa)} Q&A Loaded.", icon="✅")
440442
except UnboundLocalError as ex:
441443
logger.exception("Failed to load Tests: %s", ex)
442444
st.error("Unable to process Tests", icon="🚨")
443-
placeholder.empty()
444445
col_center.button(
445446
"Reset",
446447
key="reset_test_framework",
@@ -534,16 +535,13 @@ def main():
534535
help="Evaluation will automatically save the TestSet to the Database",
535536
on_click=qa_update_db,
536537
):
537-
st_common.clear_state_key("testbed_evaluations")
538-
placeholder = st.empty()
539-
with placeholder:
540-
st.warning("Starting Q&A evaluation... please be patient.", icon="⚠️")
541-
st_common.patch_settings()
542-
endpoint = "v1/testbed/evaluate"
543-
api_params = {"tid": state.testbed["testset_id"], "judge": state.selected_evaluate_judge}
544-
evaluate = api_call.post(endpoint=endpoint, params=api_params, timeout=1200)
538+
with st.spinner("Starting Q&A evaluation... please be patient.", show_time=True):
539+
st_common.clear_state_key("testbed_evaluations")
540+
st_common.patch_settings()
541+
endpoint = "v1/testbed/evaluate"
542+
api_params = {"tid": state.testbed["testset_id"], "judge": state.selected_evaluate_judge}
543+
evaluate = api_call.post(endpoint=endpoint, params=api_params, timeout=1200)
545544
st.success("Evaluation Complete!", icon="✅")
546-
placeholder.empty()
547545

548546
###################################
549547
# Results

src/server/api/core/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def get_client(model_config: dict, oci_config: schema.OracleCloudSettings, giska
150150
# schema.Model Classes
151151
model_classes = {}
152152
if not embedding:
153-
logger.debug("Configuring LL schema.Model")
153+
logger.debug("Configuring LL Model")
154154
ll_common_params = {}
155155
for key in [
156156
"temperature",
@@ -166,7 +166,7 @@ def get_client(model_config: dict, oci_config: schema.OracleCloudSettings, giska
166166
except KeyError:
167167
# Mainly for embeddings
168168
continue
169-
logger.debug("LL schema.Model Parameters: %s", ll_common_params)
169+
logger.debug("LL Model Parameters: %s", ll_common_params)
170170
model_classes = {
171171
"OpenAI": lambda: ChatOpenAI(model=model_id, api_key=model_api_key, **ll_common_params),
172172
"CompatOpenAI": lambda: ChatOpenAI(

src/server/api/v1/testbed.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from fastapi.responses import JSONResponse
1818
import litellm
1919
from langchain_core.messages import ChatMessage
20-
from server.api.core import settings, databases, models, oci
20+
import server.api.core.models as core_models
21+
from server.api.core import settings, databases, oci
2122
from server.api.utils import embed, testbed
2223
from server.api.v1 import chat
2324

@@ -140,8 +141,8 @@ async def testbed_generate_qa(
140141
) -> schema.TestSetQA:
141142
"""Retrieve contents from a local file uploaded and generate Q&A"""
142143
# Setup Models
143-
giskard_ll_model = models.get_model(model_id=ll_model, model_type="ll")
144-
giskard_embed_model = models.get_model(model_id=embed_model, model_type="embed")
144+
giskard_ll_model = core_models.get_model(model_id=ll_model, model_type="ll")
145+
giskard_embed_model = core_models.get_model(model_id=embed_model, model_type="embed")
145146
temp_directory = embed.get_temp_directory(client, "testbed")
146147
full_testsets = temp_directory / "all_testsets.jsonl"
147148

@@ -156,7 +157,7 @@ async def testbed_generate_qa(
156157

157158
# Process file for knowledge base
158159
text_nodes = testbed.load_and_split(filename)
159-
test_set = testbed.build_knowledge_base(text_nodes, questions, giskard_ll_model[0], giskard_embed_model[0])
160+
test_set = testbed.build_knowledge_base(text_nodes, questions, giskard_ll_model, giskard_embed_model)
160161
# Save test set
161162
test_set_filename = temp_directory / f"{name}.jsonl"
162163
test_set.save(test_set_filename)
@@ -222,7 +223,7 @@ def get_answer(question: str):
222223
# Setup Judge Model
223224
logger.debug("Starting evaluation with Judge: %s", judge)
224225
oci_config = oci.get_oci(client)
225-
judge_client = asyncio.run(models.get_client({"model": judge}, oci_config, True))
226+
judge_client = core_models.get_client({"model": judge}, oci_config, True)
226227
try:
227228
report = evaluate(get_answer, testset=loaded_testset, llm_client=judge_client, metrics=[correctness_metric])
228229
except KeyError as ex:

0 commit comments

Comments
 (0)