Skip to content

Commit 0d2b660

Browse files
authored
feat: add custom table names for CasbinRule class and update tests
1 parent 9bb1a87 commit 0d2b660

File tree

3 files changed

+91
-25
lines changed

3 files changed

+91
-25
lines changed

README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ pip install sqlalchemy_adapter
2929

3030
## Simple Example
3131

32+
You can save and load policy to database.
33+
3234
```python
3335
import sqlalchemy_adapter
3436
import casbin
@@ -49,6 +51,22 @@ else:
4951
pass
5052
```
5153

54+
By default, policies are stored in the `casbin_rule` table.
55+
You can custom the table where the policy is stored by using the `table_name` parameter.
56+
57+
```python
58+
59+
import sqlalchemy_adapter
60+
import casbin
61+
62+
custom_table_name = "<custom_table_name>"
63+
64+
# create adapter with custom table name.
65+
adapter = sqlalchemy_adapter.Adapter('sqlite:///test.db', table_name=custom_table_name)
66+
67+
e = casbin.Enforcer('path/to/model.conf', adapter)
68+
```
69+
5270

5371
### Getting Help
5472

sqlalchemy_adapter/adapter.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,28 +19,46 @@ class Base(DeclarativeBase):
1919
pass
2020

2121

22-
class CasbinRule(Base):
23-
__tablename__ = "casbin_rule"
24-
25-
id = Column(Integer, primary_key=True)
26-
ptype = Column(String(255))
27-
v0 = Column(String(255))
28-
v1 = Column(String(255))
29-
v2 = Column(String(255))
30-
v3 = Column(String(255))
31-
v4 = Column(String(255))
32-
v5 = Column(String(255))
33-
34-
def __str__(self):
35-
arr = [self.ptype]
36-
for v in (self.v0, self.v1, self.v2, self.v3, self.v4, self.v5):
37-
if v is None:
38-
break
39-
arr.append(v)
40-
return ", ".join(arr)
22+
def create_casbin_rule_class(table_name):
23+
"""
24+
Factory function to create a CasbinRule class with a custom table name.
25+
26+
Args:
27+
table_name (str): Table name for the CasbinRule class.
28+
29+
Returns:
30+
db_class (CasbinRule): The CasbinRule class.
31+
"""
32+
33+
class CasbinRule(Base):
34+
__tablename__ = table_name
35+
__table_args__ = {"extend_existing": True}
36+
37+
id = Column(Integer, primary_key=True)
38+
ptype = Column(String(255))
39+
v0 = Column(String(255))
40+
v1 = Column(String(255))
41+
v2 = Column(String(255))
42+
v3 = Column(String(255))
43+
v4 = Column(String(255))
44+
v5 = Column(String(255))
45+
46+
def __str__(self):
47+
arr = [self.ptype]
48+
for v in (self.v0, self.v1, self.v2, self.v3, self.v4, self.v5):
49+
if v is None:
50+
break
51+
arr.append(v)
52+
return ", ".join(arr)
4153

42-
def __repr__(self):
43-
return '<CasbinRule {}: "{}">'.format(self.id, str(self))
54+
def __repr__(self):
55+
return '<CasbinRule {}: "{}">'.format(self.id, str(self))
56+
57+
return CasbinRule
58+
59+
60+
# Export the default CasbinRule class with table name 'casbin_rule'.
61+
CasbinRule = create_casbin_rule_class("casbin_rule")
4462

4563

4664
class Filter:
@@ -56,14 +74,20 @@ class Filter:
5674
class Adapter(persist.Adapter, persist.adapters.UpdateAdapter):
5775
"""the interface for Casbin adapters."""
5876

59-
def __init__(self, engine, db_class=None, filtered=False):
77+
def __init__(
78+
self,
79+
engine,
80+
db_class=None,
81+
table_name="casbin_rule",
82+
filtered=False,
83+
):
6084
if isinstance(engine, str):
6185
self._engine = create_engine(engine)
6286
else:
6387
self._engine = engine
6488

6589
if db_class is None:
66-
db_class = CasbinRule
90+
db_class = create_casbin_rule_class(table_name=table_name)
6791
else:
6892
for attr in (
6993
"id",
@@ -281,7 +305,6 @@ def _update_filtered_policies(self, new_rules, filter) -> [[str]]:
281305
"""_update_filtered_policies updates all the policies on the basis of the filter."""
282306

283307
with self._session_scope() as session:
284-
285308
# Load old policies
286309

287310
query = session.query(self._db_class).filter(

tests/test_adapter.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sqlalchemy_adapter import Adapter
99
from sqlalchemy_adapter import Base
1010
from sqlalchemy_adapter import CasbinRule
11-
from sqlalchemy_adapter.adapter import Filter
11+
from sqlalchemy_adapter.adapter import Filter, create_casbin_rule_class
1212

1313

1414
def get_fixture(path):
@@ -36,6 +36,25 @@ def get_enforcer():
3636
return casbin.Enforcer(get_fixture("rbac_model.conf"), adapter)
3737

3838

39+
def get_custom_table_name_enforcer():
40+
engine = create_engine("sqlite://")
41+
table_name = "custom_casbin_rule_table"
42+
adapter = Adapter(engine, table_name=table_name)
43+
44+
session = sessionmaker(bind=engine)
45+
Base.metadata.create_all(engine)
46+
s = session()
47+
48+
CustomTableCasbinRule = create_casbin_rule_class(table_name)
49+
50+
s.query(CustomTableCasbinRule).delete()
51+
s.add(CustomTableCasbinRule(ptype="p", v0="alice", v1="data1", v2="read"))
52+
s.commit()
53+
s.close()
54+
55+
return casbin.Enforcer(get_fixture("rbac_model.conf"), adapter)
56+
57+
3958
class TestConfig(TestCase):
4059
def test_custom_db_class(self):
4160
class CustomRule(Base):
@@ -61,6 +80,12 @@ class CustomRule(Base):
6180
s.commit()
6281
self.assertEqual(s.query(CustomRule).all()[0].not_exist, "NotNone")
6382

83+
def test_custom_table_name(self):
84+
e = get_custom_table_name_enforcer()
85+
86+
self.assertTrue(e.enforce("alice", "data1", "read"))
87+
self.assertFalse(e.enforce("bob", "data2", "write"))
88+
6489
def test_enforcer_basic(self):
6590
e = get_enforcer()
6691

0 commit comments

Comments
 (0)