diff --git a/.changes/unreleased/Added-20250401-013109.yaml b/.changes/unreleased/Added-20250401-013109.yaml new file mode 100644 index 0000000..2381bf0 --- /dev/null +++ b/.changes/unreleased/Added-20250401-013109.yaml @@ -0,0 +1,6 @@ +kind: Added +body: Added support for driver `sqlite3` +time: 2025-04-01T01:31:09.0893687+02:00 +custom: + Author: rayakame + PR: "17" diff --git a/internal/codegen/common.go b/internal/codegen/common.go index 6c8c365..7bca5af 100644 --- a/internal/codegen/common.go +++ b/internal/codegen/common.go @@ -26,10 +26,13 @@ func NewDriver(conf *core.Config) (*Driver, error) { var connType string switch conf.SqlDriver { case core.SQLDriverAioSQLite: - buildPyQueryFunc = drivers.BuildPyQueryFunc - acceptedDriverCMDs = drivers.AcceptedDriverCMDs + buildPyQueryFunc = drivers.AioSQLiteBuildPyQueryFunc + acceptedDriverCMDs = drivers.AioSQLiteAcceptedDriverCMDs connType = drivers.AioSQLiteConn - + case core.SQLDriverSQLite: + buildPyQueryFunc = drivers.SQLite3BuildPyQueryFunc + acceptedDriverCMDs = drivers.SQLite3AcceptedDriverCMDs + connType = drivers.SQLite3Conn default: return nil, fmt.Errorf("unsupported driver: %s", conf.SqlDriver.String()) } diff --git a/internal/codegen/drivers/aiosqlite.go b/internal/codegen/drivers/aiosqlite.go index 72a68d9..f80529b 100644 --- a/internal/codegen/drivers/aiosqlite.go +++ b/internal/codegen/drivers/aiosqlite.go @@ -10,7 +10,7 @@ import ( const AioSQLiteConn = "aiosqlite.Connection" -func BuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuilder, argType string, retType string, isClass bool) error { +func AioSQLiteBuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuilder, argType string, retType string, isClass bool) error { indentLevel := 0 params := fmt.Sprintf("conn: %s", AioSQLiteConn) conn := "conn" @@ -132,7 +132,7 @@ func BuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuilder, arg return nil } -func AcceptedDriverCMDs() []string { +func AioSQLiteAcceptedDriverCMDs() []string { return []string{ metadata.CmdExec, metadata.CmdExecResult, diff --git a/internal/codegen/drivers/sqlite3.go b/internal/codegen/drivers/sqlite3.go new file mode 100644 index 0000000..d2352c4 --- /dev/null +++ b/internal/codegen/drivers/sqlite3.go @@ -0,0 +1,111 @@ +package drivers + +import ( + "fmt" + "github.com/rayakame/sqlc-gen-better-python/internal/codegen/builders" + "github.com/rayakame/sqlc-gen-better-python/internal/core" + "github.com/sqlc-dev/plugin-sdk-go/metadata" + "strconv" +) + +const SQLite3Conn = "sqlite3.Connection" + +func SQLite3BuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuilder, argType string, retType string, isClass bool) error { + indentLevel := 0 + params := fmt.Sprintf("conn: %s", SQLite3Conn) + conn := "conn" + if isClass { + params = "self" + conn = "self._conn" + indentLevel = 1 + } + body.WriteIndentedString(indentLevel, fmt.Sprintf("def %s(%s", query.FuncName, params)) + if argType != "" { + body.WriteString(fmt.Sprintf(", %s: %s", query.Arg.Name, argType)) + } + if query.Cmd == metadata.CmdExec { + body.WriteLine(fmt.Sprintf(") -> %s:", retType)) + body.WriteIndentedString(indentLevel+1, fmt.Sprintf("%s.execute(%s", conn, query.ConstantName)) + writeParams(query, body, argType) + body.WriteLine(")") + } else if query.Cmd == metadata.CmdExecResult { + body.WriteLine(fmt.Sprintf(") -> %s:", "sqlite3.Cursor")) + body.WriteIndentedString(indentLevel+1, fmt.Sprintf("%s.execute(%s", conn, query.ConstantName)) + writeParams(query, body, argType) + body.WriteLine(")") + } else if query.Cmd == metadata.CmdExecRows { + body.WriteLine(fmt.Sprintf(") -> %s:", retType)) + body.WriteIndentedString(indentLevel+1, fmt.Sprintf("%s.execute(%s", conn, query.ConstantName)) + writeParams(query, body, argType) + body.WriteLine(").rowcount") + } else if query.Cmd == metadata.CmdExecLastId { + body.WriteLine(fmt.Sprintf(") -> %s:", retType)) + body.WriteIndentedString(indentLevel+1, fmt.Sprintf("%s.execute(%s", conn, query.ConstantName)) + writeParams(query, body, argType) + body.WriteLine(").lastrowid") + } else if query.Cmd == metadata.CmdOne { + body.WriteLine(fmt.Sprintf(") -> typing.Optional[%s]:", retType)) + body.WriteIndentedString(indentLevel+1, fmt.Sprintf("row = %s.execute(%s", conn, query.ConstantName)) + writeParams(query, body, argType) + body.WriteLine(").fetchone()") + body.WriteIndentedLine(indentLevel+1, "if row is None:") + body.WriteIndentedLine(indentLevel+2, "return None") + if query.Ret.IsStruct() { + body.WriteIndentedString(indentLevel+1, fmt.Sprintf("return %s(", retType)) + for i, col := range query.Ret.Table.Columns { + if i != 0 { + body.WriteString(", ") + } + body.WriteString(fmt.Sprintf("%s=row[%s]", col.Name, strconv.Itoa(i))) + } + body.WriteLine(")") + } else { + body.WriteIndentedLine(indentLevel+1, fmt.Sprintf("return %s(row[0])", retType)) + } + } else if query.Cmd == metadata.CmdMany { + body.WriteLine(fmt.Sprintf(") -> typing.List[%s]:", retType)) + body.WriteIndentedLine(indentLevel+1, fmt.Sprintf("rows: typing.List[%s] = []", retType)) + body.WriteIndentedString(indentLevel+1, fmt.Sprintf("for row in %s.execute(%s", conn, query.ConstantName)) + writeParams(query, body, argType) + body.WriteLine(").fetchall():") + if query.Ret.IsStruct() { + body.WriteIndentedString(indentLevel+2, fmt.Sprintf("rows.append(%s(", retType)) + for i, col := range query.Ret.Table.Columns { + if i != 0 { + body.WriteString(", ") + } + body.WriteString(fmt.Sprintf("%s=row[%s]", col.Name, strconv.Itoa(i))) + } + body.WriteLine("))") + } else { + body.WriteIndentedLine(indentLevel+2, fmt.Sprintf("rows.append(%s(row[0]))", retType)) + } + body.WriteIndentedLine(indentLevel+1, "return rows") + } + return nil +} + +func SQLite3AcceptedDriverCMDs() []string { + return []string{ + metadata.CmdExec, + metadata.CmdExecResult, + metadata.CmdExecLastId, + metadata.CmdExecRows, + metadata.CmdOne, + metadata.CmdMany, + } +} + +func writeParams(query *core.Query, body *builders.IndentStringBuilder, argType string) { + if argType != "" { + params := "(" + if query.Arg.IsStruct() { + for _, col := range query.Arg.Table.Columns { + params += fmt.Sprintf("%s.%s, ", query.Arg.Name, col.Name) + } + } else { + params += fmt.Sprintf("%s, ", query.Arg.Name) + } + body.WriteString("," + params + ")") + } +} diff --git a/sqlc.yaml b/sqlc.yaml index 24cba0a..ec97390 100644 --- a/sqlc.yaml +++ b/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: python wasm: url: file://sqlc-gen-better-python.wasm - sha256: 67389a6e3bfdaf7e78ff7e85a4c497beb849aff7de9d4e283feda44ffe3f22a3 + sha256: be2dfe3e1b9afa91b81212ed983fc1e2b507c58c6ed9c757e902756dd1de1de9 sql: - schema: test/schema.sql queries: test/queries.sql @@ -13,7 +13,7 @@ sql: plugin: python options: package: test - sql_driver: aiosqlite + sql_driver: sqlite3 model_type: dataclass - emit_classes: true + emit_classes: false diff --git a/test/__init__.py b/test/__init__.py index 1d38848..9fe5f05 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -1,4 +1,4 @@ # Code generated by sqlc. DO NOT EDIT. # versions: # sqlc v1.28.0 -# sqlc-gen-better-python 0.0.1 +# sqlc-gen-better-python v0.0.1 diff --git a/test/models.py b/test/models.py index 514cb53..bd46a01 100644 --- a/test/models.py +++ b/test/models.py @@ -1,7 +1,7 @@ # Code generated by sqlc. DO NOT EDIT. # versions: # sqlc v1.28.0 -# sqlc-gen-better-python 0.0.1 +# sqlc-gen-better-python v0.0.1 from __future__ import annotations __all__: typing.Sequence[str] = ( diff --git a/test/queries.py b/test/queries.py index fae824a..c6eaa4f 100644 --- a/test/queries.py +++ b/test/queries.py @@ -1,22 +1,28 @@ # Code generated by sqlc. DO NOT EDIT. # versions: # sqlc v1.28.0 -# sqlc-gen-better-python 0.0.1 +# sqlc-gen-better-python v0.0.1 from __future__ import annotations __all__: typing.Sequence[str] = ( "CreateAuthorParams", "GetAuthorRow", - "Queries", "UpdateAuthorParams", "UpdateAuthorTParams", "UpsertAuthorNameParams", + "create_author", + "delete_author", + "get_author", + "list_authors", + "update_author", + "update_author_t", + "upsert_author_name", ) import dataclasses import typing -import aiosqlite +import sqlite3 from test import models @@ -101,43 +107,44 @@ class UpsertAuthorNameParams: """ -class Queries: - __slots__ = ("_conn",) +def create_author(conn: sqlite3.Connection, arg: CreateAuthorParams) -> typing.Optional[models.Author]: + row = conn.execute(CREATE_AUTHOR,(arg.name, arg.bio, )).fetchone() + if row is None: + return None + return models.Author(id=row[0], name=row[1], bio=row[2]) - def __init__(self, conn: aiosqlite.Connection): - self._conn = conn - async def create_author(self, arg: CreateAuthorParams) -> typing.Optional[models.Author]: - row = await (await self._conn.execute(CREATE_AUTHOR, arg.name, arg.bio)).fetchone() - if row is None: - return None - return models.Author(id=row[0], name=row[1], bio=row[2]) +def delete_author(conn: sqlite3.Connection, id: int) -> None: + conn.execute(DELETE_AUTHOR,(id, )) - async def delete_author(self, id: int) -> None: - await self._conn.execute(DELETE_AUTHOR, id) - async def get_author(self, id: int) -> typing.Optional[GetAuthorRow]: - row = await (await self._conn.execute(GET_AUTHOR, id)).fetchone() - if row is None: - return None - return GetAuthorRow(id=row[0], name=row[1]) +def get_author(conn: sqlite3.Connection, id: int) -> typing.Optional[GetAuthorRow]: + row = conn.execute(GET_AUTHOR,(id, )).fetchone() + if row is None: + return None + return GetAuthorRow(id=row[0], name=row[1]) - async def list_authors(self, ids: typing.Sequence[int]) -> typing.AsyncIterator[models.Author]: - stream = await self._conn.execute(LIST_AUTHORS, ids) - async for row in stream: - yield models.Author(id=row[0], name=row[1], bio=row[2]) - async def update_author(self, arg: UpdateAuthorParams) -> None: - await self._conn.execute(UPDATE_AUTHOR, arg.name, arg.bio, arg.id) +def list_authors(conn: sqlite3.Connection, ids: typing.Sequence[int]) -> typing.List[models.Author]: + rows: typing.List[models.Author] = [] + for row in conn.execute(LIST_AUTHORS,(ids, )).fetchall(): + rows.append(models.Author(id=row[0], name=row[1], bio=row[2])) + return rows - async def update_author_t(self, arg: UpdateAuthorTParams) -> typing.Optional[models.Author]: - row = await (await self._conn.execute(UPDATE_AUTHOR_T, arg.name, arg.bio, arg.id)).fetchone() - if row is None: - return None - return models.Author(id=row[0], name=row[1], bio=row[2]) - async def upsert_author_name(self, arg: UpsertAuthorNameParams) -> typing.Optional[models.Author]: - row = await (await self._conn.execute(UPSERT_AUTHOR_NAME, arg.set_name, arg.name)).fetchone() - if row is None: - return None - return models.Author(id=row[0], name=row[1], bio=row[2]) +def update_author(conn: sqlite3.Connection, arg: UpdateAuthorParams) -> None: + conn.execute(UPDATE_AUTHOR,(arg.name, arg.bio, arg.id, )) + + +def update_author_t(conn: sqlite3.Connection, arg: UpdateAuthorTParams) -> typing.Optional[models.Author]: + row = conn.execute(UPDATE_AUTHOR_T,(arg.name, arg.bio, arg.id, )).fetchone() + if row is None: + return None + return models.Author(id=row[0], name=row[1], bio=row[2]) + + +def upsert_author_name(conn: sqlite3.Connection, arg: UpsertAuthorNameParams) -> typing.Optional[models.Author]: + row = conn.execute(UPSERT_AUTHOR_NAME,(arg.set_name, arg.name, )).fetchone() + if row is None: + return None + return models.Author(id=row[0], name=row[1], bio=row[2])