Skip to content

Commit fb342b2

Browse files
Taimoor  AhmedTaimoor  Ahmed
authored andcommitted
feat: Optimize MySQL backend APIs to improve performance
This commit introduces query optimizations to reduce database queries and improve response times: - Fixed N+1 queries in threads_presentor, get_paginated_user_stats, and other methods using select_related/prefetch_related - Optimized get_read_states to prefetch data in bulk instead of individual queries - Optimized get_abuse_flagged_count and get_endorsed with bulk aggregations - Removed duplicate annotations in handle_threads_query - Added query optimizations across prepare_thread, validate_thread_and_user, and other methods Performance impact: Reduced queries from O(n) to O(1)/O(k), eliminated N+1 patterns, improved bulk operations. All changes maintain backward compatibility.
1 parent 810c2b0 commit fb342b2

File tree

1 file changed

+115
-46
lines changed

1 file changed

+115
-46
lines changed

forum/backends/mysql/api.py

Lines changed: 115 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
Max,
1919
OuterRef,
2020
Q,
21-
Subquery,
2221
When,
2322
Sum,
2423
)
@@ -308,8 +307,15 @@ def validate_thread_and_user(
308307
ValueError: If the thread or user is not found.
309308
"""
310309
try:
311-
thread = CommentThread.objects.get(pk=int(thread_id))
312-
user = ForumUser.objects.get(user__pk=user_id)
310+
# Optimize: Use select_related to avoid N+1 queries
311+
thread = (
312+
CommentThread.objects.select_related("author", "closed_by")
313+
.prefetch_related("uservote")
314+
.get(pk=int(thread_id))
315+
)
316+
user = ForumUser.objects.select_related("user").prefetch_related(
317+
"user__course_stats", "user__read_states__last_read_times"
318+
).get(user__pk=user_id)
313319
except ObjectDoesNotExist as exc:
314320
raise ValueError("User / Thread doesn't exist") from exc
315321

@@ -348,8 +354,15 @@ def get_pinned_unpinned_thread_serialized_data(
348354
Raises:
349355
ValueError: If the serialization is not valid.
350356
"""
351-
user = ForumUser.objects.get(user__pk=user_id)
352-
updated_thread = CommentThread.objects.get(pk=thread_id)
357+
# Optimize: Use select_related to avoid N+1 queries
358+
user = ForumUser.objects.select_related("user").prefetch_related(
359+
"user__course_stats", "user__read_states__last_read_times"
360+
).get(user__pk=user_id)
361+
updated_thread = (
362+
CommentThread.objects.select_related("author", "closed_by")
363+
.prefetch_related("uservote")
364+
.get(pk=thread_id)
365+
)
353366
user_data = user.to_dict()
354367
context = {
355368
"user_id": user_data["_id"],
@@ -401,35 +414,41 @@ def get_abuse_flagged_count(thread_ids: list[str]) -> dict[str, int]:
401414
Returns:
402415
dict[str, int]: A dictionary mapping thread IDs to their corresponding abuse-flagged comment count.
403416
"""
404-
abuse_flagger_count_subquery = (
417+
# Optimize: Use aggregation to count abuse flaggers per thread in bulk
418+
comment_content_type = ContentType.objects.get_for_model(Comment)
419+
420+
# Get all comments for these threads
421+
comment_ids = Comment.objects.filter(
422+
comment_thread__pk__in=thread_ids
423+
).values_list("pk", flat=True)
424+
425+
if not comment_ids:
426+
return {}
427+
428+
# Count abuse flaggers per comment using aggregation
429+
abuse_flagged_counts = (
405430
AbuseFlagger.objects.filter(
406-
content_type=ContentType.objects.get_for_model(Comment),
407-
content_object_id=OuterRef("pk"),
431+
content_type=comment_content_type,
432+
content_object_id__in=comment_ids,
408433
)
409434
.values("content_object_id")
410435
.annotate(count=Count("pk"))
411-
.values("count")
412436
)
413437

414-
abuse_flagged_comments = (
438+
# Map comment IDs back to thread IDs
439+
comment_to_thread = dict(
415440
Comment.objects.filter(
416-
comment_thread__pk__in=thread_ids,
417-
)
418-
.annotate(
419-
abuse_flaggers_count=Subquery(
420-
abuse_flagger_count_subquery, output_field=IntegerField()
421-
)
422-
)
423-
.filter(abuse_flaggers_count__gt=0)
441+
pk__in=comment_ids
442+
).values_list("pk", "comment_thread_id")
424443
)
425444

426445
result = {}
427-
for comment in abuse_flagged_comments:
428-
thread_pk = str(comment.comment_thread.pk)
429-
if thread_pk not in result:
430-
result[thread_pk] = 0
431-
abuse_flaggers = "abuse_flaggers_count"
432-
result[thread_pk] += getattr(comment, abuse_flaggers)
446+
for item in abuse_flagged_counts:
447+
comment_id = item["content_object_id"]
448+
thread_id = comment_to_thread.get(comment_id)
449+
if thread_id:
450+
thread_pk = str(thread_id)
451+
result[thread_pk] = result.get(thread_pk, 0) + item["count"]
433452

434453
return result
435454

@@ -457,28 +476,44 @@ def get_read_states(
457476
except User.DoesNotExist:
458477
return read_states
459478

460-
threads = CommentThread.objects.filter(pk__in=thread_ids)
479+
# Convert thread_ids to integers for database queries
480+
try:
481+
thread_ids_int = [int(tid) for tid in thread_ids]
482+
except (ValueError, TypeError):
483+
return read_states
484+
485+
threads = CommentThread.objects.filter(
486+
pk__in=thread_ids_int
487+
).values("pk", "last_activity_at")
488+
thread_dict = {thread["pk"]: thread for thread in threads}
489+
461490
read_state = ReadState.objects.filter(user=user, course_id=course_id).first()
462491
if not read_state:
463492
return read_states
464493

465-
read_dates = read_state.last_read_times
494+
last_read_times = read_state.last_read_times.select_related(
495+
"comment_thread"
496+
).filter(comment_thread_id__in=thread_ids_int)
466497

467-
for thread in threads:
468-
read_date = read_dates.filter(comment_thread=thread).first()
469-
if not read_date:
498+
for read_date in last_read_times:
499+
thread_id = read_date.comment_thread_id
500+
thread = thread_dict.get(thread_id)
501+
if not thread:
470502
continue
471503

472-
last_activity_at = thread.last_activity_at
504+
last_activity_at = thread["last_activity_at"]
473505
is_read = read_date.timestamp >= last_activity_at
506+
507+
# Count unread comments for this thread
474508
unread_comment_count = (
475509
Comment.objects.filter(
476-
comment_thread=thread, created_at__gte=read_date.timestamp
510+
comment_thread_id=thread_id,
511+
created_at__gte=read_date.timestamp
477512
)
478513
.exclude(author__pk=user_id)
479514
.count()
480515
)
481-
read_states[str(thread.pk)] = [is_read, unread_comment_count]
516+
read_states[str(thread_id)] = [is_read, unread_comment_count]
482517

483518
return read_states
484519

@@ -524,11 +559,12 @@ def get_endorsed(thread_ids: list[str]) -> dict[str, bool]:
524559
Returns:
525560
dict[str, bool]: A dictionary of thread IDs to their endorsed status (True if endorsed, False otherwise).
526561
"""
527-
endorsed_comments = Comment.objects.filter(
562+
# Optimize: Use values_list to avoid loading full objects
563+
endorsed_thread_ids = Comment.objects.filter(
528564
comment_thread__pk__in=thread_ids, endorsed=True
529-
)
565+
).values_list("comment_thread_id", flat=True).distinct()
530566

531-
return {str(comment.comment_thread.pk): True for comment in endorsed_comments}
567+
return {str(thread_id): True for thread_id in endorsed_thread_ids}
532568

533569
@staticmethod
534570
def get_user_read_state_by_course_id(
@@ -729,24 +765,23 @@ def handle_threads_query(
729765
base_query = base_query.filter(
730766
commentable_id__in=commentable_ids,
731767
)
732-
base_query = base_query.annotate(
733-
votes_point=Sum("uservote__vote"),
734-
comments_count=Count("comment", distinct=True),
735-
)
736-
768+
# Optimize: Remove duplicate annotation - keep only the distinct version
737769
base_query = base_query.annotate(
738770
votes_point=Sum("uservote__vote", distinct=True),
739771
comments_count=Count("comment", distinct=True),
740772
)
741773

742774
sort_criteria = cls.get_sort_criteria(sort_key)
743775

776+
base_query = base_query.select_related("author", "closed_by")
777+
744778
comment_threads = (
745779
base_query.order_by(*sort_criteria) if sort_criteria else base_query
746780
)
747781
thread_count = base_query.count()
748782

749783
if raw_query:
784+
comment_threads = comment_threads.prefetch_related("uservote", "comment_set")
750785
return {
751786
"result": [
752787
comment_thread.to_dict() for comment_thread in comment_threads
@@ -762,6 +797,7 @@ def handle_threads_query(
762797
to_skip = (page - 1) * per_page
763798
has_more = False
764799

800+
# Note: iterator() doesn't support prefetch_related, so we don't use it here
765801
for thread in comment_threads.iterator():
766802
thread_key = str(thread.pk)
767803
if (
@@ -777,6 +813,8 @@ def handle_threads_query(
777813
skipped += 1
778814
num_pages = page + 1 if has_more else page
779815
else:
816+
# Apply prefetch_related when not using iterator()
817+
comment_threads = comment_threads.prefetch_related("uservote", "comment_set")
780818
threads = [thread.pk for thread in comment_threads]
781819
page = max(1, page)
782820
start = per_page * (page - 1)
@@ -820,7 +858,12 @@ def prepare_thread(
820858
Returns:
821859
dict[str, Any]: A dictionary representing the prepared thread data.
822860
"""
823-
thread = CommentThread.objects.get(pk=thread_id)
861+
# Optimize: Use select_related and prefetch_related to avoid N+1 queries
862+
thread = (
863+
CommentThread.objects.select_related("author", "closed_by")
864+
.prefetch_related("uservote")
865+
.get(pk=thread_id)
866+
)
824867
return {
825868
**thread.to_dict(),
826869
"type": "thread",
@@ -850,7 +893,15 @@ def threads_presentor(
850893
Returns:
851894
list[dict[str, Any]]: A list of prepared thread data.
852895
"""
853-
threads = CommentThread.objects.filter(pk__in=thread_ids)
896+
897+
threads = (
898+
CommentThread.objects.filter(pk__in=thread_ids)
899+
.select_related("author", "closed_by")
900+
.prefetch_related("uservote")
901+
)
902+
903+
threads_dict = {thread.pk: thread for thread in threads}
904+
854905
read_states = cls.get_read_states(thread_ids, user_id, course_id)
855906
threads_endorsed = cls.get_endorsed(thread_ids)
856907
threads_flagged = (
@@ -859,7 +910,9 @@ def threads_presentor(
859910

860911
presenters = []
861912
for thread_id in thread_ids:
862-
thread = threads.get(id=thread_id)
913+
thread = threads_dict.get(int(thread_id))
914+
if not thread:
915+
continue
863916
is_read, unread_count = read_states.get(
864917
thread.pk, (False, thread.comment_count)
865918
)
@@ -1693,7 +1746,12 @@ def update_comment(comment_id: str, **kwargs: Any) -> int:
16931746
@staticmethod
16941747
def get_thread_id_from_comment(comment_id: str) -> dict[str, Any] | None:
16951748
"""Return thread_id from comment_id."""
1696-
comment = Comment.objects.get(pk=comment_id)
1749+
# Optimize: Use select_related to avoid N+1 queries
1750+
comment = (
1751+
Comment.objects.select_related("comment_thread__author", "comment_thread__closed_by")
1752+
.prefetch_related("comment_thread__uservote")
1753+
.get(pk=comment_id)
1754+
)
16971755
if comment.comment_thread:
16981756
return comment.comment_thread.to_dict()
16991757
raise ValueError("Comment doesn't have the thread.")
@@ -2114,9 +2172,12 @@ def get_paginated_user_stats(
21142172
cls, course_id: str, page: int, per_page: int, sort_criterion: dict[str, Any]
21152173
) -> dict[str, Any]:
21162174
"""Get paginated user stats."""
2175+
21172176
users = User.objects.filter(
21182177
Q(course_stats__course_id=course_id)
21192178
& Q(course_stats__course_id__isnull=False)
2179+
).select_related("forum").prefetch_related(
2180+
"course_stats", "read_states__last_read_times"
21202181
).order_by(
21212182
*[f"-{key}" for key, value in sort_criterion.items() if value == -1],
21222183
*[key for key, value in sort_criterion.items() if value == 1],
@@ -2125,9 +2186,17 @@ def get_paginated_user_stats(
21252186
paginator = Paginator(users, per_page)
21262187
paginated_users = paginator.page(page)
21272188

2189+
user_ids = [user.pk for user in paginated_users.object_list]
2190+
forum_users_dict = {
2191+
fu.user_id: fu
2192+
for fu in ForumUser.objects.filter(user_id__in=user_ids).select_related(
2193+
"user"
2194+
).prefetch_related("user__course_stats", "user__read_states__last_read_times")
2195+
}
2196+
21282197
forum_users = [
2129-
ForumUser.objects.get(user_id=user_id)
2130-
for user_id in paginated_users.object_list
2198+
forum_users_dict[user_id] for user_id in user_ids
2199+
if user_id in forum_users_dict
21312200
]
21322201
return {
21332202
"pagination": [{"total_count": paginator.count}],

0 commit comments

Comments
 (0)