Skip to content

Commit 783e449

Browse files
committed
hybrid search v0
1 parent fc72b43 commit 783e449

File tree

17 files changed

+1486
-389
lines changed

17 files changed

+1486
-389
lines changed

frontends/api/src/generated/v1/api.ts

Lines changed: 33 additions & 24 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

frontends/main/src/app/c/[channelType]/[name]/page.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ const Page: React.FC<PageProps<"/c/[channelType]/[name]">> = async ({
9696
)
9797

9898
const searchRequest = getSearchParams({
99+
// @ts-expect-error -- this will error until mitodl/mit-learn-api-axios is updated
99100
requestParams: validateRequestParams(search),
100101
constantSearchParams,
101102
facetNames,

frontends/main/src/app/search/page.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ const Page: React.FC<PageProps<"/search">> = async ({ searchParams }) => {
2828
}
2929

3030
const params = getSearchParams({
31+
// @ts-expect-error -- this will error until mitodl/mit-learn-api-axios is updated
3132
requestParams: validateRequestParams(search),
3233
constantSearchParams: {},
3334
facetNames,

learning_resources_search/api.py

Lines changed: 108 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313
from learning_resources.models import LearningResource
1414
from learning_resources_search.connection import (
1515
get_default_alias_name,
16+
get_vector_model_id,
1617
)
1718
from learning_resources_search.constants import (
19+
COMBINED_INDEX,
1820
CONTENT_FILE_TYPE,
1921
COURSE_QUERY_FIELDS,
2022
COURSE_TYPE,
2123
DEPARTMENT_QUERY_FIELDS,
24+
HYBRID_SEARCH_MODE,
2225
LEARNING_RESOURCE,
2326
LEARNING_RESOURCE_QUERY_FIELDS,
2427
LEARNING_RESOURCE_SEARCH_SORTBY_OPTIONS,
@@ -66,21 +69,24 @@ def gen_content_file_id(content_file_id):
6669
return f"cf_{content_file_id}"
6770

6871

69-
def relevant_indexes(resource_types, aggregations, endpoint):
72+
def relevant_indexes(resource_types, aggregations, endpoint, use_hybrid_search):
7073
"""
7174
Return list of relevent index type for the query
7275
7376
Args:
7477
resource_types (list): the resource type parameter for the search
7578
aggregations (list): the aggregations parameter for the search
7679
endpoint (string): the endpoint: learning_resource or content_file
80+
use_hybrid_search (bool): whether to use hybrid search
7781
7882
Returns:
7983
Array(string): array of index names
8084
8185
"""
8286
if endpoint == CONTENT_FILE_TYPE:
8387
return [get_default_alias_name(COURSE_TYPE)]
88+
elif use_hybrid_search:
89+
return [get_default_alias_name(COMBINED_INDEX)]
8490

8591
if aggregations and "resource_type" in aggregations:
8692
return map(get_default_alias_name, LEARNING_RESOURCE_TYPES)
@@ -143,18 +149,24 @@ def generate_sort_clause(search_params):
143149
return sort
144150

145151

146-
def wrap_text_clause(text_query, min_score=None):
152+
def wrap_text_clause(
153+
text_query,
154+
use_hybrid_search,
155+
min_score=None,
156+
):
147157
"""
148158
Wrap the text subqueries in a bool query
149159
Shared by generate_content_file_text_clause and
150160
generate_learning_resources_text_clause
151161
152162
Args:
153163
text_query (dict): dictionary with the opensearch text clauses
164+
min_score (float): minimum score for function score query
165+
use_hybrid_search (bool): whether to use hybrid search
154166
Returns:
155167
dict: dictionary with the opensearch text clause
156168
"""
157-
if min_score and text_query:
169+
if not use_hybrid_search and min_score and text_query:
158170
text_bool_clause = [
159171
{"function_score": {"query": {"bool": text_query}, "min_score": min_score}}
160172
]
@@ -207,7 +219,7 @@ def generate_content_file_text_clause(text):
207219
else:
208220
text_query = {}
209221

210-
return wrap_text_clause(text_query)
222+
return wrap_text_clause(text_query, use_hybrid_search=False)
211223

212224

213225
def generate_learning_resources_text_clause(
@@ -222,16 +234,23 @@ def generate_learning_resources_text_clause(
222234
dict: dictionary with the opensearch text clause
223235
"""
224236

237+
use_hybrid_search = search_mode == HYBRID_SEARCH_MODE
238+
225239
query_type = (
226240
"query_string" if text.startswith('"') and text.endswith('"') else "multi_match"
227241
)
228242

229243
extra_params = {}
230244

231-
if query_type == "multi_match" and search_mode:
232-
extra_params["type"] = search_mode
245+
if use_hybrid_search:
246+
text_search_mode = settings.DEFAULT_SEARCH_MODE
247+
else:
248+
text_search_mode = search_mode
249+
250+
if query_type == "multi_match" and text_search_mode:
251+
extra_params["type"] = text_search_mode
233252

234-
if search_mode == "phrase" and slop:
253+
if text_search_mode == "phrase" and slop:
235254
extra_params["slop"] = slop
236255

237256
if content_file_score_weight is not None:
@@ -335,7 +354,7 @@ def generate_learning_resources_text_clause(
335354
else:
336355
text_query = {}
337356

338-
return wrap_text_clause(text_query, min_score)
357+
return wrap_text_clause(text_query, use_hybrid_search, min_score)
339358

340359

341360
def generate_filter_clause(
@@ -573,7 +592,9 @@ def percolate_matches_for_document(document_id):
573592
return percolated_queries
574593

575594

576-
def add_text_query_to_search(search, text, search_params, query_type_query):
595+
def add_text_query_to_search(
596+
search, text, search_params, query_type_query, use_hybrid_search
597+
):
577598
if search_params.get("endpoint") == CONTENT_FILE_TYPE:
578599
text_query = generate_content_file_text_clause(text)
579600
else:
@@ -590,7 +611,7 @@ def add_text_query_to_search(search, text, search_params, query_type_query):
590611
search_params.get("max_incompleteness_penalty", 0) / 100
591612
)
592613

593-
if yearly_decay_percent or max_incompleteness_penalty:
614+
if not use_hybrid_search and (yearly_decay_percent or max_incompleteness_penalty):
594615
script_query = {
595616
"function_score": {
596617
"query": {"bool": {"must": [text_query], "filter": query_type_query}}
@@ -626,14 +647,56 @@ def add_text_query_to_search(search, text, search_params, query_type_query):
626647
"params": params,
627648
}
628649

629-
search = search.query(script_query)
650+
text_query = script_query
651+
else:
652+
text_query = {"bool": {"must": [text_query], "filter": query_type_query}}
653+
654+
if use_hybrid_search:
655+
vector_model_id = get_vector_model_id()
656+
if not vector_model_id:
657+
log.error("Vector model not found. Cannot perform hybrid search.")
658+
error_message = "Vector model not found."
659+
raise ValueError(error_message)
660+
661+
vector_query_description = {
662+
"neural": {
663+
"description_embedding": {
664+
"query_text": text,
665+
"model_id": vector_model_id,
666+
"min_score": 0.015,
667+
},
668+
}
669+
}
670+
671+
vector_query_title = {
672+
"neural": {
673+
"title_embedding": {
674+
"query_text": text,
675+
"model_id": vector_model_id,
676+
"min_score": 0.015,
677+
},
678+
}
679+
}
680+
681+
search = search.extra(
682+
query={
683+
"hybrid": {
684+
"pagination_depth": 10,
685+
"queries": [
686+
text_query,
687+
vector_query_description,
688+
vector_query_title,
689+
],
690+
}
691+
}
692+
)
630693
else:
631-
search = search.query("bool", must=[text_query], filter=query_type_query)
694+
search = search.query(text_query)
632695

633696
return search
634697

635698

636-
def construct_search(search_params):
699+
def construct_search(search_params): # noqa: C901
637700
"""
638701
Construct a learning resources search based on the query
639702
@@ -652,16 +715,20 @@ def construct_search(search_params):
652715
):
653716
search_params["resource_type"] = list(LEARNING_RESOURCE_TYPES)
654717

718+
use_hybrid_search = search_params.get("search_mode") == HYBRID_SEARCH_MODE
719+
655720
indexes = relevant_indexes(
656721
search_params.get("resource_type"),
657722
search_params.get("aggregations"),
658723
search_params.get("endpoint"),
724+
use_hybrid_search,
659725
)
660726

661727
search = Search(index=",".join(indexes))
662728

663729
search = search.source(fields={"excludes": SOURCE_EXCLUDED_FIELDS})
664-
search = search.params(search_type="dfs_query_then_fetch")
730+
if not use_hybrid_search:
731+
search = search.params(search_type="dfs_query_then_fetch")
665732
if search_params.get("offset"):
666733
search = search.extra(from_=search_params.get("offset"))
667734

@@ -683,14 +750,9 @@ def construct_search(search_params):
683750
text = re.sub("[\u201c\u201d]", '"', search_params.get("q"))
684751

685752
search = add_text_query_to_search(
686-
search,
687-
text,
688-
search_params,
689-
query_type_query,
753+
search, text, search_params, query_type_query, use_hybrid_search
690754
)
691755

692-
suggest = generate_suggest_clause(text)
693-
search = search.extra(suggest=suggest)
694756
else:
695757
search = search.query(query_type_query)
696758

@@ -727,6 +789,7 @@ def execute_learn_search(search_params):
727789
search_params["yearly_decay_percent"] = (
728790
settings.DEFAULT_SEARCH_STALENESS_PENALTY
729791
)
792+
730793
if search_params.get("search_mode") is None:
731794
search_params["search_mode"] = settings.DEFAULT_SEARCH_MODE
732795
if search_params.get("slop") is None:
@@ -738,6 +801,25 @@ def execute_learn_search(search_params):
738801
settings.DEFAULT_SEARCH_MAX_INCOMPLETENESS_PENALTY
739802
)
740803
search = construct_search(search_params)
804+
805+
if search_params.get("search_mode") == HYBRID_SEARCH_MODE:
806+
search = search.extra(
807+
search_pipeline={
808+
"description": "Post processor for hybrid search",
809+
"phase_results_processors": [
810+
{
811+
"normalization-processor": {
812+
"normalization": {"technique": "min_max"},
813+
"combination": {
814+
"technique": "arithmetic_mean",
815+
"parameters": {"weights": [0.6, 0.2, 0.2]},
816+
},
817+
}
818+
}
819+
],
820+
}
821+
)
822+
741823
results = search.execute().to_dict()
742824
if results.get("_shards", {}).get("failures"):
743825
log.error(
@@ -904,7 +986,9 @@ def get_similar_topics(
904986
list of str:
905987
list of topic values
906988
"""
907-
indexes = relevant_indexes([COURSE_TYPE], [], endpoint=LEARNING_RESOURCE)
989+
indexes = relevant_indexes(
990+
[COURSE_TYPE], [], endpoint=LEARNING_RESOURCE, use_hybrid_search=False
991+
)
908992
search = Search(index=",".join(indexes))
909993
search = search.filter("term", resource_type=COURSE_TYPE)
910994
search = search.query(
@@ -1051,7 +1135,9 @@ def get_similar_resources_opensearch(
10511135
list of str:
10521136
list of learning resources
10531137
"""
1054-
indexes = relevant_indexes(LEARNING_RESOURCE_TYPES, [], endpoint=LEARNING_RESOURCE)
1138+
indexes = relevant_indexes(
1139+
LEARNING_RESOURCE_TYPES, [], endpoint=LEARNING_RESOURCE, use_hybrid_search=False
1140+
)
10551141
search = Search(index=",".join(indexes))
10561142
if num_resources:
10571143
# adding +1 to num_resources since we filter out existing resource.id

0 commit comments

Comments
 (0)