1616
1717from casbin import persist
1818from casbin .persist .adapters .asyncio import AsyncAdapter
19- from sqlalchemy import Column , Integer , String , delete , insert
20- from sqlalchemy import or_
19+ from sqlalchemy import Column , Integer , String , Boolean , delete , insert
20+ from sqlalchemy import or_ , not_
2121from sqlalchemy .ext .asyncio import create_async_engine , AsyncSession
2222from sqlalchemy .future import select
2323from sqlalchemy .orm import declarative_base , sessionmaker
@@ -66,6 +66,7 @@ def __init__(
6666 self ,
6767 engine ,
6868 db_class = None ,
69+ db_class_softdelete_attribute = None ,
6970 filtered = False ,
7071 db_session : Optional [AsyncSession ] = None ,
7172 ):
@@ -74,9 +75,18 @@ def __init__(
7475 else :
7576 self ._engine = engine
7677
78+ self .softdelete_attribute = None
79+
7780 if db_class is None :
7881 db_class = CasbinRule
7982 else :
83+ if db_class_softdelete_attribute is not None and not isinstance (db_class_softdelete_attribute .type , Boolean ):
84+ msg = f"The type of db_class_softdelete_attribute needs to be { str (Boolean )!r} . "
85+ msg += f"An attribute of type { str (type (db_class_softdelete_attribute .type ))!r} was given."
86+ raise ValueError (msg )
87+ # Softdelete is only supported when using custom class
88+ self .softdelete_attribute = db_class_softdelete_attribute
89+
8090 for attr in (
8191 "id" ,
8292 "ptype" ,
@@ -121,7 +131,9 @@ async def create_table(self):
121131 async def load_policy (self , model ):
122132 """loads all policy rules from the storage."""
123133 async with self ._session_scope () as session :
124- lines = await session .execute (select (self ._db_class ))
134+ stmt = select (self ._db_class )
135+ stmt = self ._softdelete_query (stmt )
136+ lines = await session .execute (stmt )
125137 for line in lines .scalars ():
126138 persist .load_policy_line (str (line ), model )
127139
@@ -132,6 +144,7 @@ async def load_filtered_policy(self, model, filter) -> None:
132144 """loads all policy rules from the storage."""
133145 async with self ._session_scope () as session :
134146 stmt = select (self ._db_class )
147+ stmt = self ._softdelete_query (stmt )
135148 stmt = self .filter_query (stmt , filter )
136149 result = await session .execute (stmt )
137150 for line in result .scalars ():
@@ -144,6 +157,12 @@ def filter_query(self, stmt, filter):
144157 stmt = stmt .where (getattr (self ._db_class , attr ).in_ (getattr (filter , attr )))
145158 return stmt .order_by (self ._db_class .id )
146159
160+ def _softdelete_query (self , stmt ):
161+ """Filter out soft-deleted records if soft delete is enabled."""
162+ if self .softdelete_attribute is not None :
163+ stmt = stmt .where (not_ (self .softdelete_attribute ))
164+ return stmt
165+
147166 async def _save_policy_line (self , ptype , rule , session = None ):
148167 if session is not None :
149168 # Use provided session
@@ -161,15 +180,62 @@ async def _save_policy_line(self, ptype, rule, session=None):
161180
162181 async def save_policy (self , model ):
163182 """saves all policy rules to the storage."""
183+ # Use the default strategy when soft delete is not enabled
184+ if self .softdelete_attribute is None :
185+ async with self ._session_scope () as session :
186+ stmt = delete (self ._db_class )
187+ await session .execute (stmt )
188+ for sec in ["p" , "g" ]:
189+ if sec not in model .model .keys ():
190+ continue
191+ for ptype , ast in model .model [sec ].items ():
192+ for rule in ast .policy :
193+ await self ._save_policy_line (ptype , rule , session )
194+ return True
195+
196+ # Custom strategy for softdelete since it does not make sense to recreate all of the
197+ # entries when using soft delete
164198 async with self ._session_scope () as session :
165- stmt = delete (self ._db_class )
166- await session .execute (stmt )
199+ stmt = select (self ._db_class )
200+ stmt = self ._softdelete_query (stmt )
201+
202+ # Get entries that are not part of the model anymore
203+ result = await session .execute (stmt )
204+ lines_before_changes = result .scalars ().all ()
205+
206+ # Create new entries in the database
167207 for sec in ["p" , "g" ]:
168208 if sec not in model .model .keys ():
169209 continue
170210 for ptype , ast in model .model [sec ].items ():
171211 for rule in ast .policy :
172- await self ._save_policy_line (ptype , rule , session )
212+ # Filter for rule in the database
213+ filter_stmt = select (self ._db_class ).where (self ._db_class .ptype == ptype )
214+ filter_stmt = self ._softdelete_query (filter_stmt )
215+ for index , value in enumerate (rule ):
216+ v_value = getattr (self ._db_class , "v{}" .format (index ))
217+ filter_stmt = filter_stmt .where (v_value == value )
218+ # If the rule is not present, create an entry in the database
219+ result = await session .execute (filter_stmt )
220+ if result .scalar_one_or_none () is None :
221+ await self ._save_policy_line (ptype , rule , session = session )
222+
223+ for line in lines_before_changes :
224+ ptype = line .ptype
225+ sec = ptype [0 ] # derived from persist.load_policy_line function
226+ fields_with_None = [
227+ line .v0 ,
228+ line .v1 ,
229+ line .v2 ,
230+ line .v3 ,
231+ line .v4 ,
232+ line .v5 ,
233+ ]
234+ rule = [element for element in fields_with_None if element is not None ]
235+ # If the rule is not part of the model, set the deletion flag to True
236+ if not model .has_policy (sec , ptype , rule ):
237+ setattr (line , self .softdelete_attribute .name , True )
238+
173239 return True
174240
175241 async def add_policy (self , sec , ptype , rule ):
@@ -196,42 +262,75 @@ async def add_policies(self, sec, ptype, rules):
196262 async def remove_policy (self , sec , ptype , rule ):
197263 """removes a policy rule from the storage."""
198264 async with self ._session_scope () as session :
199- stmt = delete (self ._db_class ).where (self ._db_class .ptype == ptype )
200- for i , v in enumerate (rule ):
201- stmt = stmt .where (getattr (self ._db_class , "v{}" .format (i )) == v )
202- r = await session .execute (stmt )
203-
204- return True if r .rowcount > 0 else False
265+ if self .softdelete_attribute is None :
266+ stmt = delete (self ._db_class ).where (self ._db_class .ptype == ptype )
267+ for i , v in enumerate (rule ):
268+ stmt = stmt .where (getattr (self ._db_class , "v{}" .format (i )) == v )
269+ r = await session .execute (stmt )
270+ return True if r .rowcount > 0 else False
271+ else :
272+ stmt = select (self ._db_class ).where (self ._db_class .ptype == ptype )
273+ stmt = self ._softdelete_query (stmt )
274+ for i , v in enumerate (rule ):
275+ stmt = stmt .where (getattr (self ._db_class , "v{}" .format (i )) == v )
276+ result = await session .execute (stmt )
277+ lines = result .scalars ().all ()
278+ for line in lines :
279+ setattr (line , self .softdelete_attribute .name , True )
280+ return True if len (lines ) > 0 else False
205281
206282 async def remove_policies (self , sec , ptype , rules ):
207283 """remove policy rules from the storage."""
208284 if not rules :
209285 return
210286 async with self ._session_scope () as session :
211- stmt = delete (self ._db_class ).where (self ._db_class .ptype == ptype )
212- rules = zip (* rules )
213- for i , rule in enumerate (rules ):
214- stmt = stmt .where (or_ (getattr (self ._db_class , "v{}" .format (i )) == v for v in rule ))
215- await session .execute (stmt )
287+ if self .softdelete_attribute is None :
288+ stmt = delete (self ._db_class ).where (self ._db_class .ptype == ptype )
289+ rules_zipped = zip (* rules )
290+ for i , rule in enumerate (rules_zipped ):
291+ stmt = stmt .where (or_ (getattr (self ._db_class , "v{}" .format (i )) == v for v in rule ))
292+ await session .execute (stmt )
293+ else :
294+ stmt = select (self ._db_class ).where (self ._db_class .ptype == ptype )
295+ stmt = self ._softdelete_query (stmt )
296+ rules_zipped = zip (* rules )
297+ for i , rule in enumerate (rules_zipped ):
298+ stmt = stmt .where (or_ (getattr (self ._db_class , "v{}" .format (i )) == v for v in rule ))
299+ result = await session .execute (stmt )
300+ lines = result .scalars ().all ()
301+ for line in lines :
302+ setattr (line , self .softdelete_attribute .name , True )
216303
217304 async def remove_filtered_policy (self , sec , ptype , field_index , * field_values ):
218305 """removes policy rules that match the filter from the storage.
219306 This is part of the Auto-Save feature.
220307 """
221308 async with self ._session_scope () as session :
222- stmt = delete (self ._db_class ).where (self ._db_class .ptype == ptype )
223-
224309 if not (0 <= field_index <= 5 ):
225310 return False
226311 if not (1 <= field_index + len (field_values ) <= 6 ):
227312 return False
228- for i , v in enumerate (field_values ):
229- if v != "" :
230- v_value = getattr (self ._db_class , "v{}" .format (field_index + i ))
231- stmt = stmt .where (v_value == v )
232- r = await session .execute (stmt )
233313
234- return True if r .rowcount > 0 else False
314+ if self .softdelete_attribute is None :
315+ stmt = delete (self ._db_class ).where (self ._db_class .ptype == ptype )
316+ for i , v in enumerate (field_values ):
317+ if v != "" :
318+ v_value = getattr (self ._db_class , "v{}" .format (field_index + i ))
319+ stmt = stmt .where (v_value == v )
320+ r = await session .execute (stmt )
321+ return True if r .rowcount > 0 else False
322+ else :
323+ stmt = select (self ._db_class ).where (self ._db_class .ptype == ptype )
324+ stmt = self ._softdelete_query (stmt )
325+ for i , v in enumerate (field_values ):
326+ if v != "" :
327+ v_value = getattr (self ._db_class , "v{}" .format (field_index + i ))
328+ stmt = stmt .where (v_value == v )
329+ result = await session .execute (stmt )
330+ lines = result .scalars ().all ()
331+ for line in lines :
332+ setattr (line , self .softdelete_attribute .name , True )
333+ return True if len (lines ) > 0 else False
235334
236335 async def update_policy (self , sec : str , ptype : str , old_rule : List [str ], new_rule : List [str ]) -> None :
237336 """
@@ -247,6 +346,7 @@ async def update_policy(self, sec: str, ptype: str, old_rule: List[str], new_rul
247346
248347 async with self ._session_scope () as session :
249348 stmt = select (self ._db_class ).where (self ._db_class .ptype == ptype )
349+ stmt = self ._softdelete_query (stmt )
250350
251351 # locate the old rule
252352 for index , value in enumerate (old_rule ):
@@ -307,9 +407,24 @@ async def _update_filtered_policies(self, new_rules, filter) -> List[List[str]]:
307407 # Load old policies
308408
309409 stmt = select (self ._db_class ).where (self ._db_class .ptype == filter .ptype )
410+ stmt = self ._softdelete_query (stmt )
310411 filtered_stmt = self .filter_query (stmt , filter )
311412 result = await session .execute (filtered_stmt )
312- old_rules = result .scalars ().all ()
413+ old_rules_db = result .scalars ().all ()
414+
415+ # Convert database objects to rule lists
416+ old_rules = []
417+ for line in old_rules_db :
418+ fields_with_None = [
419+ line .v0 ,
420+ line .v1 ,
421+ line .v2 ,
422+ line .v3 ,
423+ line .v4 ,
424+ line .v5 ,
425+ ]
426+ rule = [element for element in fields_with_None if element is not None ]
427+ old_rules .append (rule )
313428
314429 # Delete old policies
315430
0 commit comments