Skip to content

Commit 8c105ef

Browse files
authored
feat: add soft deletion support (#20)
1 parent d44ae77 commit 8c105ef

File tree

3 files changed

+554
-26
lines changed

3 files changed

+554
-26
lines changed

README.md

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,75 @@ async with async_session() as session:
132132
await session.commit()
133133
```
134134

135+
## Soft Deletion Support
136+
137+
The adapter supports soft deletion, which marks records as deleted instead of physically removing them from the database. This is useful for:
138+
139+
- Maintaining audit trails
140+
- Implementing undo functionality
141+
- Preserving historical data
142+
- Debugging and compliance requirements
143+
144+
### Basic Usage with Soft Deletion
145+
146+
To enable soft deletion, you need to:
147+
148+
1. Create a custom database model with a boolean `is_deleted` column
149+
2. Pass the soft delete attribute to the adapter
150+
151+
```python
152+
import casbin_async_sqlalchemy_adapter
153+
import casbin
154+
from sqlalchemy import Column, Boolean, Integer, String
155+
from sqlalchemy.ext.asyncio import create_async_engine
156+
157+
# Define a custom model with soft delete support
158+
class CasbinRuleSoftDelete(casbin_async_sqlalchemy_adapter.Base):
159+
__tablename__ = "casbin_rule"
160+
161+
id = Column(Integer, primary_key=True)
162+
ptype = Column(String(255))
163+
v0 = Column(String(255))
164+
v1 = Column(String(255))
165+
v2 = Column(String(255))
166+
v3 = Column(String(255))
167+
v4 = Column(String(255))
168+
v5 = Column(String(255))
169+
170+
# Add the soft delete column
171+
is_deleted = Column(Boolean, default=False, index=True, nullable=False)
172+
173+
# Create adapter with soft delete support
174+
engine = create_async_engine('sqlite+aiosqlite:///test.db')
175+
adapter = casbin_async_sqlalchemy_adapter.Adapter(
176+
engine,
177+
db_class=CasbinRuleSoftDelete,
178+
db_class_softdelete_attribute=CasbinRuleSoftDelete.is_deleted
179+
)
180+
181+
# Create the table
182+
await adapter.create_table()
183+
184+
e = casbin.AsyncEnforcer('path/to/model.conf', adapter)
185+
186+
# When you delete a policy, it will be soft-deleted (marked as deleted)
187+
await e.delete_permission_for_user("alice", "data1", "read")
188+
189+
# The record remains in the database with is_deleted=True
190+
# Load policy will automatically filter out soft-deleted records
191+
await e.load_policy()
192+
```
193+
194+
### How Soft Deletion Works
195+
196+
When soft deletion is enabled:
197+
198+
- **Delete operations** set the `is_deleted` flag to `True` instead of removing records
199+
- **Load operations** automatically filter out records where `is_deleted=True`
200+
- **Save policy** marks removed rules as deleted while preserving the records
201+
- **Update operations** only affect non-deleted records
202+
203+
This feature maintains full backward compatibility - when `db_class_softdelete_attribute` is not provided, the adapter functions with hard deletion as before.
135204

136205
### Getting Help
137206

casbin_async_sqlalchemy_adapter/adapter.py

Lines changed: 141 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
from casbin import persist
1818
from casbin.persist.adapters.asyncio import AsyncAdapter
19-
from sqlalchemy import Column, Integer, String, delete, insert
20-
from sqlalchemy import or_
19+
from sqlalchemy import Column, Integer, String, Boolean, delete, insert
20+
from sqlalchemy import or_, not_
2121
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
2222
from sqlalchemy.future import select
2323
from sqlalchemy.orm import declarative_base, sessionmaker
@@ -66,6 +66,7 @@ def __init__(
6666
self,
6767
engine,
6868
db_class=None,
69+
db_class_softdelete_attribute=None,
6970
filtered=False,
7071
db_session: Optional[AsyncSession] = None,
7172
):
@@ -74,9 +75,18 @@ def __init__(
7475
else:
7576
self._engine = engine
7677

78+
self.softdelete_attribute = None
79+
7780
if db_class is None:
7881
db_class = CasbinRule
7982
else:
83+
if db_class_softdelete_attribute is not None and not isinstance(db_class_softdelete_attribute.type, Boolean):
84+
msg = f"The type of db_class_softdelete_attribute needs to be {str(Boolean)!r}. "
85+
msg += f"An attribute of type {str(type(db_class_softdelete_attribute.type))!r} was given."
86+
raise ValueError(msg)
87+
# Softdelete is only supported when using custom class
88+
self.softdelete_attribute = db_class_softdelete_attribute
89+
8090
for attr in (
8191
"id",
8292
"ptype",
@@ -121,7 +131,9 @@ async def create_table(self):
121131
async def load_policy(self, model):
122132
"""loads all policy rules from the storage."""
123133
async with self._session_scope() as session:
124-
lines = await session.execute(select(self._db_class))
134+
stmt = select(self._db_class)
135+
stmt = self._softdelete_query(stmt)
136+
lines = await session.execute(stmt)
125137
for line in lines.scalars():
126138
persist.load_policy_line(str(line), model)
127139

@@ -132,6 +144,7 @@ async def load_filtered_policy(self, model, filter) -> None:
132144
"""loads all policy rules from the storage."""
133145
async with self._session_scope() as session:
134146
stmt = select(self._db_class)
147+
stmt = self._softdelete_query(stmt)
135148
stmt = self.filter_query(stmt, filter)
136149
result = await session.execute(stmt)
137150
for line in result.scalars():
@@ -144,6 +157,12 @@ def filter_query(self, stmt, filter):
144157
stmt = stmt.where(getattr(self._db_class, attr).in_(getattr(filter, attr)))
145158
return stmt.order_by(self._db_class.id)
146159

160+
def _softdelete_query(self, stmt):
161+
"""Filter out soft-deleted records if soft delete is enabled."""
162+
if self.softdelete_attribute is not None:
163+
stmt = stmt.where(not_(self.softdelete_attribute))
164+
return stmt
165+
147166
async def _save_policy_line(self, ptype, rule, session=None):
148167
if session is not None:
149168
# Use provided session
@@ -161,15 +180,62 @@ async def _save_policy_line(self, ptype, rule, session=None):
161180

162181
async def save_policy(self, model):
163182
"""saves all policy rules to the storage."""
183+
# Use the default strategy when soft delete is not enabled
184+
if self.softdelete_attribute is None:
185+
async with self._session_scope() as session:
186+
stmt = delete(self._db_class)
187+
await session.execute(stmt)
188+
for sec in ["p", "g"]:
189+
if sec not in model.model.keys():
190+
continue
191+
for ptype, ast in model.model[sec].items():
192+
for rule in ast.policy:
193+
await self._save_policy_line(ptype, rule, session)
194+
return True
195+
196+
# Custom strategy for softdelete since it does not make sense to recreate all of the
197+
# entries when using soft delete
164198
async with self._session_scope() as session:
165-
stmt = delete(self._db_class)
166-
await session.execute(stmt)
199+
stmt = select(self._db_class)
200+
stmt = self._softdelete_query(stmt)
201+
202+
# Get entries that are not part of the model anymore
203+
result = await session.execute(stmt)
204+
lines_before_changes = result.scalars().all()
205+
206+
# Create new entries in the database
167207
for sec in ["p", "g"]:
168208
if sec not in model.model.keys():
169209
continue
170210
for ptype, ast in model.model[sec].items():
171211
for rule in ast.policy:
172-
await self._save_policy_line(ptype, rule, session)
212+
# Filter for rule in the database
213+
filter_stmt = select(self._db_class).where(self._db_class.ptype == ptype)
214+
filter_stmt = self._softdelete_query(filter_stmt)
215+
for index, value in enumerate(rule):
216+
v_value = getattr(self._db_class, "v{}".format(index))
217+
filter_stmt = filter_stmt.where(v_value == value)
218+
# If the rule is not present, create an entry in the database
219+
result = await session.execute(filter_stmt)
220+
if result.scalar_one_or_none() is None:
221+
await self._save_policy_line(ptype, rule, session=session)
222+
223+
for line in lines_before_changes:
224+
ptype = line.ptype
225+
sec = ptype[0] # derived from persist.load_policy_line function
226+
fields_with_None = [
227+
line.v0,
228+
line.v1,
229+
line.v2,
230+
line.v3,
231+
line.v4,
232+
line.v5,
233+
]
234+
rule = [element for element in fields_with_None if element is not None]
235+
# If the rule is not part of the model, set the deletion flag to True
236+
if not model.has_policy(sec, ptype, rule):
237+
setattr(line, self.softdelete_attribute.name, True)
238+
173239
return True
174240

175241
async def add_policy(self, sec, ptype, rule):
@@ -196,42 +262,75 @@ async def add_policies(self, sec, ptype, rules):
196262
async def remove_policy(self, sec, ptype, rule):
197263
"""removes a policy rule from the storage."""
198264
async with self._session_scope() as session:
199-
stmt = delete(self._db_class).where(self._db_class.ptype == ptype)
200-
for i, v in enumerate(rule):
201-
stmt = stmt.where(getattr(self._db_class, "v{}".format(i)) == v)
202-
r = await session.execute(stmt)
203-
204-
return True if r.rowcount > 0 else False
265+
if self.softdelete_attribute is None:
266+
stmt = delete(self._db_class).where(self._db_class.ptype == ptype)
267+
for i, v in enumerate(rule):
268+
stmt = stmt.where(getattr(self._db_class, "v{}".format(i)) == v)
269+
r = await session.execute(stmt)
270+
return True if r.rowcount > 0 else False
271+
else:
272+
stmt = select(self._db_class).where(self._db_class.ptype == ptype)
273+
stmt = self._softdelete_query(stmt)
274+
for i, v in enumerate(rule):
275+
stmt = stmt.where(getattr(self._db_class, "v{}".format(i)) == v)
276+
result = await session.execute(stmt)
277+
lines = result.scalars().all()
278+
for line in lines:
279+
setattr(line, self.softdelete_attribute.name, True)
280+
return True if len(lines) > 0 else False
205281

206282
async def remove_policies(self, sec, ptype, rules):
207283
"""remove policy rules from the storage."""
208284
if not rules:
209285
return
210286
async with self._session_scope() as session:
211-
stmt = delete(self._db_class).where(self._db_class.ptype == ptype)
212-
rules = zip(*rules)
213-
for i, rule in enumerate(rules):
214-
stmt = stmt.where(or_(getattr(self._db_class, "v{}".format(i)) == v for v in rule))
215-
await session.execute(stmt)
287+
if self.softdelete_attribute is None:
288+
stmt = delete(self._db_class).where(self._db_class.ptype == ptype)
289+
rules_zipped = zip(*rules)
290+
for i, rule in enumerate(rules_zipped):
291+
stmt = stmt.where(or_(getattr(self._db_class, "v{}".format(i)) == v for v in rule))
292+
await session.execute(stmt)
293+
else:
294+
stmt = select(self._db_class).where(self._db_class.ptype == ptype)
295+
stmt = self._softdelete_query(stmt)
296+
rules_zipped = zip(*rules)
297+
for i, rule in enumerate(rules_zipped):
298+
stmt = stmt.where(or_(getattr(self._db_class, "v{}".format(i)) == v for v in rule))
299+
result = await session.execute(stmt)
300+
lines = result.scalars().all()
301+
for line in lines:
302+
setattr(line, self.softdelete_attribute.name, True)
216303

217304
async def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
218305
"""removes policy rules that match the filter from the storage.
219306
This is part of the Auto-Save feature.
220307
"""
221308
async with self._session_scope() as session:
222-
stmt = delete(self._db_class).where(self._db_class.ptype == ptype)
223-
224309
if not (0 <= field_index <= 5):
225310
return False
226311
if not (1 <= field_index + len(field_values) <= 6):
227312
return False
228-
for i, v in enumerate(field_values):
229-
if v != "":
230-
v_value = getattr(self._db_class, "v{}".format(field_index + i))
231-
stmt = stmt.where(v_value == v)
232-
r = await session.execute(stmt)
233313

234-
return True if r.rowcount > 0 else False
314+
if self.softdelete_attribute is None:
315+
stmt = delete(self._db_class).where(self._db_class.ptype == ptype)
316+
for i, v in enumerate(field_values):
317+
if v != "":
318+
v_value = getattr(self._db_class, "v{}".format(field_index + i))
319+
stmt = stmt.where(v_value == v)
320+
r = await session.execute(stmt)
321+
return True if r.rowcount > 0 else False
322+
else:
323+
stmt = select(self._db_class).where(self._db_class.ptype == ptype)
324+
stmt = self._softdelete_query(stmt)
325+
for i, v in enumerate(field_values):
326+
if v != "":
327+
v_value = getattr(self._db_class, "v{}".format(field_index + i))
328+
stmt = stmt.where(v_value == v)
329+
result = await session.execute(stmt)
330+
lines = result.scalars().all()
331+
for line in lines:
332+
setattr(line, self.softdelete_attribute.name, True)
333+
return True if len(lines) > 0 else False
235334

236335
async def update_policy(self, sec: str, ptype: str, old_rule: List[str], new_rule: List[str]) -> None:
237336
"""
@@ -247,6 +346,7 @@ async def update_policy(self, sec: str, ptype: str, old_rule: List[str], new_rul
247346

248347
async with self._session_scope() as session:
249348
stmt = select(self._db_class).where(self._db_class.ptype == ptype)
349+
stmt = self._softdelete_query(stmt)
250350

251351
# locate the old rule
252352
for index, value in enumerate(old_rule):
@@ -307,9 +407,24 @@ async def _update_filtered_policies(self, new_rules, filter) -> List[List[str]]:
307407
# Load old policies
308408

309409
stmt = select(self._db_class).where(self._db_class.ptype == filter.ptype)
410+
stmt = self._softdelete_query(stmt)
310411
filtered_stmt = self.filter_query(stmt, filter)
311412
result = await session.execute(filtered_stmt)
312-
old_rules = result.scalars().all()
413+
old_rules_db = result.scalars().all()
414+
415+
# Convert database objects to rule lists
416+
old_rules = []
417+
for line in old_rules_db:
418+
fields_with_None = [
419+
line.v0,
420+
line.v1,
421+
line.v2,
422+
line.v3,
423+
line.v4,
424+
line.v5,
425+
]
426+
rule = [element for element in fields_with_None if element is not None]
427+
old_rules.append(rule)
313428

314429
# Delete old policies
315430

0 commit comments

Comments
 (0)