diff --git a/casbin_async_sqlalchemy_adapter/__init__.py b/casbin_async_sqlalchemy_adapter/__init__.py index e66fe6e..e5b3515 100644 --- a/casbin_async_sqlalchemy_adapter/__init__.py +++ b/casbin_async_sqlalchemy_adapter/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .adapter import CasbinRule, Adapter, Base +from .adapter import CasbinRule, Adapter, Base, create_casbin_rule_model diff --git a/casbin_async_sqlalchemy_adapter/adapter.py b/casbin_async_sqlalchemy_adapter/adapter.py index 4af3e7c..66a8286 100644 --- a/casbin_async_sqlalchemy_adapter/adapter.py +++ b/casbin_async_sqlalchemy_adapter/adapter.py @@ -49,6 +49,36 @@ def __repr__(self): return ''.format(self.id, str(self)) +def create_casbin_rule_model(base, table_name="casbin_rule"): + """Create a CasbinRule model using the given declarative base for Alembic integration.""" + + class CasbinRuleModel(base): + __tablename__ = table_name + __table_args__ = {"extend_existing": True} + + id = Column(Integer, primary_key=True) + ptype = Column(String(255)) + v0 = Column(String(255)) + v1 = Column(String(255)) + v2 = Column(String(255)) + v3 = Column(String(255)) + v4 = Column(String(255)) + v5 = Column(String(255)) + + def __str__(self): + arr = [self.ptype] + for v in (self.v0, self.v1, self.v2, self.v3, self.v4, self.v5): + if v is None: + break + arr.append(v) + return ", ".join(arr) + + def __repr__(self): + return ''.format(self.id, str(self)) + + return CasbinRuleModel + + class Filter: ptype = [] v0 = []