Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ dependencies = [
"jsonpath-ng>=1.6.1, <2",
"Mako>=1.2.2",
"markdown>=3.0",
# marshmallow>=4 has issues: https://github.com/apache/superset/issues/33162
"marshmallow>=3.0, <4",
"marshmallow>=4",
"marshmallow-union>=0.1",
"msgpack>=1.0.0, <1.1",
"nh3>=0.2.11, <0.3",
Expand Down
3 changes: 2 additions & 1 deletion requirements/base.in
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ apispec>=6.0.0,<6.7.0
# causing CI to fail. 1.4.0 is the last version that works.
# https://marshmallow-sqlalchemy.readthedocs.io/en/latest/changelog.html#id3
# Opened this issue https://github.com/marshmallow-code/marshmallow-sqlalchemy/issues/665
marshmallow-sqlalchemy>=1.3.0,<1.4.1
# Update: Upgrading to 1.4.2+ for marshmallow 4.x compatibility
marshmallow-sqlalchemy>=1.4.2

# needed for python 3.12 support
openapi-schema-validator>=0.6.3
4 changes: 2 additions & 2 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,13 @@ markupsafe==3.0.2
# mako
# werkzeug
# wtforms
marshmallow==3.26.1
marshmallow==4.0.0
# via
# apache-superset (pyproject.toml)
# flask-appbuilder
# marshmallow-sqlalchemy
# marshmallow-union
marshmallow-sqlalchemy==1.4.0
marshmallow-sqlalchemy==1.4.2
# via
# -r requirements/base.in
# flask-appbuilder
Expand Down
4 changes: 2 additions & 2 deletions requirements/development.txt
Original file line number Diff line number Diff line change
Expand Up @@ -443,14 +443,14 @@ markupsafe==3.0.2
# mako
# werkzeug
# wtforms
marshmallow==3.26.1
marshmallow==4.0.0
# via
# -c requirements/base-constraint.txt
# apache-superset
# flask-appbuilder
# marshmallow-sqlalchemy
# marshmallow-union
marshmallow-sqlalchemy==1.4.0
marshmallow-sqlalchemy==1.4.2
# via
# -c requirements/base-constraint.txt
# flask-appbuilder
Expand Down
10 changes: 10 additions & 0 deletions superset/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@

logger = logging.getLogger(__name__)

# Apply marshmallow 4.x compatibility patch for Flask-AppBuilder
try:
from superset.marshmallow_compatibility import (
patch_marshmallow_for_flask_appbuilder,
)

patch_marshmallow_for_flask_appbuilder()
except ImportError:
logger.debug("marshmallow_compatibility module not found, skipping patch")


def create_app(
superset_config_module: Optional[str] = None,
Expand Down
10 changes: 6 additions & 4 deletions superset/charts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ class ChartDataAggregateOptionsSchema(ChartDataPostProcessingOperationOptionsSch
allow_none=False,
metadata={"description": "Columns by which to group by"},
),
minLength=1,
validate=Length(min=1),
required=True,
),
)
Expand Down Expand Up @@ -657,7 +657,9 @@ class ChartDataProphetOptionsSchema(ChartDataPostProcessingOperationOptionsSchem
"the future",
"example": 7,
},
min=0,
validate=[
Range(min=0, error=_("`periods` must be greater than or equal to 0"))
],
required=True,
)
confidence_interval = fields.Float(
Expand Down Expand Up @@ -791,7 +793,7 @@ class ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema)
fields.List(
fields.String(allow_none=False),
metadata={"description": "Columns to group by on the table index (=rows)"},
minLength=1,
validate=Length(min=1),
required=True,
),
)
Expand Down Expand Up @@ -1643,7 +1645,7 @@ class DashboardSchema(Schema):


class ChartGetResponseSchema(Schema):
id = fields.Int(description=id_description)
id = fields.Int(metadata={"description": id_description})
url = fields.String()
cache_timeout = fields.String()
certified_by = fields.String()
Expand Down
2 changes: 1 addition & 1 deletion superset/databases/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,7 @@ def _deserialize(

class BaseUploadFilePostSchemaMixin(Schema):
@validates("file")
def validate_file_extension(self, file: FileStorage) -> None:
def validate_file_extension(self, file: FileStorage, **kwargs: Any) -> None:
allowed_extensions = current_app.config["ALLOWED_EXTENSIONS"]
file_suffix = Path(file.filename).suffix
if not file_suffix:
Expand Down
21 changes: 14 additions & 7 deletions superset/db_engine_specs/databend.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,20 +190,27 @@ def get_function_names(cls, database: Database) -> list[str]:


class DatabendParametersSchema(Schema):
username = fields.String(allow_none=True, description=__("Username"))
password = fields.String(allow_none=True, description=__("Password"))
host = fields.String(required=True, description=__("Hostname or IP address"))
username = fields.String(allow_none=True, metadata={"description": __("Username")})
password = fields.String(allow_none=True, metadata={"description": __("Password")})
host = fields.String(
required=True, metadata={"description": __("Hostname or IP address")}
)
port = fields.Integer(
allow_none=True,
description=__("Database port"),
metadata={"description": __("Database port")},
validate=Range(min=0, max=65535),
)
database = fields.String(allow_none=True, description=__("Database name"))
database = fields.String(
allow_none=True, metadata={"description": __("Database name")}
)
encryption = fields.Boolean(
default=True, description=__("Use an encrypted connection to the database")
dump_default=True,
metadata={"description": __("Use an encrypted connection to the database")},
)
query = fields.Dict(
keys=fields.Str(), values=fields.Raw(), description=__("Additional parameters")
keys=fields.Str(),
values=fields.Raw(),
metadata={"description": __("Additional parameters")},
)


Expand Down
92 changes: 92 additions & 0 deletions superset/marshmallow_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Marshmallow 4.x Compatibility Module for Flask-AppBuilder 5.0.0

This module provides compatibility between Flask-AppBuilder 5.0.0 and
marshmallow 4.x, specifically handling missing auto-generated fields
during schema initialization.
"""

import logging
from typing import Any, TYPE_CHECKING

from marshmallow import fields

if TYPE_CHECKING:
import marshmallow

logger = logging.getLogger(__name__)


def patch_marshmallow_for_flask_appbuilder() -> None:
"""
Patches marshmallow Schema._init_fields to handle Flask-AppBuilder 5.0.0
compatibility with marshmallow 4.x.

Flask-AppBuilder 5.0.0 automatically generates schema fields that reference
SQL relationship fields that may not exist in marshmallow 4.x's stricter
field validation. This patch dynamically adds missing fields as Raw fields
to prevent KeyError exceptions during schema initialization.
"""
import marshmallow

# Store the original method
original_init_fields = marshmallow.Schema._init_fields

def patched_init_fields(self: "marshmallow.Schema") -> Any:
"""Patched version that handles missing declared fields."""
max_retries = 10 # Prevent infinite loops in case of unexpected errors
retries = 0

while retries < max_retries:
try:
return original_init_fields(self)
except KeyError as e:
# Extract the missing field name from the KeyError
missing_field = str(e).strip("'\"")

# Initialize declared_fields if it doesn't exist
if not hasattr(self, "declared_fields"):
self.declared_fields = {}

# Only add if it doesn't already exist
if missing_field not in self.declared_fields:
# Use Raw field as a safe fallback for unknown auto-generated
# fields. Allow both load and dump to support both input
# validation and serialization
self.declared_fields[missing_field] = fields.Raw(
allow_none=True,
load_default=None, # Optional field (defaults to None)
)

logger.debug(
"Marshmallow compatibility: Added missing field "
"'%s' as Raw field",
missing_field,
)

retries += 1
# Continue the loop to retry initialization

# If we've exhausted retries, something is seriously wrong
raise RuntimeError(
f"Marshmallow field initialization failed after {max_retries} retries"
)

# Apply the patch
marshmallow.Schema._init_fields = patched_init_fields
1 change: 1 addition & 0 deletions superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,7 @@ def ensure_extra_json_is_not_none(
self,
_: str,
value: Optional[dict[str, Any]],
**kwargs: Any,
) -> Any:
if value is None:
return "{}"
Expand Down
6 changes: 4 additions & 2 deletions superset/reports/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,14 @@ class ReportSchedulePostSchema(Schema):
},
allow_none=True,
required=False,
default=None,
load_default=None,
)

@validates("custom_width")
def validate_custom_width(
self,
value: Optional[int],
**kwargs: Any,
) -> None:
if value is None:
return
Expand Down Expand Up @@ -378,13 +379,14 @@ class ReportSchedulePutSchema(Schema):
},
allow_none=True,
required=False,
default=None,
load_default=None,
)

@validates("custom_width")
def validate_custom_width(
self,
value: Optional[int],
**kwargs: Any,
) -> None:
if value is None:
return
Expand Down
22 changes: 14 additions & 8 deletions superset/themes/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from contextvars import ContextVar
from typing import Any

from marshmallow import fields, Schema, validates, ValidationError

from superset.themes.utils import is_valid_theme, sanitize_theme_tokens
from superset.utils import json

# Context variable for storing sanitized JSON data during validation
sanitized_json_context: ContextVar[str | None] = ContextVar(
"sanitized_json_data", default=None
)


class ImportV1ThemeSchema(Schema):
theme_name = fields.String(required=True)
Expand All @@ -29,7 +35,7 @@ class ImportV1ThemeSchema(Schema):
version = fields.String(required=True)

@validates("json_data")
def validate_json_data(self, value: dict[str, Any]) -> None:
def validate_json_data(self, value: dict[str, Any], **kwargs: Any) -> None:
# Convert dict to JSON string for validation
if isinstance(value, dict):
json_str = json.dumps(value)
Expand All @@ -56,20 +62,20 @@ def validate_json_data(self, value: dict[str, Any]) -> None:
value.clear()
value.update(sanitized_config)
else:
self.context["sanitized_json_data"] = json.dumps(sanitized_config)
sanitized_json_context.set(json.dumps(sanitized_config))


class ThemePostSchema(Schema):
theme_name = fields.String(required=True, allow_none=False)
json_data = fields.String(required=True, allow_none=False)

@validates("theme_name")
def validate_theme_name(self, value: str) -> None:
def validate_theme_name(self, value: str, **kwargs: Any) -> None:
if not value or not value.strip():
raise ValidationError("Theme name cannot be empty.")

@validates("json_data")
def validate_and_sanitize_json_data(self, value: str) -> None:
def validate_and_sanitize_json_data(self, value: str, **kwargs: Any) -> None:
# Parse JSON
try:
theme_config = json.loads(value) if isinstance(value, str) else value
Expand All @@ -87,20 +93,20 @@ def validate_and_sanitize_json_data(self, value: str) -> None:
# Note: This modifies the input data to ensure sanitized content is stored
if sanitized_config != theme_config:
# Re-serialize the sanitized config
self.context["sanitized_json_data"] = json.dumps(sanitized_config)
sanitized_json_context.set(json.dumps(sanitized_config))


class ThemePutSchema(Schema):
theme_name = fields.String(required=True, allow_none=False)
json_data = fields.String(required=True, allow_none=False)

@validates("theme_name")
def validate_theme_name(self, value: str) -> None:
def validate_theme_name(self, value: str, **kwargs: Any) -> None:
if not value or not value.strip():
raise ValidationError("Theme name cannot be empty.")

@validates("json_data")
def validate_and_sanitize_json_data(self, value: str) -> None:
def validate_and_sanitize_json_data(self, value: str, **kwargs: Any) -> None:
# Parse JSON
try:
theme_config = json.loads(value) if isinstance(value, str) else value
Expand All @@ -118,7 +124,7 @@ def validate_and_sanitize_json_data(self, value: str) -> None:
# Note: This modifies the input data to ensure sanitized content is stored
if sanitized_config != theme_config:
# Re-serialize the sanitized config
self.context["sanitized_json_data"] = json.dumps(sanitized_config)
sanitized_json_context.set(json.dumps(sanitized_config))


openapi_spec_methods_override = {
Expand Down
Loading
Loading