Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 33 additions & 24 deletions frontends/api/src/generated/v1/api.ts

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions frontends/main/src/app/c/[channelType]/[name]/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ const Page: React.FC<PageProps<"/c/[channelType]/[name]">> = async ({
)

const searchRequest = getSearchParams({
// @ts-expect-error -- this will error until mitodl/mit-learn-api-axios is updated
requestParams: validateRequestParams(search),
constantSearchParams,
facetNames,
Expand Down
1 change: 1 addition & 0 deletions frontends/main/src/app/search/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ const Page: React.FC<PageProps<"/search">> = async ({ searchParams }) => {
}

const params = getSearchParams({
// @ts-expect-error -- this will error until mitodl/mit-learn-api-axios is updated
requestParams: validateRequestParams(search),
constantSearchParams: {},
facetNames,
Expand Down
130 changes: 108 additions & 22 deletions learning_resources_search/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
from learning_resources.models import LearningResource
from learning_resources_search.connection import (
get_default_alias_name,
get_vector_model_id,
)
from learning_resources_search.constants import (
COMBINED_INDEX,
CONTENT_FILE_TYPE,
COURSE_QUERY_FIELDS,
COURSE_TYPE,
DEPARTMENT_QUERY_FIELDS,
HYBRID_SEARCH_MODE,
LEARNING_RESOURCE,
LEARNING_RESOURCE_QUERY_FIELDS,
LEARNING_RESOURCE_SEARCH_SORTBY_OPTIONS,
Expand Down Expand Up @@ -66,21 +69,24 @@ def gen_content_file_id(content_file_id):
return f"cf_{content_file_id}"


def relevant_indexes(resource_types, aggregations, endpoint):
def relevant_indexes(resource_types, aggregations, endpoint, use_hybrid_search):
"""
Return list of relevent index type for the query

Args:
resource_types (list): the resource type parameter for the search
aggregations (list): the aggregations parameter for the search
endpoint (string): the endpoint: learning_resource or content_file
use_hybrid_search (bool): whether to use hybrid search

Returns:
Array(string): array of index names

"""
if endpoint == CONTENT_FILE_TYPE:
return [get_default_alias_name(COURSE_TYPE)]
elif use_hybrid_search:
return [get_default_alias_name(COMBINED_INDEX)]

if aggregations and "resource_type" in aggregations:
return map(get_default_alias_name, LEARNING_RESOURCE_TYPES)
Expand Down Expand Up @@ -143,18 +149,24 @@ def generate_sort_clause(search_params):
return sort


def wrap_text_clause(text_query, min_score=None):
def wrap_text_clause(
text_query,
use_hybrid_search,
min_score=None,
):
"""
Wrap the text subqueries in a bool query
Shared by generate_content_file_text_clause and
generate_learning_resources_text_clause

Args:
text_query (dict): dictionary with the opensearch text clauses
min_score (float): minimum score for function score query
use_hybrid_search (bool): whether to use hybrid search
Returns:
dict: dictionary with the opensearch text clause
"""
if min_score and text_query:
if not use_hybrid_search and min_score and text_query:
text_bool_clause = [
{"function_score": {"query": {"bool": text_query}, "min_score": min_score}}
]
Expand Down Expand Up @@ -207,7 +219,7 @@ def generate_content_file_text_clause(text):
else:
text_query = {}

return wrap_text_clause(text_query)
return wrap_text_clause(text_query, use_hybrid_search=False)


def generate_learning_resources_text_clause(
Expand All @@ -222,16 +234,23 @@ def generate_learning_resources_text_clause(
dict: dictionary with the opensearch text clause
"""

use_hybrid_search = search_mode == HYBRID_SEARCH_MODE

query_type = (
"query_string" if text.startswith('"') and text.endswith('"') else "multi_match"
)

extra_params = {}

if query_type == "multi_match" and search_mode:
extra_params["type"] = search_mode
if use_hybrid_search:
text_search_mode = settings.DEFAULT_SEARCH_MODE
else:
text_search_mode = search_mode

if query_type == "multi_match" and text_search_mode:
extra_params["type"] = text_search_mode

if search_mode == "phrase" and slop:
if text_search_mode == "phrase" and slop:
extra_params["slop"] = slop

if content_file_score_weight is not None:
Expand Down Expand Up @@ -335,7 +354,7 @@ def generate_learning_resources_text_clause(
else:
text_query = {}

return wrap_text_clause(text_query, min_score)
return wrap_text_clause(text_query, use_hybrid_search, min_score)


def generate_filter_clause(
Expand Down Expand Up @@ -573,7 +592,9 @@ def percolate_matches_for_document(document_id):
return percolated_queries


def add_text_query_to_search(search, text, search_params, query_type_query):
def add_text_query_to_search(
search, text, search_params, query_type_query, use_hybrid_search
):
if search_params.get("endpoint") == CONTENT_FILE_TYPE:
text_query = generate_content_file_text_clause(text)
else:
Expand All @@ -590,7 +611,7 @@ def add_text_query_to_search(search, text, search_params, query_type_query):
search_params.get("max_incompleteness_penalty", 0) / 100
)

if yearly_decay_percent or max_incompleteness_penalty:
if not use_hybrid_search and (yearly_decay_percent or max_incompleteness_penalty):
script_query = {
"function_score": {
"query": {"bool": {"must": [text_query], "filter": query_type_query}}
Expand Down Expand Up @@ -626,14 +647,56 @@ def add_text_query_to_search(search, text, search_params, query_type_query):
"params": params,
}

search = search.query(script_query)
text_query = script_query
else:
text_query = {"bool": {"must": [text_query], "filter": query_type_query}}

if use_hybrid_search:
vector_model_id = get_vector_model_id()
if not vector_model_id:
log.error("Vector model not found. Cannot perform hybrid search.")
error_message = "Vector model not found."
raise ValueError(error_message)

vector_query_description = {
"neural": {
"description_embedding": {
"query_text": text,
"model_id": vector_model_id,
"min_score": 0.015,
},
}
}

vector_query_title = {
"neural": {
"title_embedding": {
"query_text": text,
"model_id": vector_model_id,
"min_score": 0.015,
},
}
}

search = search.extra(
query={
"hybrid": {
"pagination_depth": 10,
"queries": [
text_query,
vector_query_description,
vector_query_title,
],
}
}
)
else:
search = search.query("bool", must=[text_query], filter=query_type_query)
search = search.query(text_query)

return search


def construct_search(search_params):
def construct_search(search_params): # noqa: C901
"""
Construct a learning resources search based on the query

Expand All @@ -652,16 +715,20 @@ def construct_search(search_params):
):
search_params["resource_type"] = list(LEARNING_RESOURCE_TYPES)

use_hybrid_search = search_params.get("search_mode") == HYBRID_SEARCH_MODE

indexes = relevant_indexes(
search_params.get("resource_type"),
search_params.get("aggregations"),
search_params.get("endpoint"),
use_hybrid_search,
)

search = Search(index=",".join(indexes))

search = search.source(fields={"excludes": SOURCE_EXCLUDED_FIELDS})
search = search.params(search_type="dfs_query_then_fetch")
if not use_hybrid_search:
search = search.params(search_type="dfs_query_then_fetch")
if search_params.get("offset"):
search = search.extra(from_=search_params.get("offset"))

Expand All @@ -683,14 +750,9 @@ def construct_search(search_params):
text = re.sub("[\u201c\u201d]", '"', search_params.get("q"))

search = add_text_query_to_search(
search,
text,
search_params,
query_type_query,
search, text, search_params, query_type_query, use_hybrid_search
)

suggest = generate_suggest_clause(text)
search = search.extra(suggest=suggest)
else:
search = search.query(query_type_query)

Expand Down Expand Up @@ -727,6 +789,7 @@ def execute_learn_search(search_params):
search_params["yearly_decay_percent"] = (
settings.DEFAULT_SEARCH_STALENESS_PENALTY
)

if search_params.get("search_mode") is None:
search_params["search_mode"] = settings.DEFAULT_SEARCH_MODE
if search_params.get("slop") is None:
Expand All @@ -738,6 +801,25 @@ def execute_learn_search(search_params):
settings.DEFAULT_SEARCH_MAX_INCOMPLETENESS_PENALTY
)
search = construct_search(search_params)

if search_params.get("search_mode") == HYBRID_SEARCH_MODE:
search = search.extra(
search_pipeline={
"description": "Post processor for hybrid search",
"phase_results_processors": [
{
"normalization-processor": {
"normalization": {"technique": "min_max"},
"combination": {
"technique": "arithmetic_mean",
"parameters": {"weights": [0.6, 0.2, 0.2]},
},
}
}
],
}
)

results = search.execute().to_dict()
if results.get("_shards", {}).get("failures"):
log.error(
Expand Down Expand Up @@ -904,7 +986,9 @@ def get_similar_topics(
list of str:
list of topic values
"""
indexes = relevant_indexes([COURSE_TYPE], [], endpoint=LEARNING_RESOURCE)
indexes = relevant_indexes(
[COURSE_TYPE], [], endpoint=LEARNING_RESOURCE, use_hybrid_search=False
)
search = Search(index=",".join(indexes))
search = search.filter("term", resource_type=COURSE_TYPE)
search = search.query(
Expand Down Expand Up @@ -1051,7 +1135,9 @@ def get_similar_resources_opensearch(
list of str:
list of learning resources
"""
indexes = relevant_indexes(LEARNING_RESOURCE_TYPES, [], endpoint=LEARNING_RESOURCE)
indexes = relevant_indexes(
LEARNING_RESOURCE_TYPES, [], endpoint=LEARNING_RESOURCE, use_hybrid_search=False
)
search = Search(index=",".join(indexes))
if num_resources:
# adding +1 to num_resources since we filter out existing resource.id
Expand Down
Loading
Loading