From 1040591247d8271d58562c69eda436e13b103e02 Mon Sep 17 00:00:00 2001 From: iYasha <33287747+iYasha@users.noreply.github.com> Date: Tue, 29 Oct 2024 23:23:34 +0200 Subject: [PATCH 1/7] fix(test_cases): changed vendor name in aggregation functions --- django_documentdb/fields/duration.py | 2 +- django_documentdb/fields/json.py | 2 +- django_documentdb/lookups.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/django_documentdb/fields/duration.py b/django_documentdb/fields/duration.py index cd0fd55..62828db 100644 --- a/django_documentdb/fields/duration.py +++ b/django_documentdb/fields/duration.py @@ -6,7 +6,7 @@ def get_db_prep_value(self, value, connection, prepared=False): """DurationField stores milliseconds rather than microseconds.""" value = _get_db_prep_value(self, value, connection, prepared) - if connection.vendor == "mongodb" and value is not None: + if connection.vendor == "documentdb" and value is not None: value //= 1000 return value diff --git a/django_documentdb/fields/json.py b/django_documentdb/fields/json.py index 79b9e94..a6cef9c 100644 --- a/django_documentdb/fields/json.py +++ b/django_documentdb/fields/json.py @@ -72,7 +72,7 @@ def json_exact_process_rhs(self, compiler, connection): """Skip JSONExact.process_rhs()'s conversion of None to "null".""" return ( super(JSONExact, self).process_rhs(compiler, connection) - if connection.vendor == "mongodb" + if connection.vendor == "documentdb" else _process_rhs(self, compiler, connection) ) diff --git a/django_documentdb/lookups.py b/django_documentdb/lookups.py index c651dd6..71bd2ea 100644 --- a/django_documentdb/lookups.py +++ b/django_documentdb/lookups.py @@ -23,7 +23,7 @@ def builtin_lookup(self, compiler, connection): def field_resolve_expression_parameter(self, compiler, connection, sql, param): """For MongoDB, this method must call as_mql() instead of as_sql().""" sql, sql_params = _field_resolve_expression_parameter(self, compiler, connection, sql, param) - if connection.vendor == "mongodb": + if connection.vendor == "documentdb": params = [param] if hasattr(param, "resolve_expression"): param = param.resolve_expression(compiler.query) From f67bcaf46735b5e4d13c325f311939ed66cc53eb Mon Sep 17 00:00:00 2001 From: iYasha <33287747+iYasha@users.noreply.github.com> Date: Wed, 30 Oct 2024 22:47:08 +0200 Subject: [PATCH 2/7] fix(test_cases): fixed a few test cases --- django_documentdb/aggregates.py | 5 +++-- django_documentdb/compiler.py | 38 +++++++++++++++++++++++++------- django_documentdb/expressions.py | 10 +++------ django_documentdb/functions.py | 4 ++-- django_documentdb/query.py | 4 ++-- django_documentdb/query_utils.py | 4 +++- django_documentdb/utils.py | 19 +++++++++++++++- 7 files changed, 61 insertions(+), 23 deletions(-) diff --git a/django_documentdb/aggregates.py b/django_documentdb/aggregates.py index 6eae38d..4af987f 100644 --- a/django_documentdb/aggregates.py +++ b/django_documentdb/aggregates.py @@ -3,6 +3,7 @@ from django.db.models.lookups import IsNull from .query_utils import process_lhs +from .utils import prefix_with_dollar # Aggregates whose MongoDB aggregation name differ from Aggregate.function.lower(). MONGO_AGGREGATIONS = {Count: "sum"} @@ -26,9 +27,9 @@ def aggregate( node = self lhs_mql = process_lhs(node, compiler, connection) if resolve_inner_expression: - return lhs_mql + return prefix_with_dollar(lhs_mql) operator = operator or MONGO_AGGREGATIONS.get(self.__class__, self.function.lower()) - return {f"${operator}": f"${lhs_mql}"} + return {f"${operator}": prefix_with_dollar(lhs_mql)} def count(self, compiler, connection, resolve_inner_expression=False, **extra_context): # noqa: ARG001 diff --git a/django_documentdb/compiler.py b/django_documentdb/compiler.py index af1328e..4a4c21f 100644 --- a/django_documentdb/compiler.py +++ b/django_documentdb/compiler.py @@ -4,10 +4,8 @@ from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet from django.db import IntegrityError, NotSupportedError -from django.db.models import Count from django.db.models.aggregates import Aggregate, Variance from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When -from django.db.models.functions.comparison import Coalesce from django.db.models.functions.math import Power from django.db.models.lookups import IsNull from django.db.models.sql import compiler @@ -18,6 +16,7 @@ from .base import Cursor from .query import MongoQuery, wrap_database_errors +from .utils import prefix_with_dollar class SQLCompiler(compiler.SQLCompiler): @@ -95,8 +94,8 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group group[alias] = sub_expr.as_mql(self, self.connection) replacing_expr = inner_column # Count must return 0 rather than null. - if isinstance(sub_expr, Count): - replacing_expr = Coalesce(replacing_expr, 0) + # if isinstance(sub_expr, Count): + # replacing_expr = Coalesce(replacing_expr, 0) # Variance = StdDev^2 if isinstance(sub_expr, Variance): replacing_expr = Power(replacing_expr, 2) @@ -185,6 +184,32 @@ def _build_aggregation_pipeline(self, ids, group): if not ids: group["_id"] = None pipeline.append({"$group": group}) + + # Step 2: Add conditional $unionWith to handle the case of no matching records + pipeline.append( + { + "$unionWith": { + "coll": self.collection_name, + "pipeline": [ + {"$limit": 1}, # Ensures we always have a document in the result set + ], + } + } + ) + + # Step 3: Final $group to select first non-null result for each field + pipeline.append( + { + "$group": { + "_id": None, + **{ + key: {"$first": prefix_with_dollar(key)} + for key in group + if key != "_id" + }, + } + }, + ) else: group["_id"] = ids pipeline.append({"$group": group}) @@ -528,10 +553,7 @@ def get_project_fields(self, columns=None, ordering=None, force_expression=False fields[collection][name] = 1 else: mql = expr.as_mql(self, self.connection) - if isinstance(mql, str): - fields[collection][name] = f"${mql}" - else: - fields[collection][name] = mql + fields[collection][name] = prefix_with_dollar(mql) except EmptyResultSet: empty_result_set_value = getattr(expr, "empty_result_set_value", NotImplemented) diff --git a/django_documentdb/expressions.py b/django_documentdb/expressions.py index 224ea7e..6602d87 100644 --- a/django_documentdb/expressions.py +++ b/django_documentdb/expressions.py @@ -25,7 +25,7 @@ ) from django.db.models.sql import Query -from django_documentdb.utils import IndexNotUsedWarning +from django_documentdb.utils import IndexNotUsedWarning, prefix_with_dollar def case(self, compiler, connection): @@ -73,12 +73,8 @@ def col(self, compiler, connection): # noqa: ARG001 def combined_expression(self, compiler, connection): expressions = [ - f"${self.lhs.as_mql(compiler, connection)}" - if isinstance(self.lhs, Col) - else self.lhs.as_mql(compiler, connection), - f"${self.rhs.as_mql(compiler, connection)}" - if isinstance(self.rhs, Col) - else self.rhs.as_mql(compiler, connection), + prefix_with_dollar(self.lhs.as_mql(compiler, connection)), + prefix_with_dollar(self.rhs.as_mql(compiler, connection)), ] return connection.ops.combine_expression(self.connector, expressions) diff --git a/django_documentdb/functions.py b/django_documentdb/functions.py index f5364d6..cae6e8e 100644 --- a/django_documentdb/functions.py +++ b/django_documentdb/functions.py @@ -33,6 +33,7 @@ ) from .query_utils import process_lhs +from .utils import prefix_with_dollar MONGO_OPERATORS = { Ceil: "ceil", @@ -102,8 +103,7 @@ def func(self, compiler, connection): # Functions are using array syntax and for field name we want to add $ lhs_mql = process_lhs(self, compiler, connection) if isinstance(lhs_mql, list): - field_name = lhs_mql[0] - lhs_mql[0] = f"${field_name}" + lhs_mql = [prefix_with_dollar(field_name) for field_name in lhs_mql] operator = MONGO_OPERATORS.get(self.__class__, self.function.lower()) return {f"${operator}": lhs_mql} diff --git a/django_documentdb/query.py b/django_documentdb/query.py index 2f2ce1f..fb8d7de 100644 --- a/django_documentdb/query.py +++ b/django_documentdb/query.py @@ -16,7 +16,7 @@ if typing.TYPE_CHECKING: from django_documentdb.base import DatabaseWrapper from django_documentdb.compiler import SQLCompiler -from django_documentdb.utils import IndexNotUsedWarning +from django_documentdb.utils import IndexNotUsedWarning, unprefix_dollar def wrap_database_errors(func): @@ -106,7 +106,7 @@ def get_pipeline(self): if self.extra_fields: pipeline.append({"$addFields": self.extra_fields}) if self.ordering: - pipeline.append({"$sort": self.ordering}) + pipeline.append({"$sort": unprefix_dollar(self.ordering)}) if self.query.low_mark > 0: pipeline.append({"$skip": self.query.low_mark}) if self.query.high_mark is not None: diff --git a/django_documentdb/query_utils.py b/django_documentdb/query_utils.py index 4e2d5c8..b0b0ec2 100644 --- a/django_documentdb/query_utils.py +++ b/django_documentdb/query_utils.py @@ -2,6 +2,8 @@ from django.db.models.aggregates import Aggregate from django.db.models.expressions import Value +from django_documentdb.utils import prefix_with_dollar + def is_direct_value(node): return not hasattr(node, "as_sql") @@ -13,7 +15,7 @@ def process_lhs(node, compiler, connection): result = [] for expr in node.get_source_expressions(): try: - result.append(expr.as_mql(compiler, connection)) + result.append(prefix_with_dollar(expr.as_mql(compiler, connection))) except FullResultSet: result.append(Value(True).as_mql(compiler, connection)) if isinstance(node, Aggregate): diff --git a/django_documentdb/utils.py b/django_documentdb/utils.py index 05ecdc1..ee75596 100644 --- a/django_documentdb/utils.py +++ b/django_documentdb/utils.py @@ -65,12 +65,15 @@ def profile_call(self, func, args=(), kwargs=None): duration = time.monotonic() - start return duration, retval + def to_documentdb_syntax(self, query): + return query.replace("None", "null").replace("True", "true").replace("False", "false") + def log(self, op, duration, args, kwargs=None): # If kwargs are used by any operations in the future, they must be # added to this logging. msg = "(%.3f) %s" args = ", ".join(repr(arg) for arg in args) - operation = f"db.{self.collection_name}{op}({args})" + operation = f"db.{self.collection_name}{op}({self.to_documentdb_syntax(args)})" if len(settings.DATABASES) > 1: msg += f"; alias={self.db.alias}" self.db.queries_log.append( @@ -135,3 +138,17 @@ class NotOptimalOperationWarning(Warning): class IndexNotUsedWarning(DocumentDBIncompatibleWarning, NotOptimalOperationWarning): pass + + +def prefix_with_dollar(field): + if isinstance(field, str): + return f"${field}" if not field.startswith("$") else field + return field + + +def unprefix_dollar(field): + if isinstance(field, dict): + return {k[1:] if k.startswith("$") else k: v for k, v in field.items()} + if isinstance(field, str): + return field[1:] if field.startswith("$") else field + return field From cefb81396b15f19aeb6cc80976b8d5f060c56909 Mon Sep 17 00:00:00 2001 From: iYasha <33287747+iYasha@users.noreply.github.com> Date: Thu, 31 Oct 2024 17:15:35 +0200 Subject: [PATCH 3/7] feat(tests): Added dockerfile to run tests --- ...odb_settings.py => documentdb_settings.py} | 8 +++++++- .gitignore | 1 + tests/Dockerfile | 19 +++++++++++++++++++ 3 files changed, 27 insertions(+), 1 deletion(-) rename .github/workflows/{mongodb_settings.py => documentdb_settings.py} (59%) create mode 100644 tests/Dockerfile diff --git a/.github/workflows/mongodb_settings.py b/.github/workflows/documentdb_settings.py similarity index 59% rename from .github/workflows/mongodb_settings.py rename to .github/workflows/documentdb_settings.py index f571e11..634f194 100644 --- a/.github/workflows/mongodb_settings.py +++ b/.github/workflows/documentdb_settings.py @@ -2,13 +2,19 @@ "default": { "ENGINE": "django_documentdb", "NAME": "djangotests", + "HOST": "host.docker.internal", + "USER": "root", + "PASSWORD": "mongoadmin", }, "other": { "ENGINE": "django_documentdb", "NAME": "djangotests-other", + "HOST": "host.docker.internal", + "USER": "root", + "PASSWORD": "mongoadmin", }, } DEFAULT_AUTO_FIELD = "django_documentdb.fields.ObjectIdAutoField" PASSWORD_HASHERS = ("django.contrib.auth.hashers.MD5PasswordHasher",) -SECRET_KEY = "django_tests_secret_key" +SECRET_KEY = "django_tests_secret_key" # noqa: S105 USE_TZ = False diff --git a/.gitignore b/.gitignore index 35ac288..0e856fa 100644 --- a/.gitignore +++ b/.gitignore @@ -51,3 +51,4 @@ docs/build site/ _development/ +tests/django diff --git a/tests/Dockerfile b/tests/Dockerfile new file mode 100644 index 0000000..150774e --- /dev/null +++ b/tests/Dockerfile @@ -0,0 +1,19 @@ +FROM python:3.11.10 + +ENV PYTHONUNBUFFERED 1 + +RUN apt-get update && apt-get install -y git libmemcached-dev + +WORKDIR /package-test + +COPY . django-documentdb +COPY ./tests/django django +COPY .github/workflows/documentdb_settings.py django/tests/ +RUN find django -type f -exec sed -i 's/django_mongodb/django_documentdb/g' {} + + +RUN pip install -e django-documentdb +RUN pip install -e django +RUN pip install -r django/tests/requirements/py3.txt + + +CMD ["python", "django/tests/runtests.py", "--settings", "documentdb_settings", "-v", "2", "aggregation"] From 38cef1af5752e1a6c75002479d64ff227a57e804 Mon Sep 17 00:00:00 2001 From: iYasha <33287747+iYasha@users.noreply.github.com> Date: Thu, 31 Oct 2024 18:04:00 +0200 Subject: [PATCH 4/7] feat(simple lookup): Added simple lookup operation support using .find(...) --- django_documentdb/fields/duration.py | 2 +- django_documentdb/fields/json.py | 2 +- django_documentdb/lookups.py | 2 +- django_documentdb/query.py | 38 +++++++++++++++++++++++++++- django_documentdb/utils.py | 7 +++-- 5 files changed, 45 insertions(+), 6 deletions(-) diff --git a/django_documentdb/fields/duration.py b/django_documentdb/fields/duration.py index cd0fd55..62828db 100644 --- a/django_documentdb/fields/duration.py +++ b/django_documentdb/fields/duration.py @@ -6,7 +6,7 @@ def get_db_prep_value(self, value, connection, prepared=False): """DurationField stores milliseconds rather than microseconds.""" value = _get_db_prep_value(self, value, connection, prepared) - if connection.vendor == "mongodb" and value is not None: + if connection.vendor == "documentdb" and value is not None: value //= 1000 return value diff --git a/django_documentdb/fields/json.py b/django_documentdb/fields/json.py index 79b9e94..a6cef9c 100644 --- a/django_documentdb/fields/json.py +++ b/django_documentdb/fields/json.py @@ -72,7 +72,7 @@ def json_exact_process_rhs(self, compiler, connection): """Skip JSONExact.process_rhs()'s conversion of None to "null".""" return ( super(JSONExact, self).process_rhs(compiler, connection) - if connection.vendor == "mongodb" + if connection.vendor == "documentdb" else _process_rhs(self, compiler, connection) ) diff --git a/django_documentdb/lookups.py b/django_documentdb/lookups.py index c651dd6..71bd2ea 100644 --- a/django_documentdb/lookups.py +++ b/django_documentdb/lookups.py @@ -23,7 +23,7 @@ def builtin_lookup(self, compiler, connection): def field_resolve_expression_parameter(self, compiler, connection, sql, param): """For MongoDB, this method must call as_mql() instead of as_sql().""" sql, sql_params = _field_resolve_expression_parameter(self, compiler, connection, sql, param) - if connection.vendor == "mongodb": + if connection.vendor == "documentdb": params = [param] if hasattr(param, "resolve_expression"): param = param.resolve_expression(compiler.query) diff --git a/django_documentdb/query.py b/django_documentdb/query.py index 2f2ce1f..ebde0c8 100644 --- a/django_documentdb/query.py +++ b/django_documentdb/query.py @@ -12,6 +12,8 @@ from django.db.models.sql.datastructures import Join from django.db.models.sql.where import AND, OR, XOR, ExtraWhere, NothingNode, WhereNode from pymongo.errors import BulkWriteError, DuplicateKeyError, PyMongoError +from pymongo.synchronous.cursor import Cursor +from pymongo.typings import _DocumentType if typing.TYPE_CHECKING: from django_documentdb.base import DatabaseWrapper @@ -78,17 +80,51 @@ def delete(self): return self.collection.delete_many(self.mongo_query).deleted_count @wrap_database_errors - def get_cursor(self): + def get_cursor(self) -> Cursor[_DocumentType]: """ Return a pymongo CommandCursor that can be iterated on to give the results of the query. """ + if self.is_simple_lookup: + pipeline = self.build_simple_lookup() + return self.collection.find(**pipeline) + pipeline = self.get_pipeline() options = {} if hasattr(self.query, "_index_hint"): options["hint"] = self.query._index_hint return self.collection.aggregate(pipeline, **options) + @property + def is_simple_lookup(self) -> bool: + return ( + (self.lookup_pipeline or self.mongo_query) + and not self.subqueries + and not self.combinator_pipeline + and not self.extra_fields + and not self.subquery_lookup + ) + + def build_simple_lookup(self) -> dict: + pipeline = {} + if self.lookup_pipeline: + pipeline["filter"] = self.lookup_pipeline + elif self.mongo_query: + pipeline["filter"] = self.mongo_query + else: + raise ValueError("No lookup pipeline or query found.") + if self.project_fields: + pipeline["projection"] = self.project_fields + if self.ordering: + pipeline["sort"] = self.ordering + if self.query.low_mark > 0: + pipeline["skip"] = self.query.low_mark + if self.query.high_mark is not None: + pipeline["limit"] = self.query.high_mark - self.query.low_mark + if hasattr(self.query, "_index_hint"): + pipeline["hint"] = self.query._index_hint + return pipeline + def get_pipeline(self): pipeline = [] if self.lookup_pipeline: diff --git a/django_documentdb/utils.py b/django_documentdb/utils.py index 05ecdc1..7e1c2b3 100644 --- a/django_documentdb/utils.py +++ b/django_documentdb/utils.py @@ -37,6 +37,7 @@ def set_wrapped_methods(cls): class OperationDebugWrapper: # The PyMongo database and collection methods that this backend uses. wrapped_methods = { + "find", "aggregate", "create_collection", "create_indexes", @@ -70,7 +71,8 @@ def log(self, op, duration, args, kwargs=None): # added to this logging. msg = "(%.3f) %s" args = ", ".join(repr(arg) for arg in args) - operation = f"db.{self.collection_name}{op}({args})" + kwargs = ", ".join(f"{k}={v!r}" for k, v in (kwargs or {}).items()) + operation = f"db.{self.collection_name}{op}({args} {kwargs})" if len(settings.DATABASES) > 1: msg += f"; alias={self.db.alias}" self.db.queries_log.append( @@ -115,7 +117,8 @@ def __init__(self, collected_sql=None, *, collection=None, db=None): def log(self, op, args, kwargs=None): args = ", ".join(repr(arg) for arg in args) - operation = f"db.{self.collection_name}{op}({args})" + kwargs = ", ".join(f"{k}={v!r}" for k, v in (kwargs or {}).items()) + operation = f"db.{self.collection_name}{op}({args} {kwargs})" self.collected_sql.append(operation) def logging_wrapper(method): From 45644b1aca94b4ec9a5da90e3a5ff00a8a2ce808 Mon Sep 17 00:00:00 2001 From: iYasha <33287747+iYasha@users.noreply.github.com> Date: Sat, 2 Nov 2024 18:52:08 +0200 Subject: [PATCH 5/7] feat(DocumentDB Compatible): DocumentDB doesn't support $facet, $unionWith, Variance --- .gitignore | 1 + README.md | 6 ++++ django_documentdb/aggregates.py | 8 ++--- django_documentdb/compiler.py | 29 ++----------------- django_documentdb/query.py | 8 ++--- django_documentdb/schema.py | 1 + tests/Dockerfile | 3 +- .../mongodb_settings.py | 10 +++---- 8 files changed, 23 insertions(+), 43 deletions(-) rename .github/workflows/documentdb_settings.py => tests/mongodb_settings.py (65%) diff --git a/.gitignore b/.gitignore index 0e856fa..0276ad2 100644 --- a/.gitignore +++ b/.gitignore @@ -52,3 +52,4 @@ site/ _development/ tests/django +documentdb_settings.py diff --git a/README.md b/README.md index dd45cba..cfe8265 100644 --- a/README.md +++ b/README.md @@ -110,3 +110,9 @@ class TestModel(DocumentModel): ## Forked Project This project, **django-documentdb**, is a fork of the original **django-mongodb** library, which aimed to integrate MongoDB with Django. The fork was created to enhance compatibility with AWS DocumentDB, addressing the limitations of its API support while maintaining the core functionalities of the original library. We appreciate the work of the MongoDB Python Team and aim to build upon their foundation to better serve users needing DocumentDB integration. + +## Run tests + +docker build . -t test:latest -f tests/Dockerfile && docker run -it test:latest + +docker build . -t mongo_test:latest -f tests/mongodb.Dockerfile && docker run -it mongo_test:latest diff --git a/django_documentdb/aggregates.py b/django_documentdb/aggregates.py index 4af987f..28e94d3 100644 --- a/django_documentdb/aggregates.py +++ b/django_documentdb/aggregates.py @@ -65,12 +65,8 @@ def count(self, compiler, connection, resolve_inner_expression=False, **extra_co return {"$add": [{"$size": lhs_mql}, exits_null]} -def stddev_variance(self, compiler, connection, **extra_context): - if self.function.endswith("_SAMP"): - operator = "stdDevSamp" - elif self.function.endswith("_POP"): - operator = "stdDevPop" - return aggregate(self, compiler, connection, operator=operator, **extra_context) +def stddev_variance(*args, **kwargs): # noqa: ARG001 + raise NotImplementedError("StdDev and Variance are not supported yet.") def register_aggregates(): diff --git a/django_documentdb/compiler.py b/django_documentdb/compiler.py index 4a4c21f..2c67dd6 100644 --- a/django_documentdb/compiler.py +++ b/django_documentdb/compiler.py @@ -184,32 +184,6 @@ def _build_aggregation_pipeline(self, ids, group): if not ids: group["_id"] = None pipeline.append({"$group": group}) - - # Step 2: Add conditional $unionWith to handle the case of no matching records - pipeline.append( - { - "$unionWith": { - "coll": self.collection_name, - "pipeline": [ - {"$limit": 1}, # Ensures we always have a document in the result set - ], - } - } - ) - - # Step 3: Final $group to select first non-null result for each field - pipeline.append( - { - "$group": { - "_id": None, - **{ - key: {"$first": prefix_with_dollar(key)} - for key in group - if key != "_id" - }, - } - }, - ) else: group["_id"] = ids pipeline.append({"$group": group}) @@ -504,10 +478,11 @@ def get_combinator_queries(self): inner_pipeline.append({"$project": fields}) # Combine query with the current combinator pipeline. if combinator_pipeline: + raise NotSupportedError combinator_pipeline.append( {"$unionWith": {"coll": compiler_.collection_name, "pipeline": inner_pipeline}} ) - else: + else: # noqa: RET506 combinator_pipeline = inner_pipeline if not self.query.combinator_all: ids = defaultdict(dict) diff --git a/django_documentdb/query.py b/django_documentdb/query.py index 461aa05..1f62edc 100644 --- a/django_documentdb/query.py +++ b/django_documentdb/query.py @@ -98,7 +98,9 @@ def get_cursor(self) -> Cursor[_DocumentType]: @property def is_simple_lookup(self) -> bool: return ( - (self.lookup_pipeline or self.mongo_query) + self.mongo_query + and not self.lookup_pipeline + and not self.aggregation_pipeline and not self.subqueries and not self.combinator_pipeline and not self.extra_fields @@ -107,9 +109,7 @@ def is_simple_lookup(self) -> bool: def build_simple_lookup(self) -> dict: pipeline = {} - if self.lookup_pipeline: - pipeline["filter"] = self.lookup_pipeline - elif self.mongo_query: + if self.mongo_query: pipeline["filter"] = self.mongo_query else: raise ValueError("No lookup pipeline or query found.") diff --git a/django_documentdb/schema.py b/django_documentdb/schema.py index 8fd902a..3bf1728 100644 --- a/django_documentdb/schema.py +++ b/django_documentdb/schema.py @@ -129,6 +129,7 @@ def remove_field(self, model, field): return # Unset field on existing documents. if column := field.column: + # TODO: documentdb doesn't support $unset self.get_collection(model._meta.db_table).update_many({}, {"$unset": {column: ""}}) if self._field_should_be_indexed(model, field): self._remove_field_index(model, field) diff --git a/tests/Dockerfile b/tests/Dockerfile index 150774e..43c648f 100644 --- a/tests/Dockerfile +++ b/tests/Dockerfile @@ -4,11 +4,12 @@ ENV PYTHONUNBUFFERED 1 RUN apt-get update && apt-get install -y git libmemcached-dev -WORKDIR /package-test +WORKDIR /package-test/ COPY . django-documentdb COPY ./tests/django django COPY .github/workflows/documentdb_settings.py django/tests/ +COPY rds-combined-ca-bundle.pem /package-test/rds-combined-ca-bundle.pem RUN find django -type f -exec sed -i 's/django_mongodb/django_documentdb/g' {} + RUN pip install -e django-documentdb diff --git a/.github/workflows/documentdb_settings.py b/tests/mongodb_settings.py similarity index 65% rename from .github/workflows/documentdb_settings.py rename to tests/mongodb_settings.py index 634f194..6b189e2 100644 --- a/.github/workflows/documentdb_settings.py +++ b/tests/mongodb_settings.py @@ -1,20 +1,20 @@ DATABASES = { "default": { - "ENGINE": "django_documentdb", - "NAME": "djangotests", + "ENGINE": "django_mongodb", + "NAME": "mongotest", "HOST": "host.docker.internal", "USER": "root", "PASSWORD": "mongoadmin", }, "other": { - "ENGINE": "django_documentdb", - "NAME": "djangotests-other", + "ENGINE": "django_mongodb", + "NAME": "mongotest-other", "HOST": "host.docker.internal", "USER": "root", "PASSWORD": "mongoadmin", }, } -DEFAULT_AUTO_FIELD = "django_documentdb.fields.ObjectIdAutoField" +DEFAULT_AUTO_FIELD = "django_mongodb.fields.ObjectIdAutoField" PASSWORD_HASHERS = ("django.contrib.auth.hashers.MD5PasswordHasher",) SECRET_KEY = "django_tests_secret_key" # noqa: S105 USE_TZ = False From ab91026e7f58b264d443ab47917846192599c71b Mon Sep 17 00:00:00 2001 From: iYasha <33287747+iYasha@users.noreply.github.com> Date: Sat, 2 Nov 2024 21:08:50 +0200 Subject: [PATCH 6/7] feat(DocumentDB Compatible): Fixed subquery_lookup and case --- django_documentdb/base.py | 63 +++++++++++++++++++++----------- django_documentdb/expressions.py | 27 +++++++++----- django_documentdb/lookups.py | 12 +++--- django_documentdb/query.py | 26 +++++-------- django_documentdb/query_utils.py | 13 ++++++- tests/Dockerfile | 2 +- 6 files changed, 86 insertions(+), 57 deletions(-) diff --git a/django_documentdb/base.py b/django_documentdb/base.py index 9043ca9..4ac6799 100644 --- a/django_documentdb/base.py +++ b/django_documentdb/base.py @@ -13,7 +13,7 @@ from .operations import DatabaseOperations from .query_utils import regex_match from .schema import DatabaseSchemaEditor -from .utils import IndexNotUsedWarning, OperationDebugWrapper +from .utils import IndexNotUsedWarning, OperationDebugWrapper, prefix_with_dollar # ignore warning from pymongo about DocumentDB warnings.filterwarnings("ignore", "You appear to be connected to a DocumentDB cluster", UserWarning) @@ -87,37 +87,58 @@ class DatabaseWrapper(BaseDatabaseWrapper): "iendswith": "LIKE '%%' || UPPER({})", } - def _isnull_operator(a, b): + def _isnull_operator(a, b, pos: bool = False): if b: - return {a: None} + return {a: None} if not pos else {"$eq": [prefix_with_dollar(a), None]} warnings.warn("You're using $ne, index will not be used", IndexNotUsedWarning, stacklevel=1) - return {a: {"$ne": None}} + return {a: {"$ne": None}} if not pos else {"$ne": [prefix_with_dollar(a), None]} mongo_operators = { - # Where a = field_name, b = value - "exact": lambda a, b: {a: b}, - "gt": lambda a, b: {a: {"$gt": b}}, - "gte": lambda a, b: {a: {"$gte": b}}, - "lt": lambda a, b: {a: {"$lt": b}}, - "lte": lambda a, b: {a: {"$lte": b}}, - "in": lambda a, b: {a: {"$in": b}}, + # Where a = field_name, b = value, pos = positional operator syntax + "exact": lambda a, b, pos: {a: b} if not pos else {"$eq": [prefix_with_dollar(a), b]}, + "gt": lambda a, b, pos: {a: {"$gt": b}} if not pos else {"$gt": [prefix_with_dollar(a), b]}, + "gte": lambda a, b, pos: {a: {"$gte": b}} + if not pos + else {"$gte": [prefix_with_dollar(a), b]}, + "lt": lambda a, b, pos: {a: {"$lt": b}} if not pos else {"$lt": [prefix_with_dollar(a), b]}, + "lte": lambda a, b, pos: {a: {"$lte": b}} + if not pos + else {"$lte": [prefix_with_dollar(a), b]}, + "in": lambda a, b, pos: {a: {"$in": b}} if not pos else {"$in": [prefix_with_dollar(a), b]}, "isnull": _isnull_operator, - "range": lambda a, b: { + "range": lambda a, b, pos: { "$and": [ {"$or": [{a: {"$gte": b[0]}}, {a: None}]}, {"$or": [{a: {"$lte": b[1]}}, {a: None}]}, ] + } + if not pos + else { + "$and": [ + { + "$or": [ + {"$gte": [prefix_with_dollar(a), b[0]]}, + {"$eq": [prefix_with_dollar(a), None]}, + ] + }, + { + "$or": [ + {"$lte": [prefix_with_dollar(a), b[1]]}, + {"$eq": [prefix_with_dollar(a), None]}, + ] + }, + ] }, - "iexact": lambda a, b: regex_match(a, f"^{b}$", insensitive=True), - "startswith": lambda a, b: regex_match(a, f"^{b}"), - "istartswith": lambda a, b: regex_match(a, f"^{b}", insensitive=True), - "endswith": lambda a, b: regex_match(a, f"{b}$"), - "iendswith": lambda a, b: regex_match(a, f"{b}$", insensitive=True), - "contains": lambda a, b: regex_match(a, b), - "icontains": lambda a, b: regex_match(a, b, insensitive=True), - "regex": lambda a, b: regex_match(a, b), - "iregex": lambda a, b: regex_match(a, b, insensitive=True), + "iexact": lambda a, b, pos: regex_match(a, f"^{b}$", insensitive=True, pos=pos), + "startswith": lambda a, b, pos: regex_match(a, f"^{b}", pos=pos), + "istartswith": lambda a, b, pos: regex_match(a, f"^{b}", insensitive=True, pos=pos), + "endswith": lambda a, b, pos: regex_match(a, f"{b}$", pos=pos), + "iendswith": lambda a, b, pos: regex_match(a, f"{b}$", insensitive=True, pos=pos), + "contains": lambda a, b, pos: regex_match(a, b, pos=pos), + "icontains": lambda a, b, pos: regex_match(a, b, insensitive=True, pos=pos), + "regex": lambda a, b, pos: regex_match(a, b, pos=pos), + "iregex": lambda a, b, pos: regex_match(a, b, insensitive=True, pos=pos), } display_name = "DocumentDB" diff --git a/django_documentdb/expressions.py b/django_documentdb/expressions.py index 6602d87..024fa7d 100644 --- a/django_documentdb/expressions.py +++ b/django_documentdb/expressions.py @@ -33,13 +33,13 @@ def case(self, compiler, connection): for case in self.cases: case_mql = {} try: - case_mql["case"] = case.as_mql(compiler, connection) + case_mql["case"] = case.as_mql(compiler, connection, positional_operator_syntax=True) except EmptyResultSet: continue except FullResultSet: default_mql = case.result.as_mql(compiler, connection) break - case_mql["then"] = case.result.as_mql(compiler, connection) + case_mql["then"] = prefix_with_dollar(case.result.as_mql(compiler, connection)) case_parts.append(case_mql) else: default_mql = self.default.as_mql(compiler, connection) @@ -117,10 +117,15 @@ def query(self, compiler, connection, lookup_name=None): subquery.subquery_lookup = { "as": table_output, "from": from_table, - "let": { - compiler.PARENT_FIELD_TEMPLATE.format(i): col.as_mql(compiler, connection) - for col, i in subquery_compiler.column_indices.items() - }, + "localField": next( + iter( + [ + col.as_mql(compiler, connection) + for col, i in subquery_compiler.column_indices.items() + ] + ) + ), + "foreignField": next(iter(subquery.mongo_query.keys())), } # The result must be a list of values. The output is compressed with an # aggregation pipeline. @@ -187,16 +192,18 @@ def subquery(self, compiler, connection, lookup_name=None): return self.query.as_mql(compiler, connection, lookup_name=lookup_name) -def exists(self, compiler, connection, lookup_name=None): +def exists(self, compiler, connection, lookup_name=None, positional_operator_syntax: bool = False): try: lhs_mql = subquery(self, compiler, connection, lookup_name=lookup_name) except EmptyResultSet: return Value(False).as_mql(compiler, connection) - return connection.mongo_operators["isnull"](lhs_mql, False) + return connection.mongo_operators["isnull"](lhs_mql, False, pos=positional_operator_syntax) -def when(self, compiler, connection): - return self.condition.as_mql(compiler, connection) +def when(self, compiler, connection, positional_operator_syntax: bool = False): + return self.condition.as_mql( + compiler, connection, positional_operator_syntax=positional_operator_syntax + ) def value(self, compiler, connection): # noqa: ARG001 diff --git a/django_documentdb/lookups.py b/django_documentdb/lookups.py index 71bd2ea..b2cfa31 100644 --- a/django_documentdb/lookups.py +++ b/django_documentdb/lookups.py @@ -11,10 +11,10 @@ from .query_utils import process_lhs, process_rhs -def builtin_lookup(self, compiler, connection): +def builtin_lookup(self, compiler, connection, positional_operator_syntax: bool = False): lhs_mql = process_lhs(self, compiler, connection) value = process_rhs(self, compiler, connection) - return connection.mongo_operators[self.lookup_name](lhs_mql, value) + return connection.mongo_operators[self.lookup_name](lhs_mql, value, positional_operator_syntax) _field_resolve_expression_parameter = FieldGetDbPrepValueIterableMixin.resolve_expression_parameter @@ -33,7 +33,7 @@ def field_resolve_expression_parameter(self, compiler, connection, sql, param): return sql, sql_params -def in_(self, compiler, connection): +def in_(self, compiler, connection, positional_operator_syntax: bool = False): if isinstance(self.lhs, MultiColSource): raise NotImplementedError("MultiColSource is not supported.") db_rhs = getattr(self.rhs, "_db", None) @@ -42,14 +42,14 @@ def in_(self, compiler, connection): "Subqueries aren't allowed across different databases. Force " "the inner query to be evaluated using `list(inner_query)`." ) - return builtin_lookup(self, compiler, connection) + return builtin_lookup(self, compiler, connection, positional_operator_syntax) -def is_null(self, compiler, connection): +def is_null(self, compiler, connection, positional_operator_syntax: bool = False): if not isinstance(self.rhs, bool): raise ValueError("The QuerySet value for an isnull lookup must be True or False.") lhs_mql = process_lhs(self, compiler, connection) - return connection.mongo_operators["isnull"](lhs_mql, self.rhs) + return connection.mongo_operators["isnull"](lhs_mql, self.rhs, pos=positional_operator_syntax) # from https://www.pcre.org/current/doc/html/pcre2pattern.html#SEC4 diff --git a/django_documentdb/query.py b/django_documentdb/query.py index 1f62edc..ba9d2ef 100644 --- a/django_documentdb/query.py +++ b/django_documentdb/query.py @@ -150,21 +150,11 @@ def get_pipeline(self): if self.subquery_lookup: table_output = self.subquery_lookup["as"] pipeline = [ - {"$lookup": {**self.subquery_lookup, "pipeline": pipeline}}, + {"$lookup": {**self.subquery_lookup}}, { - "$set": { - table_output: { - "$cond": { - "if": { - "$or": [ - {"$eq": [{"$type": f"${table_output}"}, "missing"]}, - {"$eq": [{"$size": f"${table_output}"}, 0]}, - ] - }, - "then": {}, - "else": {"$arrayElemAt": [f"${table_output}", 0]}, - } - } + "$unwind": { + "path": f"${table_output}", + "preserveNullAndEmptyArrays": True, } }, ] @@ -244,7 +234,7 @@ def join(self: Join, compiler: "SQLCompiler", connection: "DatabaseWrapper"): return lookup_pipeline -def where_node(self, compiler, connection): +def where_node(self, compiler, connection, positional_operator_syntax: bool = False): if self.connector == AND: full_needed, empty_needed = len(self.children), 1 else: @@ -267,14 +257,16 @@ def where_node(self, compiler, connection): if len(self.children) > 2: rhs_sum = Mod(rhs_sum, 2) rhs = Exact(1, rhs_sum) - return self.__class__([lhs, rhs], AND, self.negated).as_mql(compiler, connection) + return self.__class__([lhs, rhs], AND, self.negated).as_mql( + compiler, connection, positional_operator_syntax + ) else: operator = "$or" children_mql = [] for child in self.children: try: - mql = child.as_mql(compiler, connection) + mql = child.as_mql(compiler, connection, positional_operator_syntax) except EmptyResultSet: empty_needed -= 1 except FullResultSet: diff --git a/django_documentdb/query_utils.py b/django_documentdb/query_utils.py index b0b0ec2..382c2e4 100644 --- a/django_documentdb/query_utils.py +++ b/django_documentdb/query_utils.py @@ -48,7 +48,10 @@ def process_rhs(node, compiler, connection): return connection.ops.prep_lookup_value(value, node.lhs.output_field, node.lookup_name) -def regex_match(field, regex: str, insensitive=False): +def regex_match(field, regex: str, insensitive=False, pos: bool = False): + """ + - pos = positional operator syntax (e.g. {"$eq": [a, b]}) + """ # warnings.warn( # "It's better to use hint with regex operations.\n" # "See https://docs.aws.amazon.com/documentdb/latest/developerguide/functional-differences.html" @@ -57,4 +60,10 @@ def regex_match(field, regex: str, insensitive=False): # category=NotOptimalOperationWarning, # ) options = "i" if insensitive else "" - return {field: {"$regex": regex, "$options": options}} + return ( + {field: {"$regex": regex, "$options": options}} + if not pos + else { + "$regexMatch": {"input": prefix_with_dollar(field), "regex": regex, "options": options} + } + ) diff --git a/tests/Dockerfile b/tests/Dockerfile index 43c648f..5a8d2e0 100644 --- a/tests/Dockerfile +++ b/tests/Dockerfile @@ -17,4 +17,4 @@ RUN pip install -e django RUN pip install -r django/tests/requirements/py3.txt -CMD ["python", "django/tests/runtests.py", "--settings", "documentdb_settings", "-v", "2", "aggregation"] +CMD ["python", "django/tests/runtests.py", "--settings", "documentdb_settings", "-v", "2", "--failfast", "aggregation"] From dfb08cfa5e0e9a6e2f40465723cbbce0f7198e36 Mon Sep 17 00:00:00 2001 From: iYasha <33287747+iYasha@users.noreply.github.com> Date: Sun, 10 Nov 2024 20:05:44 +0200 Subject: [PATCH 7/7] fix: Distinct performance --- django_documentdb/compiler.py | 16 +++++----------- django_documentdb/query.py | 21 +++++++++++++++------ django_documentdb/utils.py | 24 ++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 17 deletions(-) diff --git a/django_documentdb/compiler.py b/django_documentdb/compiler.py index e12eb07..76a982c 100644 --- a/django_documentdb/compiler.py +++ b/django_documentdb/compiler.py @@ -18,6 +18,7 @@ from .base import Cursor from .query import MongoQuery, wrap_database_errors +from .utils import Distinct class SQLCompiler(compiler.SQLCompiler): @@ -245,7 +246,8 @@ def execute_sql( else: return self._make_result(obj, columns) # result_type is MULTI - cursor.batch_size(chunk_size) + if not isinstance(cursor, list): + cursor.batch_size(chunk_size) result = self.cursor_iter(cursor, chunk_size, columns) if not chunked_fetch: # If using non-chunked reads, read data into memory. @@ -347,16 +349,8 @@ def build_query(self, columns=None): if self.query.distinct: # If query is distinct, build a $group stage for distinct # fields, then set project fields based on the grouped _id. - distinct_fields = self.get_project_fields( - columns, ordering_fields, force_expression=True - ) - if not query.aggregation_pipeline: - query.aggregation_pipeline = [] - query.aggregation_pipeline.extend( - [ - {"$group": {"_id": distinct_fields}}, - {"$project": {key: f"$_id.{key}" for key in distinct_fields}}, - ] + query.distinct = Distinct( + fields=self.get_project_fields(columns, ordering_fields, force_expression=True) ) else: # Otherwise, project fields without grouping. diff --git a/django_documentdb/query.py b/django_documentdb/query.py index ee3c6bc..b677988 100644 --- a/django_documentdb/query.py +++ b/django_documentdb/query.py @@ -18,7 +18,7 @@ if typing.TYPE_CHECKING: from django_documentdb.base import DatabaseWrapper from django_documentdb.compiler import SQLCompiler -from django_documentdb.utils import IndexNotUsedWarning +from django_documentdb.utils import Distinct, IndexNotUsedWarning def wrap_database_errors(func): @@ -61,6 +61,7 @@ def __init__(self, compiler: "SQLCompiler"): self.mongo_query = getattr(compiler.query, "raw_query", {}) self.subqueries = None self.lookup_pipeline = None + self.distinct: Distinct | None = None self.project_fields = None self.aggregation_pipeline = compiler.aggregation_pipeline self.extra_fields = None @@ -80,16 +81,24 @@ def delete(self): return self.collection.delete_many(self.mongo_query).deleted_count @wrap_database_errors - def get_cursor(self) -> Cursor[_DocumentType]: + def get_cursor(self) -> Cursor[_DocumentType] | list[dict]: """ Return a pymongo CommandCursor that can be iterated on to give the results of the query. """ - if self.is_simple_lookup: + if self.is_simple_lookup and self.distinct and self.distinct.is_simple_distinct: + results = self.collection.distinct( + self.distinct.field, **self.build_simple_lookup(limit=False, offset=False) + ) + return [{self.distinct.field: x} for x in results] + + if self.is_simple_lookup and not self.distinct: pipeline = self.build_simple_lookup() return self.collection.find(**pipeline) pipeline = self.get_pipeline() + if self.distinct: + pipeline.extend(self.distinct.aggregation()) options = {} if hasattr(self.query, "_index_hint"): options["hint"] = self.query._index_hint @@ -107,7 +116,7 @@ def is_simple_lookup(self) -> bool: and not self.subquery_lookup ) - def build_simple_lookup(self) -> dict: + def build_simple_lookup(self, **kwargs) -> dict: pipeline = {} if self.mongo_query: pipeline["filter"] = self.mongo_query @@ -117,9 +126,9 @@ def build_simple_lookup(self) -> dict: pipeline["projection"] = self.project_fields if self.ordering: pipeline["sort"] = self.ordering - if self.query.low_mark > 0: + if self.query.low_mark > 0 and kwargs.get("offset", True): pipeline["skip"] = self.query.low_mark - if self.query.high_mark is not None: + if self.query.high_mark is not None and kwargs.get("limit", True): pipeline["limit"] = self.query.high_mark - self.query.low_mark if hasattr(self.query, "_index_hint"): pipeline["hint"] = self.query._index_hint diff --git a/django_documentdb/utils.py b/django_documentdb/utils.py index 7e1c2b3..f3c5a51 100644 --- a/django_documentdb/utils.py +++ b/django_documentdb/utils.py @@ -1,5 +1,6 @@ import copy import time +from functools import cached_property import django from django.conf import settings @@ -39,6 +40,7 @@ class OperationDebugWrapper: wrapped_methods = { "find", "aggregate", + "distinct", "create_collection", "create_indexes", "drop", @@ -138,3 +140,25 @@ class NotOptimalOperationWarning(Warning): class IndexNotUsedWarning(DocumentDBIncompatibleWarning, NotOptimalOperationWarning): pass + + +class Distinct: + def __init__( + self, + fields: dict[str, str | dict], + ): + self.fields = fields + + def aggregation(self): + return [ + {"$group": {"_id": self.fields}}, + {"$project": {key: f"$_id.{key}" for key in self.fields}}, + ] + + @cached_property + def is_simple_distinct(self): + return len(self.fields) == 1 + + @property + def field(self) -> str: + return list(self.fields.keys())[0] # noqa: RUF015