diff --git a/forum/api/threads.py b/forum/api/threads.py index d966e8f9..a8f11a99 100644 --- a/forum/api/threads.py +++ b/forum/api/threads.py @@ -33,6 +33,7 @@ def _get_thread_data_from_request_data(data: dict[str, Any]) -> dict[str, Any]: "pinned", "group_id", "context", + "user_group_ids", ] result = {field: data.get(field) for field in fields if data.get(field) is not None} @@ -286,6 +287,7 @@ def create_thread( thread_type: str = "discussion", group_id: Optional[int] = None, context: str = "course", + user_group_ids: Optional[list[int]] = None, ) -> dict[str, Any]: """ Create a new thread. @@ -315,6 +317,7 @@ def create_thread( "thread_type": thread_type, "group_id": group_id, "context": context, + "user_group_ids": user_group_ids, } thread_data: dict[str, Any] = _get_thread_data_from_request_data(data) @@ -380,6 +383,7 @@ def get_user_threads( "user_id": user_id, "group_id": group_id, "group_ids": group_ids, + "user_group_ids": kwargs.get("user_group_ids"), } params = {k: v for k, v in params.items() if v is not None} backend.validate_params(params) diff --git a/forum/api/users.py b/forum/api/users.py index c13c59fa..784dbdae 100644 --- a/forum/api/users.py +++ b/forum/api/users.py @@ -197,6 +197,7 @@ def get_user_active_threads( page: Optional[int] = FORUM_DEFAULT_PAGE, per_page: Optional[int] = FORUM_DEFAULT_PER_PAGE, group_id: Optional[str] = None, + **kwargs, ) -> dict[str, Any]: """Get user active threads.""" backend = get_backend(course_id)() @@ -237,6 +238,7 @@ def get_user_active_threads( "user_id": user_id, "course_id": course_id, "group_ids": [int(group_id)] if group_id else [], + "user_group_ids": kwargs.get("user_group_ids"), "author_id": author_id, "thread_type": thread_type, "filter_flagged": flagged, diff --git a/forum/backends/mysql/api.py b/forum/backends/mysql/api.py index 8d8621f3..88ce0c8f 100644 --- a/forum/backends/mysql/api.py +++ b/forum/backends/mysql/api.py @@ -43,7 +43,6 @@ from forum.constants import RETIRED_BODY, RETIRED_TITLE from forum.utils import get_group_ids_from_params - class MySQLBackend(AbstractBackend): """MySQL backend api.""" @@ -606,6 +605,7 @@ def handle_threads_query( per_page: int, context: str = "course", raw_query: bool = False, + **kwargs: Any, # We use kwargs for not modifying the function signature ) -> dict[str, Any]: """ Handles complex thread queries based on various filters and returns paginated results. @@ -658,6 +658,13 @@ def handle_threads_query( Q(group_id__in=group_ids) | Q(group_id__isnull=True) ) + # User group filtering + if kwargs.get("user_group_ids"): + user_groups_filter = Q(user_group_ids__isnull=True) + for group_id in kwargs.get("user_group_ids"): + user_groups_filter |= Q(user_group_ids__contains=group_id) + base_query = base_query.filter(user_groups_filter) + # Author filtering if author_id: base_query = base_query.filter(author__pk=author_id) @@ -1018,6 +1025,7 @@ def validate_params( "commentable_ids", "group_id", "group_ids", + "user_group_ids", ] if not user_id: valid_params.append("user_id") @@ -1071,6 +1079,7 @@ def get_threads( params.get("sort_key", ""), int(params.get("page", 1)), int(params.get("per_page", 100)), + user_group_ids=params.get("user_group_ids"), ) context: dict[str, Any] = { "count_flagged": count_flagged, @@ -1753,6 +1762,8 @@ def create_thread(data: dict[str, Any]) -> str: optional_args = {} if group_id := data.get("group_id"): optional_args["group_id"] = group_id + if user_group_ids := data.get("user_group_ids"): + optional_args["user_group_ids"] = user_group_ids new_thread = CommentThread.objects.create( title=data["title"], body=data["body"], diff --git a/forum/backends/mysql/models.py b/forum/backends/mysql/models.py index d7cfe6ec..e462b786 100644 --- a/forum/backends/mysql/models.py +++ b/forum/backends/mysql/models.py @@ -107,6 +107,9 @@ class Content(models.Model): group_id: models.PositiveIntegerField[int, int] = models.PositiveIntegerField( null=True ) + user_group_ids: models.JSONField[list[int], list[int]] = models.JSONField( + null=True, + ) created_at: models.DateTimeField[datetime, datetime] = models.DateTimeField( auto_now_add=True ) @@ -294,6 +297,7 @@ def to_dict(self) -> dict[str, Any]: "last_activity_at": self.last_activity_at, "edit_history": edit_history, "group_id": self.group_id, + "user_group_ids": self.user_group_ids, } def doc_to_hash(self) -> dict[str, Any]: diff --git a/forum/migrations/0004_comment_user_group_ids_commentthread_user_group_ids.py b/forum/migrations/0004_comment_user_group_ids_commentthread_user_group_ids.py new file mode 100644 index 00000000..eb60ed17 --- /dev/null +++ b/forum/migrations/0004_comment_user_group_ids_commentthread_user_group_ids.py @@ -0,0 +1,23 @@ +# Generated by Django 4.2.16 on 2025-06-18 23:58 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("forum", "0003_alter_commentthread_title"), + ] + + operations = [ + migrations.AddField( + model_name="comment", + name="user_group_ids", + field=models.JSONField(null=True), + ), + migrations.AddField( + model_name="commentthread", + name="user_group_ids", + field=models.JSONField(null=True), + ), + ] diff --git a/forum/serializers/thread.py b/forum/serializers/thread.py index c2f8b85f..8a321f79 100644 --- a/forum/serializers/thread.py +++ b/forum/serializers/thread.py @@ -64,6 +64,7 @@ class ThreadSerializer(ContentSerializer): resp_total = serializers.SerializerMethodField(required=False) resp_skip = serializers.IntegerField(required=False, default=0) resp_limit = serializers.IntegerField(required=False, default=10) + user_group_ids = serializers.ListField(allow_null=True, default=None) def __init__(self, *args: Any, **kwargs: Any) -> None: """