diff --git a/.github/workflows/mongodb_settings.py b/.github/workflows/mongodb_settings.py deleted file mode 100644 index f571e11..0000000 --- a/.github/workflows/mongodb_settings.py +++ /dev/null @@ -1,14 +0,0 @@ -DATABASES = { - "default": { - "ENGINE": "django_documentdb", - "NAME": "djangotests", - }, - "other": { - "ENGINE": "django_documentdb", - "NAME": "djangotests-other", - }, -} -DEFAULT_AUTO_FIELD = "django_documentdb.fields.ObjectIdAutoField" -PASSWORD_HASHERS = ("django.contrib.auth.hashers.MD5PasswordHasher",) -SECRET_KEY = "django_tests_secret_key" -USE_TZ = False diff --git a/.gitignore b/.gitignore index 35ac288..0276ad2 100644 --- a/.gitignore +++ b/.gitignore @@ -51,3 +51,5 @@ docs/build site/ _development/ +tests/django +documentdb_settings.py diff --git a/README.md b/README.md index dd45cba..cfe8265 100644 --- a/README.md +++ b/README.md @@ -110,3 +110,9 @@ class TestModel(DocumentModel): ## Forked Project This project, **django-documentdb**, is a fork of the original **django-mongodb** library, which aimed to integrate MongoDB with Django. The fork was created to enhance compatibility with AWS DocumentDB, addressing the limitations of its API support while maintaining the core functionalities of the original library. We appreciate the work of the MongoDB Python Team and aim to build upon their foundation to better serve users needing DocumentDB integration. + +## Run tests + +docker build . -t test:latest -f tests/Dockerfile && docker run -it test:latest + +docker build . -t mongo_test:latest -f tests/mongodb.Dockerfile && docker run -it mongo_test:latest diff --git a/django_documentdb/aggregates.py b/django_documentdb/aggregates.py index 6eae38d..28e94d3 100644 --- a/django_documentdb/aggregates.py +++ b/django_documentdb/aggregates.py @@ -3,6 +3,7 @@ from django.db.models.lookups import IsNull from .query_utils import process_lhs +from .utils import prefix_with_dollar # Aggregates whose MongoDB aggregation name differ from Aggregate.function.lower(). MONGO_AGGREGATIONS = {Count: "sum"} @@ -26,9 +27,9 @@ def aggregate( node = self lhs_mql = process_lhs(node, compiler, connection) if resolve_inner_expression: - return lhs_mql + return prefix_with_dollar(lhs_mql) operator = operator or MONGO_AGGREGATIONS.get(self.__class__, self.function.lower()) - return {f"${operator}": f"${lhs_mql}"} + return {f"${operator}": prefix_with_dollar(lhs_mql)} def count(self, compiler, connection, resolve_inner_expression=False, **extra_context): # noqa: ARG001 @@ -64,12 +65,8 @@ def count(self, compiler, connection, resolve_inner_expression=False, **extra_co return {"$add": [{"$size": lhs_mql}, exits_null]} -def stddev_variance(self, compiler, connection, **extra_context): - if self.function.endswith("_SAMP"): - operator = "stdDevSamp" - elif self.function.endswith("_POP"): - operator = "stdDevPop" - return aggregate(self, compiler, connection, operator=operator, **extra_context) +def stddev_variance(*args, **kwargs): # noqa: ARG001 + raise NotImplementedError("StdDev and Variance are not supported yet.") def register_aggregates(): diff --git a/django_documentdb/base.py b/django_documentdb/base.py index 9043ca9..4ac6799 100644 --- a/django_documentdb/base.py +++ b/django_documentdb/base.py @@ -13,7 +13,7 @@ from .operations import DatabaseOperations from .query_utils import regex_match from .schema import DatabaseSchemaEditor -from .utils import IndexNotUsedWarning, OperationDebugWrapper +from .utils import IndexNotUsedWarning, OperationDebugWrapper, prefix_with_dollar # ignore warning from pymongo about DocumentDB warnings.filterwarnings("ignore", "You appear to be connected to a DocumentDB cluster", UserWarning) @@ -87,37 +87,58 @@ class DatabaseWrapper(BaseDatabaseWrapper): "iendswith": "LIKE '%%' || UPPER({})", } - def _isnull_operator(a, b): + def _isnull_operator(a, b, pos: bool = False): if b: - return {a: None} + return {a: None} if not pos else {"$eq": [prefix_with_dollar(a), None]} warnings.warn("You're using $ne, index will not be used", IndexNotUsedWarning, stacklevel=1) - return {a: {"$ne": None}} + return {a: {"$ne": None}} if not pos else {"$ne": [prefix_with_dollar(a), None]} mongo_operators = { - # Where a = field_name, b = value - "exact": lambda a, b: {a: b}, - "gt": lambda a, b: {a: {"$gt": b}}, - "gte": lambda a, b: {a: {"$gte": b}}, - "lt": lambda a, b: {a: {"$lt": b}}, - "lte": lambda a, b: {a: {"$lte": b}}, - "in": lambda a, b: {a: {"$in": b}}, + # Where a = field_name, b = value, pos = positional operator syntax + "exact": lambda a, b, pos: {a: b} if not pos else {"$eq": [prefix_with_dollar(a), b]}, + "gt": lambda a, b, pos: {a: {"$gt": b}} if not pos else {"$gt": [prefix_with_dollar(a), b]}, + "gte": lambda a, b, pos: {a: {"$gte": b}} + if not pos + else {"$gte": [prefix_with_dollar(a), b]}, + "lt": lambda a, b, pos: {a: {"$lt": b}} if not pos else {"$lt": [prefix_with_dollar(a), b]}, + "lte": lambda a, b, pos: {a: {"$lte": b}} + if not pos + else {"$lte": [prefix_with_dollar(a), b]}, + "in": lambda a, b, pos: {a: {"$in": b}} if not pos else {"$in": [prefix_with_dollar(a), b]}, "isnull": _isnull_operator, - "range": lambda a, b: { + "range": lambda a, b, pos: { "$and": [ {"$or": [{a: {"$gte": b[0]}}, {a: None}]}, {"$or": [{a: {"$lte": b[1]}}, {a: None}]}, ] + } + if not pos + else { + "$and": [ + { + "$or": [ + {"$gte": [prefix_with_dollar(a), b[0]]}, + {"$eq": [prefix_with_dollar(a), None]}, + ] + }, + { + "$or": [ + {"$lte": [prefix_with_dollar(a), b[1]]}, + {"$eq": [prefix_with_dollar(a), None]}, + ] + }, + ] }, - "iexact": lambda a, b: regex_match(a, f"^{b}$", insensitive=True), - "startswith": lambda a, b: regex_match(a, f"^{b}"), - "istartswith": lambda a, b: regex_match(a, f"^{b}", insensitive=True), - "endswith": lambda a, b: regex_match(a, f"{b}$"), - "iendswith": lambda a, b: regex_match(a, f"{b}$", insensitive=True), - "contains": lambda a, b: regex_match(a, b), - "icontains": lambda a, b: regex_match(a, b, insensitive=True), - "regex": lambda a, b: regex_match(a, b), - "iregex": lambda a, b: regex_match(a, b, insensitive=True), + "iexact": lambda a, b, pos: regex_match(a, f"^{b}$", insensitive=True, pos=pos), + "startswith": lambda a, b, pos: regex_match(a, f"^{b}", pos=pos), + "istartswith": lambda a, b, pos: regex_match(a, f"^{b}", insensitive=True, pos=pos), + "endswith": lambda a, b, pos: regex_match(a, f"{b}$", pos=pos), + "iendswith": lambda a, b, pos: regex_match(a, f"{b}$", insensitive=True, pos=pos), + "contains": lambda a, b, pos: regex_match(a, b, pos=pos), + "icontains": lambda a, b, pos: regex_match(a, b, insensitive=True, pos=pos), + "regex": lambda a, b, pos: regex_match(a, b, pos=pos), + "iregex": lambda a, b, pos: regex_match(a, b, insensitive=True, pos=pos), } display_name = "DocumentDB" diff --git a/django_documentdb/compiler.py b/django_documentdb/compiler.py index e12eb07..37ab691 100644 --- a/django_documentdb/compiler.py +++ b/django_documentdb/compiler.py @@ -4,10 +4,8 @@ from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet from django.db import IntegrityError, NotSupportedError -from django.db.models import Count from django.db.models.aggregates import Aggregate, Variance from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When -from django.db.models.functions.comparison import Coalesce from django.db.models.functions.math import Power from django.db.models.lookups import IsNull from django.db.models.sql import compiler @@ -18,6 +16,7 @@ from .base import Cursor from .query import MongoQuery, wrap_database_errors +from .utils import Distinct, prefix_with_dollar class SQLCompiler(compiler.SQLCompiler): @@ -95,8 +94,8 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group group[alias] = sub_expr.as_mql(self, self.connection) replacing_expr = inner_column # Count must return 0 rather than null. - if isinstance(sub_expr, Count): - replacing_expr = Coalesce(replacing_expr, 0) + # if isinstance(sub_expr, Count): + # replacing_expr = Coalesce(replacing_expr, 0) # Variance = StdDev^2 if isinstance(sub_expr, Variance): replacing_expr = Power(replacing_expr, 2) @@ -245,7 +244,8 @@ def execute_sql( else: return self._make_result(obj, columns) # result_type is MULTI - cursor.batch_size(chunk_size) + if not isinstance(cursor, list): + cursor.batch_size(chunk_size) result = self.cursor_iter(cursor, chunk_size, columns) if not chunked_fetch: # If using non-chunked reads, read data into memory. @@ -347,24 +347,16 @@ def build_query(self, columns=None): if self.query.distinct: # If query is distinct, build a $group stage for distinct # fields, then set project fields based on the grouped _id. - distinct_fields = self.get_project_fields( - columns, ordering_fields, force_expression=True - ) - if not query.aggregation_pipeline: - query.aggregation_pipeline = [] - query.aggregation_pipeline.extend( - [ - {"$group": {"_id": distinct_fields}}, - {"$project": {key: f"$_id.{key}" for key in distinct_fields}}, - ] + query.distinct = Distinct( + fields=self.get_project_fields(columns, ordering_fields, force_expression=True) ) else: # Otherwise, project fields without grouping. query.project_fields = self.get_project_fields(columns, ordering_fields) # If columns is None, then get_project_fields() won't add # ordering_fields to $project. Use $addFields (extra_fields) instead. - # if columns is None: - # extra_fields += ordering_fields + if columns is None: + extra_fields += ordering_fields query.lookup_pipeline = self.get_lookup_pipeline() where = self.get_where() try: @@ -479,10 +471,11 @@ def get_combinator_queries(self): inner_pipeline.append({"$project": fields}) # Combine query with the current combinator pipeline. if combinator_pipeline: + raise NotSupportedError combinator_pipeline.append( {"$unionWith": {"coll": compiler_.collection_name, "pipeline": inner_pipeline}} ) - else: + else: # noqa: RET506 combinator_pipeline = inner_pipeline if not self.query.combinator_all: ids = defaultdict(dict) @@ -528,10 +521,7 @@ def get_project_fields(self, columns=None, ordering=None, force_expression=False fields[collection][name] = 1 else: mql = expr.as_mql(self, self.connection) - if isinstance(mql, str): - fields[collection][name] = f"${mql}" - else: - fields[collection][name] = mql + fields[collection][name] = prefix_with_dollar(mql) except EmptyResultSet: empty_result_set_value = getattr(expr, "empty_result_set_value", NotImplemented) diff --git a/django_documentdb/expressions.py b/django_documentdb/expressions.py index 224ea7e..024fa7d 100644 --- a/django_documentdb/expressions.py +++ b/django_documentdb/expressions.py @@ -25,7 +25,7 @@ ) from django.db.models.sql import Query -from django_documentdb.utils import IndexNotUsedWarning +from django_documentdb.utils import IndexNotUsedWarning, prefix_with_dollar def case(self, compiler, connection): @@ -33,13 +33,13 @@ def case(self, compiler, connection): for case in self.cases: case_mql = {} try: - case_mql["case"] = case.as_mql(compiler, connection) + case_mql["case"] = case.as_mql(compiler, connection, positional_operator_syntax=True) except EmptyResultSet: continue except FullResultSet: default_mql = case.result.as_mql(compiler, connection) break - case_mql["then"] = case.result.as_mql(compiler, connection) + case_mql["then"] = prefix_with_dollar(case.result.as_mql(compiler, connection)) case_parts.append(case_mql) else: default_mql = self.default.as_mql(compiler, connection) @@ -73,12 +73,8 @@ def col(self, compiler, connection): # noqa: ARG001 def combined_expression(self, compiler, connection): expressions = [ - f"${self.lhs.as_mql(compiler, connection)}" - if isinstance(self.lhs, Col) - else self.lhs.as_mql(compiler, connection), - f"${self.rhs.as_mql(compiler, connection)}" - if isinstance(self.rhs, Col) - else self.rhs.as_mql(compiler, connection), + prefix_with_dollar(self.lhs.as_mql(compiler, connection)), + prefix_with_dollar(self.rhs.as_mql(compiler, connection)), ] return connection.ops.combine_expression(self.connector, expressions) @@ -121,10 +117,15 @@ def query(self, compiler, connection, lookup_name=None): subquery.subquery_lookup = { "as": table_output, "from": from_table, - "let": { - compiler.PARENT_FIELD_TEMPLATE.format(i): col.as_mql(compiler, connection) - for col, i in subquery_compiler.column_indices.items() - }, + "localField": next( + iter( + [ + col.as_mql(compiler, connection) + for col, i in subquery_compiler.column_indices.items() + ] + ) + ), + "foreignField": next(iter(subquery.mongo_query.keys())), } # The result must be a list of values. The output is compressed with an # aggregation pipeline. @@ -191,16 +192,18 @@ def subquery(self, compiler, connection, lookup_name=None): return self.query.as_mql(compiler, connection, lookup_name=lookup_name) -def exists(self, compiler, connection, lookup_name=None): +def exists(self, compiler, connection, lookup_name=None, positional_operator_syntax: bool = False): try: lhs_mql = subquery(self, compiler, connection, lookup_name=lookup_name) except EmptyResultSet: return Value(False).as_mql(compiler, connection) - return connection.mongo_operators["isnull"](lhs_mql, False) + return connection.mongo_operators["isnull"](lhs_mql, False, pos=positional_operator_syntax) -def when(self, compiler, connection): - return self.condition.as_mql(compiler, connection) +def when(self, compiler, connection, positional_operator_syntax: bool = False): + return self.condition.as_mql( + compiler, connection, positional_operator_syntax=positional_operator_syntax + ) def value(self, compiler, connection): # noqa: ARG001 diff --git a/django_documentdb/functions.py b/django_documentdb/functions.py index f5364d6..cae6e8e 100644 --- a/django_documentdb/functions.py +++ b/django_documentdb/functions.py @@ -33,6 +33,7 @@ ) from .query_utils import process_lhs +from .utils import prefix_with_dollar MONGO_OPERATORS = { Ceil: "ceil", @@ -102,8 +103,7 @@ def func(self, compiler, connection): # Functions are using array syntax and for field name we want to add $ lhs_mql = process_lhs(self, compiler, connection) if isinstance(lhs_mql, list): - field_name = lhs_mql[0] - lhs_mql[0] = f"${field_name}" + lhs_mql = [prefix_with_dollar(field_name) for field_name in lhs_mql] operator = MONGO_OPERATORS.get(self.__class__, self.function.lower()) return {f"${operator}": lhs_mql} diff --git a/django_documentdb/lookups.py b/django_documentdb/lookups.py index 71bd2ea..b2cfa31 100644 --- a/django_documentdb/lookups.py +++ b/django_documentdb/lookups.py @@ -11,10 +11,10 @@ from .query_utils import process_lhs, process_rhs -def builtin_lookup(self, compiler, connection): +def builtin_lookup(self, compiler, connection, positional_operator_syntax: bool = False): lhs_mql = process_lhs(self, compiler, connection) value = process_rhs(self, compiler, connection) - return connection.mongo_operators[self.lookup_name](lhs_mql, value) + return connection.mongo_operators[self.lookup_name](lhs_mql, value, positional_operator_syntax) _field_resolve_expression_parameter = FieldGetDbPrepValueIterableMixin.resolve_expression_parameter @@ -33,7 +33,7 @@ def field_resolve_expression_parameter(self, compiler, connection, sql, param): return sql, sql_params -def in_(self, compiler, connection): +def in_(self, compiler, connection, positional_operator_syntax: bool = False): if isinstance(self.lhs, MultiColSource): raise NotImplementedError("MultiColSource is not supported.") db_rhs = getattr(self.rhs, "_db", None) @@ -42,14 +42,14 @@ def in_(self, compiler, connection): "Subqueries aren't allowed across different databases. Force " "the inner query to be evaluated using `list(inner_query)`." ) - return builtin_lookup(self, compiler, connection) + return builtin_lookup(self, compiler, connection, positional_operator_syntax) -def is_null(self, compiler, connection): +def is_null(self, compiler, connection, positional_operator_syntax: bool = False): if not isinstance(self.rhs, bool): raise ValueError("The QuerySet value for an isnull lookup must be True or False.") lhs_mql = process_lhs(self, compiler, connection) - return connection.mongo_operators["isnull"](lhs_mql, self.rhs) + return connection.mongo_operators["isnull"](lhs_mql, self.rhs, pos=positional_operator_syntax) # from https://www.pcre.org/current/doc/html/pcre2pattern.html#SEC4 diff --git a/django_documentdb/query.py b/django_documentdb/query.py index ee3c6bc..9a3fdf3 100644 --- a/django_documentdb/query.py +++ b/django_documentdb/query.py @@ -18,7 +18,7 @@ if typing.TYPE_CHECKING: from django_documentdb.base import DatabaseWrapper from django_documentdb.compiler import SQLCompiler -from django_documentdb.utils import IndexNotUsedWarning +from django_documentdb.utils import Distinct, IndexNotUsedWarning, unprefix_dollar def wrap_database_errors(func): @@ -61,6 +61,7 @@ def __init__(self, compiler: "SQLCompiler"): self.mongo_query = getattr(compiler.query, "raw_query", {}) self.subqueries = None self.lookup_pipeline = None + self.distinct: Distinct | None = None self.project_fields = None self.aggregation_pipeline = compiler.aggregation_pipeline self.extra_fields = None @@ -80,16 +81,24 @@ def delete(self): return self.collection.delete_many(self.mongo_query).deleted_count @wrap_database_errors - def get_cursor(self) -> Cursor[_DocumentType]: + def get_cursor(self) -> Cursor[_DocumentType] | list[dict]: """ Return a pymongo CommandCursor that can be iterated on to give the results of the query. """ - if self.is_simple_lookup: + if self.is_simple_lookup and self.distinct and self.distinct.is_simple_distinct: + results = self.collection.distinct( + self.distinct.field, **self.build_simple_lookup(limit=False, offset=False) + ) + return [{self.distinct.field: x} for x in results] + + if self.is_simple_lookup and not self.distinct: pipeline = self.build_simple_lookup() return self.collection.find(**pipeline) pipeline = self.get_pipeline() + if self.distinct: + pipeline.extend(self.distinct.aggregation()) options = {} if hasattr(self.query, "_index_hint"): options["hint"] = self.query._index_hint @@ -107,7 +116,7 @@ def is_simple_lookup(self) -> bool: and not self.subquery_lookup ) - def build_simple_lookup(self) -> dict: + def build_simple_lookup(self, **kwargs) -> dict: pipeline = {} if self.mongo_query: pipeline["filter"] = self.mongo_query @@ -117,9 +126,9 @@ def build_simple_lookup(self) -> dict: pipeline["projection"] = self.project_fields if self.ordering: pipeline["sort"] = self.ordering - if self.query.low_mark > 0: + if self.query.low_mark > 0 and kwargs.get("offset", True): pipeline["skip"] = self.query.low_mark - if self.query.high_mark is not None: + if self.query.high_mark is not None and kwargs.get("limit", True): pipeline["limit"] = self.query.high_mark - self.query.low_mark if hasattr(self.query, "_index_hint"): pipeline["hint"] = self.query._index_hint @@ -142,7 +151,7 @@ def get_pipeline(self): if self.extra_fields: pipeline.append({"$addFields": self.extra_fields}) if self.ordering: - pipeline.append({"$sort": self.ordering}) + pipeline.append({"$sort": unprefix_dollar(self.ordering)}) if self.query.low_mark > 0: pipeline.append({"$skip": self.query.low_mark}) if self.query.high_mark is not None: @@ -150,21 +159,11 @@ def get_pipeline(self): if self.subquery_lookup: table_output = self.subquery_lookup["as"] pipeline = [ - {"$lookup": {**self.subquery_lookup, "pipeline": pipeline}}, + {"$lookup": {**self.subquery_lookup}}, { - "$set": { - table_output: { - "$cond": { - "if": { - "$or": [ - {"$eq": [{"$type": f"${table_output}"}, "missing"]}, - {"$eq": [{"$size": f"${table_output}"}, 0]}, - ] - }, - "then": {}, - "else": {"$arrayElemAt": [f"${table_output}", 0]}, - } - } + "$unwind": { + "path": f"${table_output}", + "preserveNullAndEmptyArrays": True, } }, ] @@ -244,7 +243,7 @@ def join(self: Join, compiler: "SQLCompiler", connection: "DatabaseWrapper"): return lookup_pipeline -def where_node(self, compiler, connection): +def where_node(self, compiler, connection, positional_operator_syntax: bool = False): if self.connector == AND: full_needed, empty_needed = len(self.children), 1 else: @@ -267,14 +266,16 @@ def where_node(self, compiler, connection): if len(self.children) > 2: rhs_sum = Mod(rhs_sum, 2) rhs = Exact(1, rhs_sum) - return self.__class__([lhs, rhs], AND, self.negated).as_mql(compiler, connection) + return self.__class__([lhs, rhs], AND, self.negated).as_mql( + compiler, connection, positional_operator_syntax + ) else: operator = "$or" children_mql = [] for child in self.children: try: - mql = child.as_mql(compiler, connection) + mql = child.as_mql(compiler, connection, positional_operator_syntax) except EmptyResultSet: empty_needed -= 1 except FullResultSet: diff --git a/django_documentdb/query_utils.py b/django_documentdb/query_utils.py index 4e2d5c8..382c2e4 100644 --- a/django_documentdb/query_utils.py +++ b/django_documentdb/query_utils.py @@ -2,6 +2,8 @@ from django.db.models.aggregates import Aggregate from django.db.models.expressions import Value +from django_documentdb.utils import prefix_with_dollar + def is_direct_value(node): return not hasattr(node, "as_sql") @@ -13,7 +15,7 @@ def process_lhs(node, compiler, connection): result = [] for expr in node.get_source_expressions(): try: - result.append(expr.as_mql(compiler, connection)) + result.append(prefix_with_dollar(expr.as_mql(compiler, connection))) except FullResultSet: result.append(Value(True).as_mql(compiler, connection)) if isinstance(node, Aggregate): @@ -46,7 +48,10 @@ def process_rhs(node, compiler, connection): return connection.ops.prep_lookup_value(value, node.lhs.output_field, node.lookup_name) -def regex_match(field, regex: str, insensitive=False): +def regex_match(field, regex: str, insensitive=False, pos: bool = False): + """ + - pos = positional operator syntax (e.g. {"$eq": [a, b]}) + """ # warnings.warn( # "It's better to use hint with regex operations.\n" # "See https://docs.aws.amazon.com/documentdb/latest/developerguide/functional-differences.html" @@ -55,4 +60,10 @@ def regex_match(field, regex: str, insensitive=False): # category=NotOptimalOperationWarning, # ) options = "i" if insensitive else "" - return {field: {"$regex": regex, "$options": options}} + return ( + {field: {"$regex": regex, "$options": options}} + if not pos + else { + "$regexMatch": {"input": prefix_with_dollar(field), "regex": regex, "options": options} + } + ) diff --git a/django_documentdb/schema.py b/django_documentdb/schema.py index 8fd902a..3bf1728 100644 --- a/django_documentdb/schema.py +++ b/django_documentdb/schema.py @@ -129,6 +129,7 @@ def remove_field(self, model, field): return # Unset field on existing documents. if column := field.column: + # TODO: documentdb doesn't support $unset self.get_collection(model._meta.db_table).update_many({}, {"$unset": {column: ""}}) if self._field_should_be_indexed(model, field): self._remove_field_index(model, field) diff --git a/django_documentdb/utils.py b/django_documentdb/utils.py index 7e1c2b3..2b31ec8 100644 --- a/django_documentdb/utils.py +++ b/django_documentdb/utils.py @@ -1,5 +1,6 @@ import copy import time +from functools import cached_property import django from django.conf import settings @@ -39,6 +40,7 @@ class OperationDebugWrapper: wrapped_methods = { "find", "aggregate", + "distinct", "create_collection", "create_indexes", "drop", @@ -66,13 +68,16 @@ def profile_call(self, func, args=(), kwargs=None): duration = time.monotonic() - start return duration, retval + def to_documentdb_syntax(self, query): + return query.replace("None", "null").replace("True", "true").replace("False", "false") + def log(self, op, duration, args, kwargs=None): # If kwargs are used by any operations in the future, they must be # added to this logging. msg = "(%.3f) %s" args = ", ".join(repr(arg) for arg in args) kwargs = ", ".join(f"{k}={v!r}" for k, v in (kwargs or {}).items()) - operation = f"db.{self.collection_name}{op}({args} {kwargs})" + operation = f"db.{self.collection_name}{op}({self.to_documentdb_syntax(args)} {kwargs})" if len(settings.DATABASES) > 1: msg += f"; alias={self.db.alias}" self.db.queries_log.append( @@ -138,3 +143,39 @@ class NotOptimalOperationWarning(Warning): class IndexNotUsedWarning(DocumentDBIncompatibleWarning, NotOptimalOperationWarning): pass + + +class Distinct: + def __init__( + self, + fields: dict[str, str | dict], + ): + self.fields = fields + + def aggregation(self): + return [ + {"$group": {"_id": self.fields}}, + {"$project": {key: f"$_id.{key}" for key in self.fields}}, + ] + + @cached_property + def is_simple_distinct(self): + return len(self.fields) == 1 + + @property + def field(self) -> str: + return list(self.fields.keys())[0] # noqa: RUF015 + + +def prefix_with_dollar(field): + if isinstance(field, str): + return f"${field}" if not field.startswith("$") else field + return field + + +def unprefix_dollar(field): + if isinstance(field, dict): + return {k[1:] if k.startswith("$") else k: v for k, v in field.items()} + if isinstance(field, str): + return field[1:] if field.startswith("$") else field + return field diff --git a/tests/Dockerfile b/tests/Dockerfile new file mode 100644 index 0000000..5a8d2e0 --- /dev/null +++ b/tests/Dockerfile @@ -0,0 +1,20 @@ +FROM python:3.11.10 + +ENV PYTHONUNBUFFERED 1 + +RUN apt-get update && apt-get install -y git libmemcached-dev + +WORKDIR /package-test/ + +COPY . django-documentdb +COPY ./tests/django django +COPY .github/workflows/documentdb_settings.py django/tests/ +COPY rds-combined-ca-bundle.pem /package-test/rds-combined-ca-bundle.pem +RUN find django -type f -exec sed -i 's/django_mongodb/django_documentdb/g' {} + + +RUN pip install -e django-documentdb +RUN pip install -e django +RUN pip install -r django/tests/requirements/py3.txt + + +CMD ["python", "django/tests/runtests.py", "--settings", "documentdb_settings", "-v", "2", "--failfast", "aggregation"] diff --git a/tests/mongodb_settings.py b/tests/mongodb_settings.py new file mode 100644 index 0000000..6b189e2 --- /dev/null +++ b/tests/mongodb_settings.py @@ -0,0 +1,20 @@ +DATABASES = { + "default": { + "ENGINE": "django_mongodb", + "NAME": "mongotest", + "HOST": "host.docker.internal", + "USER": "root", + "PASSWORD": "mongoadmin", + }, + "other": { + "ENGINE": "django_mongodb", + "NAME": "mongotest-other", + "HOST": "host.docker.internal", + "USER": "root", + "PASSWORD": "mongoadmin", + }, +} +DEFAULT_AUTO_FIELD = "django_mongodb.fields.ObjectIdAutoField" +PASSWORD_HASHERS = ("django.contrib.auth.hashers.MD5PasswordHasher",) +SECRET_KEY = "django_tests_secret_key" # noqa: S105 +USE_TZ = False