diff --git a/CHANGES.md b/CHANGES.md index 1abfc15..fee2433 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,7 @@ +## Version 0.1.1 + +Adds a `get_recorded_queries` helper, in line with [Flask-SQLAlchemy's implementation](https://github.com/pallets-eco/flask-sqlalchemy/blob/3e3e92ba557649ab5251eda860a67656cc8c10af/src/flask_sqlalchemy/record_queries.py), which can be used to track SQL queries emitted by SQLAlchemy. Useful for supporting integrations with other Flask extensions or for aiding development and testing. + ## Version 0.1.0 Released 2024-06-07 diff --git a/src/flask_sqlalchemy_lite/_extension.py b/src/flask_sqlalchemy_lite/_extension.py index 530c648..024cab5 100644 --- a/src/flask_sqlalchemy_lite/_extension.py +++ b/src/flask_sqlalchemy_lite/_extension.py @@ -83,6 +83,12 @@ def init_app(self, app: App) -> None: app.teardown_appcontext(_close_async_sessions) app.shell_context_processor(add_models_to_shell) + if app.config.setdefault("SQLALCHEMY_RECORD_QUERIES", False): + from . import record_queries + + for engine in engines.values(): + record_queries._listen(engine) + def _get_state(self) -> _State: app = current_app._get_current_object() # type: ignore[attr-defined] diff --git a/src/flask_sqlalchemy_lite/record_queries.py b/src/flask_sqlalchemy_lite/record_queries.py new file mode 100644 index 0000000..e8273be --- /dev/null +++ b/src/flask_sqlalchemy_lite/record_queries.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import dataclasses +import inspect +import typing as t +from time import perf_counter + +import sqlalchemy as sa +import sqlalchemy.event as sa_event +from flask import current_app +from flask import g +from flask import has_app_context + + +def get_recorded_queries() -> list[_QueryInfo]: + """Get the list of recorded query information for the current session. Queries are + recorded if the config :data:`.SQLALCHEMY_RECORD_QUERIES` is enabled. + + Each query info object has the following attributes: + + ``statement`` + The string of SQL generated by SQLAlchemy with parameter placeholders. + ``parameters`` + The parameters sent with the SQL statement. + ``start_time`` / ``end_time`` + Timing info about when the query started execution and when the results where + returned. Accuracy and value depends on the operating system. + ``duration`` + The time the query took in seconds. + ``location`` + A string description of where in your application code the query was executed. + This may not be possible to calculate, and the format is not stable. + + .. versionchanged:: 3.0 + Renamed from ``get_debug_queries``. + + .. versionchanged:: 3.0 + The info object is a dataclass instead of a tuple. + + .. versionchanged:: 3.0 + The info object attribute ``context`` is renamed to ``location``. + + .. versionchanged:: 3.0 + Not enabled automatically in debug or testing mode. + """ + return g.get("_sqlalchemy_queries", []) # type: ignore[no-any-return] + + +@dataclasses.dataclass +class _QueryInfo: + """Information about an executed query. Returned by :func:`get_recorded_queries`. + + .. versionchanged:: 3.0 + Renamed from ``_DebugQueryTuple``. + + .. versionchanged:: 3.0 + Changed to a dataclass instead of a tuple. + + .. versionchanged:: 3.0 + ``context`` is renamed to ``location``. + """ + + statement: str | None + parameters: t.Any + start_time: float + end_time: float + location: str + + @property + def duration(self) -> float: + return self.end_time - self.start_time + + +def _listen(engine: sa.engine.Engine) -> None: + sa_event.listen(engine, "before_cursor_execute", _record_start, named=True) + sa_event.listen(engine, "after_cursor_execute", _record_end, named=True) + + +def _record_start(context: sa.engine.ExecutionContext, **kwargs: t.Any) -> None: + if not has_app_context(): + return + + context._fsa_start_time = perf_counter() # type: ignore[attr-defined] + + +def _record_end(context: sa.engine.ExecutionContext, **kwargs: t.Any) -> None: + if not has_app_context(): + return + + if "_sqlalchemy_queries" not in g: + g._sqlalchemy_queries = [] + + import_top = current_app.import_name.partition(".")[0] + import_dot = f"{import_top}." + frame = inspect.currentframe() + + while frame: + name = frame.f_globals.get("__name__") + + if name and (name == import_top or name.startswith(import_dot)): + code = frame.f_code + location = f"{code.co_filename}:{frame.f_lineno} ({code.co_name})" + break + + frame = frame.f_back + else: + location = "" + + g._sqlalchemy_queries.append( + _QueryInfo( + statement=context.statement, + parameters=context.parameters, + start_time=context._fsa_start_time, # type: ignore[attr-defined] + end_time=perf_counter(), + location=location, + ) + ) diff --git a/tests/test_record_queries.py b/tests/test_record_queries.py new file mode 100644 index 0000000..e2da348 --- /dev/null +++ b/tests/test_record_queries.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import os + +import pytest +import sqlalchemy as sa +import sqlalchemy.orm as sa_orm +from flask import Flask + +from flask_sqlalchemy_lite import SQLAlchemy +from flask_sqlalchemy_lite.record_queries import get_recorded_queries + + +class Base(sa_orm.DeclarativeBase): + pass + + +class Todo(Base): + __tablename__ = "todo" + id: sa_orm.Mapped[int] = sa_orm.mapped_column(primary_key=True) + + +@pytest.mark.usefixtures("app_ctx") +def test_query_info(app: Flask) -> None: + app.config["SQLALCHEMY_RECORD_QUERIES"] = True + db = SQLAlchemy(app) + Base.metadata.create_all(db.engine) + db.session.execute(sa.select(Todo).filter(Todo.id < 5)).scalars() + info = get_recorded_queries()[-1] + assert info.statement is not None + assert "SELECT" in info.statement + assert "FROM todo" in info.statement + assert info.parameters[0][0] == 5 + assert info.duration == info.end_time - info.start_time + assert os.path.join("tests", "test_record_queries.py:") in info.location + assert "(test_query_info)" in info.location