diff --git a/flask_mongorest/__init__.py b/flask_mongorest/__init__.py index f7da37f2..7f5d7d8d 100644 --- a/flask_mongorest/__init__.py +++ b/flask_mongorest/__init__.py @@ -1,6 +1,46 @@ from flask import Blueprint -from flask_mongorest.methods import Create, BulkUpdate, List +from flask_mongorest.methods import * +def register_class(app, klass, **kwargs): + # Construct a url based on a 'name' kwarg with a fallback to the + # view's class name. Note that the name must be unique. + name = kwargs.pop('name', klass.__name__) + view_func = klass.as_view(name) + url = kwargs.pop('url', None) + if not url: + document_name = klass.resource.document.__name__.lower() + url = f'/{document_name}/' + + # Insert the url prefix, if it exists + url_prefix = kwargs.pop('url_prefix', '') + if url_prefix: + url = f'{url_prefix}{url}' + + # Add url rules + klass_methods = set(klass.methods) + if Create in klass_methods and BulkCreate in klass_methods: + raise ValueError('Use either Create or BulkCreate!') + + for x in klass_methods & {Fetch, Update, Delete}: + endpoint = view_func.__name__ + x.__name__ + app.add_url_rule( + f'{url}/', defaults={'short_mime': None}, + view_func=view_func, methods=[x.method], endpoint=endpoint, **kwargs + ) + + for x in klass_methods & {Create, BulkFetch, BulkCreate, BulkUpdate, BulkDelete}: + endpoint = view_func.__name__ + x.__name__ + app.add_url_rule( + url, defaults={'pk': None, 'short_mime': None}, + view_func=view_func, methods=[x.method], endpoint=endpoint, **kwargs + ) + + if Download in klass.methods: + endpoint = view_func.__name__ + Download.__name__ + app.add_url_rule( + f'{url}download//', defaults={'pk': None, 'short_mime': 'gz'}, + view_func=view_func, methods=[Download.method], endpoint=endpoint, **kwargs + ) class MongoRest(object): def __init__(self, app, **kwargs): @@ -10,26 +50,7 @@ def __init__(self, app, **kwargs): def register(self, **kwargs): def decorator(klass): - # Construct a url based on a 'name' kwarg with a fallback to the - # view's class name. Note that the name must be unique. - name = kwargs.pop('name', klass.__name__) - url = kwargs.pop('url', None) - if not url: - document_name = klass.resource.document.__name__.lower() - url = '/%s/' % document_name - - # Insert the url prefix, if it exists - if self.url_prefix: - url = '%s%s' % (self.url_prefix, url) - - # Add url rules - pk_type = kwargs.pop('pk_type', 'string') - view_func = klass.as_view(name) - if List in klass.methods: - self.app.add_url_rule(url, defaults={'pk': None}, view_func=view_func, methods=[List.method], **kwargs) - if Create in klass.methods or BulkUpdate in klass.methods: - self.app.add_url_rule(url, view_func=view_func, methods=[x.method for x in klass.methods if x in (Create, BulkUpdate)], **kwargs) - self.app.add_url_rule('%s<%s:%s>/' % (url, pk_type, 'pk'), view_func=view_func, methods=[x.method for x in klass.methods if x not in (List, BulkUpdate)], **kwargs) + register_class(self.app, klass, **kwargs) return klass return decorator diff --git a/flask_mongorest/methods.py b/flask_mongorest/methods.py index 5357ab45..21618fd5 100644 --- a/flask_mongorest/methods.py +++ b/flask_mongorest/methods.py @@ -1,17 +1,34 @@ +import sys +import inspect + +class Fetch: + method = 'GET' + class Create: method = 'POST' class Update: method = 'PUT' +class Delete: + method = 'DELETE' + + +class BulkFetch: + method = 'GET' + +class BulkCreate: + method = 'POST' + class BulkUpdate: method = 'PUT' -class Fetch: - method = 'GET' +class BulkDelete: + method = 'DELETE' + -class List: +class Download: method = 'GET' -class Delete: - method = 'DELETE' +members = inspect.getmembers(sys.modules[__name__], inspect.isclass) +__all__ = [m[0] for m in members] diff --git a/flask_mongorest/operators.py b/flask_mongorest/operators.py index 9e0ed823..85a6000b 100644 --- a/flask_mongorest/operators.py +++ b/flask_mongorest/operators.py @@ -53,6 +53,7 @@ class Operator(object): """Base class that all the other operators should inherit from.""" op = 'exact' + typ = 'string' # Can be overridden via constructor. allow_negation = False @@ -75,20 +76,42 @@ def apply(self, queryset, field, value, negate=False): kwargs = self.prepare_queryset_kwargs(field, value, negate) return queryset.filter(**kwargs) +def try_float(value): + try: + return float(value) + except ValueError: + return value + class Ne(Operator): op = 'ne' class Lt(Operator): op = 'lt' + typ = 'number' + + def prepare_queryset_kwargs(self, field, value, negate): + return {'__'.join(filter(None, [field, self.op])): try_float(value)} class Lte(Operator): op = 'lte' + typ = 'number' + + def prepare_queryset_kwargs(self, field, value, negate): + return {'__'.join(filter(None, [field, self.op])): try_float(value)} class Gt(Operator): op = 'gt' + typ = 'number' + + def prepare_queryset_kwargs(self, field, value, negate): + return {'__'.join(filter(None, [field, self.op])): try_float(value)} class Gte(Operator): op = 'gte' + typ = 'number' + + def prepare_queryset_kwargs(self, field, value, negate): + return {'__'.join(filter(None, [field, self.op])): try_float(value)} class Exact(Operator): op = 'exact' @@ -106,6 +129,7 @@ class IExact(Operator): class In(Operator): op = 'in' + typ = 'array' def prepare_queryset_kwargs(self, field, value, negate): # this is null if the user submits an empty in expression (like @@ -140,6 +164,7 @@ class IEndswith(Operator): class Boolean(Operator): op = 'exact' + typ = 'boolean' def prepare_queryset_kwargs(self, field, value, negate): if value == 'false': diff --git a/flask_mongorest/resources.py b/flask_mongorest/resources.py index d293c057..0c8db395 100644 --- a/flask_mongorest/resources.py +++ b/flask_mongorest/resources.py @@ -1,6 +1,9 @@ import json import mongoengine +from fastnumbers import fast_int +from unflatten import unflatten +from typing import Pattern from bson.dbref import DBRef from bson.objectid import ObjectId from flask import has_request_context, request, url_for @@ -16,15 +19,38 @@ DocumentProxy = None SafeReferenceField = None -from mongoengine.fields import EmbeddedDocumentField, ListField, ReferenceField, GenericReferenceField +from mongoengine.fields import EmbeddedDocumentField, ListField +from mongoengine.fields import GenericReferenceField, ReferenceField +from mongoengine.fields import GenericLazyReferenceField, LazyReferenceField from mongoengine.fields import DictField -from cleancat import ValidationError as SchemaValidationError +try: + from cleancat import Schema as CleancatSchema + from cleancat import ValidationError as SchemaValidationError +except ImportError: + CleancatSchema = None + +try: + from marshmallow_mongoengine import ModelSchema + from marshmallow.exceptions import ValidationError as MarshmallowValidationError + from marshmallow.utils import get_value, set_value, _Missing +except ImportError: + ModelSchema = None + from glom import glom, assign + from glom.core import PathAccessError + from flask_mongorest import methods from flask_mongorest.exceptions import ValidationError, UnknownFieldError from flask_mongorest.utils import cmp_fields, isbound, isint, equal +def get_with_list_index(o, k): + try: + return o[fast_int(k)] + except ValueError: + return o[k] + + class ResourceMeta(type): def __init__(cls, name, bases, classdict): if classdict.get('__metaclass__') is not ResourceMeta: @@ -33,6 +59,7 @@ def __init__(cls, name, bases, classdict): cls.child_document_resources[document] = cls type.__init__(cls, name, bases, classdict) + class Resource(object): # MongoEngine Document class related to this resource (required) document = None @@ -63,7 +90,10 @@ class Resource(object): max_limit = 100 # Maximum number of objects which can be bulk-updated by a single request - bulk_update_limit = 1000 + bulk_update_limit = 1000 # NOTE also used for bulk delete + + # Map of field names to paginate with according default and maximum limits + fields_to_paginate = {} # Map of field names and Resource classes that should be used to handle # these fields (for serialization, saving, etc.). @@ -95,6 +125,9 @@ class Resource(object): # filtered query set, pulling all the references efficiently. select_related = False + # allow download formats + download_formats = [] + # Must start and end with a "/" uri_prefix = None @@ -114,12 +147,14 @@ def __init__(self, view_method=None): self._reverse_rename_fields[v] = k assert len(self._rename_fields) == len(self._reverse_rename_fields), \ 'Cannot rename multiple fields to the same name' - self._filters = self.get_filters() + self._normal_filters, self._regex_filters = self.get_filters() self._child_document_resources = self.get_child_document_resources() self._default_child_resource_document = self.get_default_child_resource_document() self.data = None self._dirty_fields = None self.view_method = view_method + self._normal_allowed_ordering = [o for o in self.allowed_ordering if not isinstance(o, Pattern)] + self._regex_allowed_ordering = [o for o in self.allowed_ordering if isinstance(o, Pattern)] @property def params(self): @@ -173,10 +208,15 @@ def raw_data(self): raise ValidationError({'error': "Chunked Transfer-Encoding is not supported."}) try: - self._raw_data = json.loads(request.data.decode('utf-8'), parse_constant=self._enforce_strict_json) + self._raw_data = json.loads( + request.data.decode('utf-8'), + parse_constant=self._enforce_strict_json + ) + if request.method == 'PUT': + self._raw_data = unflatten(self._raw_data) except ValueError: raise ValidationError({'error': 'The request contains invalid JSON.'}) - if not isinstance(self._raw_data, dict): + if request.method == 'PUT' and not isinstance(self._raw_data, dict): raise ValidationError({'error': 'JSON data must be a dict.'}) else: self._raw_data = {} @@ -209,7 +249,8 @@ def get_fields(self): """ return self.fields - def get_optional_fields(self): + @staticmethod + def get_optional_fields(): """ Return a list of fields that can optionally be included in the response (but only if a `_fields` param mentioned them explicitly). @@ -228,15 +269,18 @@ def get_requested_fields(self, **kwargs): include_all = False + # NOTE use list(dict.fromkeys()) below instead of set() to maintain order if 'fields' in kwargs: fields = kwargs['fields'] - all_fields_set = set(fields) + all_fields_set = list(dict.fromkeys(fields)) else: - fields = self.get_fields() - all_fields_set = set(fields) | set(self.get_optional_fields()) + fields = list(self.get_fields()) + all_fields = fields + self.get_optional_fields() + all_fields_set = list(dict.fromkeys(all_fields)) if params and '_fields' in params: - only_fields = set(params['_fields'].split(',')) + params_fields = params['_fields'].split(',') + only_fields = list(dict.fromkeys(params_fields)) if '_all' in only_fields: include_all = True else: @@ -244,7 +288,7 @@ def get_requested_fields(self, **kwargs): requested_fields = [] if include_all or only_fields is None: - if include_all: + if include_all or self.view_method == methods.Download: field_selection = all_fields_set else: field_selection = fields @@ -253,7 +297,7 @@ def get_requested_fields(self, **kwargs): else: for field in only_fields: actual_field = self._reverse_rename_fields.get(field, field) - if actual_field in all_fields_set: + if actual_field in all_fields_set or any(actual_field.startswith(f) for f in all_fields_set): requested_fields.append(actual_field) return requested_fields @@ -309,15 +353,20 @@ def get_filters(self): `?date__gte=value` to the 'date' field and the 'gte' suffix: 'gte', and hence use the Gte operator to filter the data. """ - filters = {} + normal_filters, regex_filters = {}, {} for field, operators in getattr(self, 'filters', {}).items(): field_filters = {} + for op in operators: if op.op == 'exact': field_filters[''] = op field_filters[op.op] = op - filters[field] = field_filters - return filters + + if isinstance(field, Pattern): + regex_filters[field] = field_filters + else: + normal_filters[field] = field_filters + return normal_filters, regex_filters def serialize_field(self, obj, **kwargs): if self.uri_prefix and hasattr(obj, "id"): @@ -355,12 +404,17 @@ def get_field_value(self, obj, field_name, field_instance=None, **kwargs): # Determine the field value if has_field_instance: field_value = obj - elif isinstance(obj, dict): - return obj[field_name] - else: + elif ModelSchema is None: try: field_value = getattr(obj, field_name) except AttributeError: + try: + field_value = glom(obj, field_name) # slow + except PathAccessError: + raise UnknownFieldError + else: + field_value = get_value(obj, field_name) + if isinstance(field_value, _Missing): raise UnknownFieldError return self.serialize_field_value(obj, field_name, field_instance, field_value, **kwargs) @@ -371,6 +425,9 @@ def serialize_field_value(self, obj, field_name, field_instance, field_value, ** field_value is an actual value to be serialized. For other fields, see get_field_value method. """ + if isinstance(field_instance, (LazyReferenceField, GenericLazyReferenceField)): + return field_value and field_value.pk + if isinstance(field_instance, (ReferenceField, GenericReferenceField, EmbeddedDocumentField)): return self.serialize_document_field(field_name, field_value, **kwargs) @@ -382,6 +439,7 @@ def serialize_field_value(self, obj, field_name, field_instance, field_value, ** elif callable(field_instance): return self.serialize_callable_field(obj, field_instance, field_name, field_value, **kwargs) + return field_value def serialize_callable_field(self, obj, field_instance, field_name, field_value, **kwargs): @@ -398,12 +456,13 @@ def serialize_callable_field(self, obj, field_instance, field_name, field_value, else: value = field_instance(obj) if field_name in self._related_resources: + res = self._related_resources[field_name](view_method=self.view_method) if isinstance(value, list): - return [self._related_resources[field_name]().serialize_field(o, **kwargs) for o in value] + return [res.serialize_field(o, **kwargs) for o in value] elif value is None: return None else: - return self._related_resources[field_name]().serialize_field(value, **kwargs) + return res.serialize_field(value, **kwargs) return value def serialize_dict_field(self, field_instance, field_name, field_value, **kwargs): @@ -423,18 +482,27 @@ def serialize_dict_field(self, field_instance, field_name, field_value, **kwargs def serialize_list_field(self, field_instance, field_name, field_value, **kwargs): """Serialize each item in the list separately.""" - return [val for val in [self.get_field_value(elem, field_name, field_instance=field_instance.field, **kwargs) for elem in field_value] if val] + if not field_value: + return [] + + field_values = [] + for elem in field_value: + fv = self.get_field_value( + elem, field_name, field_instance=field_instance.field, **kwargs + ) + if fv is not None: + field_values.append(fv) + + return field_values def serialize_document_field(self, field_name, field_value, **kwargs): """If this field is a reference or an embedded document, either return a DBRef or serialize it using a resource found in `related_resources`. """ if field_name in self._related_resources: - return ( - field_value and - not isinstance(field_value, DBRef) and - self._related_resources[field_name]().serialize_field(field_value, **kwargs) - ) + if field_value and not isinstance(field_value, DBRef): + res = self._related_resources[field_name](view_method=self.view_method) + return res.serialize_field(field_value, **kwargs) else: if DocumentProxy and isinstance(field_value, DocumentProxy): # Don't perform a DBRef isinstance check below since @@ -476,13 +544,16 @@ def serialize(self, obj, **kwargs): renamed_field = self._rename_fields.get(field, field) # if the field is callable, execute it with `obj` as the param + value = None if hasattr(self, field) and callable(getattr(self, field)): value = getattr(self, field)(obj) # if the field is associated with a specific resource (via the # `related_resources` map), use that resource to serialize it if field in self._related_resources and value is not None: - related_resource = self._related_resources[field]() + related_resource = self._related_resources[field]( + view_method=self.view_method + ) if isinstance(value, mongoengine.document.Document): value = related_resource.serialize_field(value) elif isinstance(value, dict): @@ -491,16 +562,21 @@ def serialize(self, obj, **kwargs): else: # assume queryset or list value = [related_resource.serialize_field(o) for o in value] - data[renamed_field] = value else: try: - data[renamed_field] = self.get_field_value(obj, field, **kwargs) + value = self.get_field_value(obj, field, **kwargs) except UnknownFieldError: try: - data[renamed_field] = self.value_for_field(obj, field) + value = self.value_for_field(obj, field) except UnknownFieldError: pass + if value is not None: + if ModelSchema is None: + assign(data, renamed_field, value, missing=dict) # slow + else: + set_value(data, renamed_field, value) + return data def handle_serialization_error(self, exc, obj): @@ -556,23 +632,42 @@ def validate_request(self, obj=None): # If CleanCat schema exists on this resource, use it to perform the # validation if self.schema: + if CleancatSchema is None and ModelSchema is None: + raise ImportError('Cannot validate schema without CleanCat or Marshmallow!') + if request.method == 'PUT' and obj is not None: obj_data = dict([(key, getattr(obj, key)) for key in obj._fields.keys()]) else: obj_data = None - schema = self.schema(self.data, obj_data) - try: - self.data = schema.full_clean() - except SchemaValidationError: - raise ValidationError({'field-errors': schema.field_errors, 'errors': schema.errors }) + if CleancatSchema is not None: + try: + schema = self.schema(self.data, obj_data) + self.data = schema.full_clean() + except SchemaValidationError: + raise ValidationError({'field-errors': schema.field_errors, 'errors': schema.errors }) + elif ModelSchema is not None: + try: + partial = bool(request.method == 'PUT' and obj is not None) + self.data = self.schema().load(self.data, partial=partial) + except MarshmallowValidationError as ex: + raise ValidationError({'errors': ex.messages}) def get_queryset(self): """ Return a MongoEngine queryset that will later be used to return matching documents. """ - return self.document.objects + if request.method == 'PUT': + return self.document.objects # get full documents for updates + else: + document_fields = set(self.fields + self.get_optional_fields()) + requested_fields = self.get_requested_fields(params=self.params) + requested_root_fields = set(f.split('.', 1)[0] for f in requested_fields) + mask = requested_root_fields & document_fields + if self.view_method == methods.Download: + mask.add("last_modified") + return self.document.objects.only(*mask) def get_object(self, pk, qfilter=None): """ @@ -584,6 +679,10 @@ def get_object(self, pk, qfilter=None): # get a new one out if qfilter: qs = qfilter(qs) + + if self.view_method != methods.Download: + qs = self.apply_field_pagination(qs) + obj = qs.get(pk=pk) # We don't need to fetch related resources for DELETE requests because @@ -597,6 +696,34 @@ def get_object(self, pk, qfilter=None): return obj + def apply_field_pagination(self, qs, params=None): + """apply field pagination according to `fields_to_paginate`""" + if params is None: + params = self.params + + field_attrs = {} + for field, limits in self.fields_to_paginate.items(): + page = params.get(f'{field}_page', 1) + per_page = params.get(f'{field}_per_page', limits[0]) + if not isint(page): + raise ValidationError({'error': f'{field}_page must be an integer.'}) + if not isint(per_page): + raise ValidationError({'error': f'{field}_per_page must be an integer.'}) + + page, per_page = int(page), int(per_page) + if per_page > limits[1]: + raise ValidationError({ + 'error': f"Per-page limit ({per_page}) for {field} too large ({limits[1]})." + }) + if page < 0: + raise ValidationError({'error': f'{field}_page must be a non-negative integer.'}) + + per_page = min(per_page, limits[1]) + start_index = (page - 1) * per_page + field_attrs[field] = {"$slice": [start_index, per_page]} + + return qs.fields(**field_attrs) + def fetch_related_resources(self, objs, only_fields=None): """ Given a list of objects and an optional list of the only fields we @@ -711,7 +838,16 @@ def apply_filters(self, qs, params=None): parts = key.split('__') for i in range(len(parts) + 1, 0, -1): field = '__'.join(parts[:i]) - allowed_operators = self._filters.get(field) + try: + allowed_operators = self._normal_filters[field] + except KeyError: + for k, v in self._regex_filters.items(): + m = k.match(field) + if m: + allowed_operators = v + break + else: + allowed_operators = None if allowed_operators: parts = parts[i:] break @@ -750,9 +886,17 @@ def apply_ordering(self, qs, params=None): """ if params is None: params = self.params - if self.allowed_ordering and params.get('_order_by') in self.allowed_ordering: - order_params = [self._reverse_rename_fields.get(p, p) for p in params['_order_by'].split(',')] - qs = qs.order_by(*order_params) + if self.allowed_ordering: + oby = params.get('_order_by') + if oby: + order_params = None + if oby in self._normal_allowed_ordering: + order_params = [self._reverse_rename_fields.get(p, p) for p in oby.split(',')] + elif any(p.match(oby) for p in self._regex_allowed_ordering): + order_params = [oby] + if order_params: + order_sign = '-' if params.get('order') == 'desc' else '+' + qs = qs.order_by(*[f'{order_sign}{p}' for p in order_params]) return qs def get_skip_and_limit(self, params=None): @@ -765,18 +909,31 @@ def get_skip_and_limit(self, params=None): params = self.params if self.paginate: # _limit and _skip validation - if not isint(params.get('_limit', 1)): - raise ValidationError({'error': '_limit must be an integer (got "%s" instead).' % params['_limit']}) - if not isint(params.get('_skip', 1)): - raise ValidationError({'error': '_skip must be an integer (got "%s" instead).' % params['_skip']}) - if params.get('_limit') and int(params['_limit']) > max_limit: - raise ValidationError({'error': "The limit you set is larger than the maximum limit for this resource (max_limit = %d)." % max_limit}) - if params.get('_skip') and int(params['_skip']) < 0: - raise ValidationError({'error': '_skip must be a non-negative integer (got "%s" instead).' % params['_skip']}) - - limit = min(int(params.get('_limit', self.default_limit)), max_limit) + for par in ['_limit', 'per_page']: + if par in params: + if not isint(params[par]): + raise ValidationError({'error': f'{par} must be an integer (got "%s" instead).' % params[par]}) + if params[par] and int(params[par]) > max_limit: + raise ValidationError({'error': "The limit you set is larger than the maximum limit for this \ + resource (max_limit = %d)." % max_limit}) + limit = min(int(params[par]), max_limit) + break + else: + limit = min(int(self.default_limit), max_limit) + + for par in ['_skip', 'page']: + if par in params: + if not isint(params[par]): + raise ValidationError({'error': f'{par} must be an integer (got "%s" instead).' % params[par]}) + if params[par] and int(params[par]) < 0: + raise ValidationError({'error': f'{par} must be a non-negative integer (got "%s" instead).' % params[par]}) + skip = int(params[par]) if par == '_skip' else (int(params[par])-1) * limit + break + else: + skip = 0 + # Fetch one more so we know if there are more results. - return int(params.get('_skip', 0)), limit + return skip, limit else: return 0, max_limit @@ -791,31 +948,41 @@ def get_objects(self, qs=None, qfilter=None): - Pass `qfilter` function to modify the queryset. """ params = self.params + extra = {} + + if self.view_method == methods.Download: + fmt = params.get('format') + if fmt not in self.download_formats: + raise ValueError(f'`format` must be one of {self.download_formats}') custom_qs = True if qs is None: custom_qs = False qs = self.get_queryset() + # Apply filters and ordering, based on the params supplied by the request + qs = self.apply_filters(qs, params) + qs = self.apply_ordering(qs, params) + # If a queryset filter was provided, pass our current queryset in and # get a new one out if qfilter: qs = qfilter(qs) - # Apply filters and ordering, based on the params supplied by the - # request - qs = self.apply_filters(qs, params) - qs = self.apply_ordering(qs, params) + # set total count + extra['total_count'] = qs.count() - # Apply limit and skip to the queryset + # Apply pagination to the queryset (if not Download and no custom queryset provided) limit = None - if self.view_method == methods.BulkUpdate: + if self.view_method in [methods.BulkUpdate, methods.BulkDelete]: # limit the number of objects that can be bulk-updated at a time qs = qs.limit(self.bulk_update_limit) - elif not custom_qs: + elif not custom_qs and self.view_method != methods.Download: # no need to skip/limit if a custom `qs` was provided skip, limit = self.get_skip_and_limit(params) - qs = qs.skip(skip).limit(limit+1) + qs = qs.skip(skip).limit(limit+1) # get one extra to determine has_more + qs = self.apply_field_pagination(qs, params) + extra['total_pages'] = int(extra['total_count']/limit) + bool(extra['total_count'] % limit) # Needs to be at the end as it returns a list, not a queryset if self.select_related: @@ -823,30 +990,21 @@ def get_objects(self, qs=None, qfilter=None): # Evaluate the queryset objs = list(qs) + has_more = None + if self.view_method not in [methods.BulkUpdate, methods.BulkDelete, methods.Download] and self.paginate: + has_more = bool(len(objs) > limit) - # Raise a validation error if bulk update would result in more than - # bulk_update_limit updates - if self.view_method == methods.BulkUpdate and len(objs) >= self.bulk_update_limit: - raise ValidationError({ - 'errors': ["It's not allowed to update more than %d objects at once" % self.bulk_update_limit] - }) - - # Determine the value of has_more - if self.view_method != methods.BulkUpdate and self.paginate: - has_more = len(objs) > limit - if has_more: - objs = objs[:-1] - else: - has_more = None + if has_more: + objs = objs[:-1] # bulk-fetch related resources for moar speed self.fetch_related_resources( objs, self.get_requested_fields(params=params) ) - return objs, has_more + return objs, has_more, extra - def save_related_objects(self, obj, parent_resources=None): + def save_related_objects(self, obj, parent_resources=None, **kwargs): if not parent_resources: parent_resources = [self] else: @@ -881,7 +1039,7 @@ def save_related_objects(self, obj, parent_resources=None): def save_object(self, obj, **kwargs): self.save_related_objects(obj, **kwargs) - obj.save() + obj.save(**kwargs) obj.reload() self._dirty_fields = None # No longer dirty. @@ -905,7 +1063,7 @@ def create_object(self, data=None, save=True, parent_resources=None): obj = self.document(**update_dict) self._dirty_fields = update_dict.keys() if save: - self.save_object(obj) + self.save_object(obj, force_insert=True) return obj def update_object(self, obj, data=None, save=True, parent_resources=None): @@ -913,7 +1071,7 @@ def update_object(self, obj, data=None, save=True, parent_resources=None): if subresource: return subresource.update_object(obj, data=data, save=save, parent_resources=parent_resources) - update_dict = self.get_object_dict(data, update=True) + update_dict = self.get_object_dict(data, update=True) if save else data self._dirty_fields = [] @@ -928,19 +1086,31 @@ def update_object(self, obj, data=None, save=True, parent_resources=None): id_from_data = value and getattr(value, 'pk', value) if id_from_obj != id_from_data: update = True - elif not equal(getattr(obj, field), value): + elif getattr(obj, '_fields', None) is not None: + if isinstance(obj._fields.get(field), DictField): + if value is None: + update = True + else: + if obj[field] is None: + obj[field] = {} + self.update_object(obj[field], data=value, save=False) + elif obj._fields[field].primary_key: + raise ValidationError({'error': f'`{field}` is primary key and cannot be updated'}) + elif not equal(getattr(obj, field), value): + update = True + elif not equal(obj.get(field), value): update = True if update: - setattr(obj, field, value) + set_value(obj, field, value) self._dirty_fields.append(field) if save: self.save_object(obj) return obj - def delete_object(self, obj, parent_resources=None): - obj.delete() + def delete_object(self, obj, parent_resources=None, skip_post_delete=False): + obj.delete(signal_kwargs={"skip": skip_post_delete}) # Py2/3 compatible way to do metaclasses (or six.add_metaclass) diff --git a/flask_mongorest/utils.py b/flask_mongorest/utils.py index 272473d2..1a7a65cb 100644 --- a/flask_mongorest/utils.py +++ b/flask_mongorest/utils.py @@ -3,6 +3,7 @@ import datetime from bson.dbref import DBRef from bson.objectid import ObjectId +from bson.decimal128 import Decimal128 import mongoengine isbound = lambda m: getattr(m, 'im_self', None) is not None @@ -18,14 +19,17 @@ class MongoEncoder(json.JSONEncoder): def default(self, value, **kwargs): if isinstance(value, ObjectId): return str(value) - if isinstance(value, DBRef): + elif isinstance(value, DBRef): return value.id - if isinstance(value, datetime.datetime): + elif isinstance(value, datetime.datetime): return value.isoformat() - if isinstance(value, datetime.date): + elif isinstance(value, datetime.date): return value.strftime("%Y-%m-%d") - if isinstance(value, decimal.Decimal): + elif isinstance(value, decimal.Decimal): return str(value) + elif isinstance(value, Decimal128): + return str(value.to_decimal()) + return super(MongoEncoder, self).default(value, **kwargs) diff --git a/flask_mongorest/views.py b/flask_mongorest/views.py index 93463b18..b26e28d3 100644 --- a/flask_mongorest/views.py +++ b/flask_mongorest/views.py @@ -1,18 +1,73 @@ +import os +import sys +import time import json - -import mimerender +import boto3 +import hashlib +import traceback import mongoengine +from gzip import GzipFile +from io import BytesIO +from fdict import fdict +from collections import deque from flask import render_template, request from flask.views import MethodView +from flask_sse import sse from flask_mongorest import methods from flask_mongorest.exceptions import ValidationError from flask_mongorest.utils import MongoEncoder from werkzeug.exceptions import NotFound, Unauthorized +from mimerender import register_mime, FlaskMimeRender +from botocore.errorfactory import ClientError +from urllib.parse import unquote + +BUCKET = os.environ.get('S3_DOWNLOADS_BUCKET', 'mongorest-downloads') +CNAME = os.environ.get('PORTAL_CNAME') +S3_DOWNLOAD_URL = f"https://{BUCKET}.s3.amazonaws.com" + +s3_client = boto3.client('s3') +mimerender = FlaskMimeRender(global_override_input_key='short_mime') +register_mime('gz', ('application/gzip',)) + +def render_json(**payload): + return json.dumps(payload, allow_nan=True, cls=MongoEncoder) + +def render_html(**payload): + d = json.dumps(payload, cls=MongoEncoder, sort_keys=True, indent=4) + return render_template('mongorest/debug.html', data=d) + +def render_gz(**payload): + s3 = payload.get("s3") + + if s3 and s3["update"]: + fmt = request.args.get('format') + content_type = 'text/csv' if fmt == 'csv' else 'application/json' + if fmt == 'json': + contents = json.dumps(payload['data'], allow_nan=True, cls=MongoEncoder) + elif fmt == 'csv': + from pandas import DataFrame + from cherrypicker import CherryPicker + records = [CherryPicker(d).flatten().get() for d in payload['data']] + contents = DataFrame.from_records(records).to_csv() + + gzip_buffer = BytesIO() + with GzipFile(mode='wb', fileobj=gzip_buffer) as gzip_file: + gzip_file.write(contents.encode('utf-8')) # need to give full contents to compression + + body = gzip_buffer.getvalue() + s3_client.put_object( + Bucket=BUCKET, + Key=s3["key"], + ContentType=content_type, + ContentEncoding='gzip', + Body=body + ) + return body + + retr = s3_client.get_object(Bucket=BUCKET, Key=s3["key"]) + gzip_buffer = BytesIO(retr['Body'].read()) + return gzip_buffer.getvalue() -mimerender = mimerender.FlaskMimeRender() - -render_json = lambda **payload: json.dumps(payload, allow_nan=False, cls=MongoEncoder) -render_html = lambda **payload: render_template('mongorest/debug.html', data=json.dumps(payload, cls=MongoEncoder, sort_keys=True, indent=4)) try: text_type = unicode # Python 2 @@ -57,7 +112,7 @@ class ResourceView(MethodView): def __init__(self): assert(self.resource and self.methods) - @mimerender(default='json', json=render_json, html=render_html) + @mimerender(default='json', json=render_json, html=render_html, gz=render_gz) def dispatch_request(self, *args, **kwargs): # keep all the logic in a helper method (_dispatch_request) so that # it's easy for subclasses to override this method (when they don't want to use @@ -76,14 +131,18 @@ def _dispatch_request(self, *args, **kwargs): try: self._resource = self.requested_resource(request) return super(ResourceView, self).dispatch_request(*args, **kwargs) - except mongoengine.queryset.DoesNotExist as e: - return {'error': 'Empty query: ' + str(e)}, '404 Not Found' - except ValidationError as e: - return e.args[0], '400 Bad Request' - except Unauthorized: - return {'error': 'Unauthorized'}, '401 Unauthorized' - except NotFound as e: + except (ValueError, ValidationError, mongoengine.errors.ValidationError) as e: + return {'error': str(e)}, '400 Bad Request' + except (Unauthorized, mongoengine.errors.NotUniqueError) as e: + return {'error': str(e)}, '401 Unauthorized' + except (NotFound, mongoengine.queryset.DoesNotExist) as e: return {'error': str(e)}, '404 Not Found' + except Exception as e: + exc_type, exc_value, exc_tb = sys.exc_info() + tb = traceback.format_exception(exc_type, exc_value, exc_tb) + err = ''.join(tb) + print(err) + return {'error': err}, '500 Internal Server Error' def handle_validation_error(self, e): if isinstance(e, ValidationError): @@ -103,16 +162,23 @@ def requested_resource(self, request): def get(self, **kwargs): pk = kwargs.pop('pk', None) + short_mime = kwargs.pop('short_mime', None) + fmt = self._resource.params.get('format') # Set the view_method on a resource instance if pk: self._resource.view_method = methods.Fetch + elif short_mime: + if short_mime != 'gz': + raise ValueError(f'{short_mime} not supported') + self._resource.view_method = methods.Download else: - self._resource.view_method = methods.List + self._resource.view_method = methods.BulkFetch # Create a queryset filter to control read access to the # underlying objects qfilter = lambda qs: self.has_read_permission(request, qs.clone()) + if pk is None: result = self._resource.get_objects(qfilter=qfilter) @@ -126,19 +192,50 @@ def get(self, **kwargs): else: raise ValueError('Unsupported value of resource.get_objects') - data = [] - for obj in objs: + # generate hash/etag and S3 object name + if self._resource.view_method == methods.Download: + primary_keys = [str(obj.pk) for obj in objs] + last_modified = max(obj.last_modified for obj in objs) + dct = {"ids": primary_keys, "params": self._resource.params} + sha1 = hashlib.sha1( + json.dumps(dct, sort_keys=True).encode('utf-8') + ).hexdigest() + filename = f"{sha1}.{fmt}" + key = f"{CNAME}/{filename}" if CNAME else filename + extra["s3"] = {"key": key, "update": False} try: - data.append(self._resource.serialize(obj, params=request.args)) - except Exception as e: - fixed_obj = self._resource.handle_serialization_error(e, obj) - if fixed_obj is not None: - data.append(fixed_obj) + s3_client.head_object( + Bucket=BUCKET, Key=key, IfModifiedSince=last_modified + ) + except ClientError: + extra["s3"]["update"] = True # Serialize the objects one by one - ret = { - 'data': data - } + data = [] + url = unquote(request.url).encode('utf-8') + channel = hashlib.sha1(url).hexdigest() + if "s3" not in extra or extra["s3"]["update"]: + print(f"serializing {channel}...") + tic = time.perf_counter() + batch_size, total_count = 1000, extra["total_count"] + for idx, obj in enumerate(objs): + if idx > 0 and (not idx % batch_size or idx == total_count - 1): + toc = time.perf_counter() + nobjs = batch_size + if idx == total_count - 1: + nobjs = total_count - batch_size * int(idx/batch_size) - 1 + print(f"{idx} Took {toc - tic:0.4f}s to serialize {nobjs} objects.") + if self._resource.view_method == methods.Download: + sse.publish({"message": idx + 1}, type="download", channel=channel) + tic = time.perf_counter() + try: + data.append(self._resource.serialize(obj, params=request.args)) + except Exception as e: + fixed_obj = self._resource.handle_serialization_error(e, obj) + if fixed_obj is not None: + data.append(fixed_obj) + + ret = {'data': data} if has_more is not None: ret['has_more'] = has_more @@ -148,18 +245,57 @@ def get(self, **kwargs): else: obj = self._resource.get_object(pk, qfilter=qfilter) ret = self._resource.serialize(obj, params=request.args) - return ret + + if self._resource.view_method == methods.Download: + sse.publish({"message": 0}, type="download", channel=channel) + return ret, '200 OK', { + 'Content-Disposition': f'attachment; filename="{filename}.{short_mime}"' + } + else: + return ret def post(self, **kwargs): - if 'pk' in kwargs: + if kwargs.pop('pk'): raise NotFound("Did you mean to use PUT?") - # Set the view_method on a resource instance - self._resource.view_method = methods.Create + raw_data = self._resource.raw_data + if isinstance(raw_data, dict): + # create single object + self._resource.view_method = methods.Create + return self.create_object() + elif isinstance(raw_data, list): + limit = self._resource.bulk_update_limit + if len(raw_data) > limit: + raise ValidationError({ + 'errors': [f"Can only create {limit} documents at once"] + }) + raw_data_deque = deque(raw_data) + self._resource.view_method = methods.BulkCreate + data = [] + tic = time.perf_counter() + while len(raw_data_deque): + self._resource._raw_data = raw_data_deque.popleft() + data.append(self.create_object()) + dt = time.perf_counter() - tic + if dt > 50: + break + + count = len(data) + msg = f"Created {count} objects in {dt:0.1f}s ({count/dt:0.3f}/s)." + print(msg) + ret = {'data': data, 'count': count} + if raw_data_deque: + remain = len(raw_data_deque) + msg += f" Remaining {remain} objects skipped to avoid Server Timeout." + ret['warning'] = msg + return ret, '201 Created' + else: + raise ValidationError({'error': 'wrong payload type'}) + def create_object(self): self._resource.validate_request() try: - obj = self._resource.create_object() + obj = self._resource.create_object(save=False) except Exception as e: self.handle_validation_error(e) @@ -167,11 +303,8 @@ def post(self, **kwargs): if not self.has_add_permission(request, obj): raise Unauthorized - ret = self._resource.serialize(obj, params=request.args) - if isinstance(obj, mongoengine.Document) and self._resource.uri_prefix: - return ret, "201 Created", {"Location": self._resource._url(str(obj.id))} - else: - return ret + self._resource.save_object(obj, force_insert=True) + return self._resource.serialize(obj, params=request.args) def process_object(self, obj): """Validate and update an object""" @@ -233,27 +366,79 @@ def put(self, **kwargs): objs, has_more, extra = result # Update all the objects and return their count - return self.process_objects(objs) + ret = self.process_objects(objs) + ret['has_more'] = has_more + ret.update(extra) + return ret else: obj = self._resource.get_object(pk) self.process_object(obj) - ret = self._resource.serialize(obj, params=request.args) + raw_data = fdict(self._resource.raw_data, delimiter='.') + fields = ','.join(raw_data.keys()) + return self._resource.serialize(obj, params={'_fields': fields}) + + def delete_object(self, obj, skip_post_delete=False): + """Delete an object""" + # Check if we have permission to delete this object + if not self.has_delete_permission(request, obj): + raise Unauthorized + + try: + self._resource.delete_object(obj, skip_post_delete=skip_post_delete) + except Exception as e: + self.handle_validation_error(e) + + def delete_objects(self, objs): + """Delete each object in the list one by one, and return the total count.""" + tic = time.perf_counter() + nobjs, count = len(objs), 0 + try: + # separately delete last object to send skip signal + for iobj, obj in enumerate(objs): + skip = iobj < nobjs - 1 + self.delete_object(obj, skip_post_delete=skip) + count += 1 + dt = time.perf_counter() - tic + if dt > 50: + break + except ValidationError as e: + e.args[0]['count'] = count + raise e + else: + msg = f"Deleted {count} objects in {dt:0.1f}s ({count/dt:0.3f}/s)." + print(msg) + ret = {'count': count} + remain = nobjs - count + if remain: + msg += f" Remaining {remain} objects skipped to avoid Server Timeout." + ret['warning'] = msg return ret def delete(self, **kwargs): pk = kwargs.pop('pk', None) # Set the view_method on a resource instance - self._resource.view_method = methods.Delete - - obj = self._resource.get_object(pk) + if pk: + self._resource.view_method = methods.Delete + else: + self._resource.view_method = methods.BulkDelete - # Check if we have permission to delete this object - if not self.has_delete_permission(request, obj): - raise Unauthorized + if pk is None: + result = self._resource.get_objects() + if len(result) == 2: + objs, has_more = result + elif len(result) == 3: + objs, has_more, extra = result - self._resource.delete_object(obj) - return {} + # Delete all the objects and return their count + ret = self.delete_objects(objs) + ret['has_more'] = has_more + ret.update(extra) + return ret + else: + obj = self._resource.get_object(pk) + self.delete_object(obj) + return {'count': 1} # This takes a QuerySet as an argument and then # returns a query set that this request can read diff --git a/requirements.txt b/requirements.txt index d25a6a9a..6e0330c9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,8 @@ --e git://github.com/closeio/mongoengine.git#egg=mongoengine-dev --e git://github.com/closeio/flask-mongoengine.git#egg=flask-mongoengine -mimerender -python-dateutil -sphinx -cleancat>=0.3 -Flask>=0.9 -pymongo>=3.4 -flake8 +mongoengine==0.21.0 +flask-mongoengine==1.0.0 +mimerender @ git+https://github.com/tschaume/mimerender@mpcontribs#egg=mimerender-0.6.1 +python-dateutil==2.8.1 +Flask==1.1.2 +pymongo==3.11.1 +unflatten==0.1 +fastnumbers==3.0.0 diff --git a/setup.py b/setup.py index 4890a0f7..6469c99a 100644 --- a/setup.py +++ b/setup.py @@ -27,12 +27,12 @@ test_suite='nose.collector', zip_safe=False, platforms='any', - setup_requires=[ + install_requires=[ 'Flask-MongoEngine', - 'mimerender', + 'mimerender @ git+https://github.com/tschaume/mimerender@mpcontribs#egg=mimerender', 'nose', 'python-dateutil', - 'cleancat' + 'unflatten' ], classifiers=[ 'Development Status :: 4 - Beta',