Skip to content

Commit 46ba668

Browse files
committed
fix(policy): synchronize policy_map updates in add, update, and remove operations
1 parent 963115b commit 46ba668

File tree

1 file changed

+53
-36
lines changed

1 file changed

+53
-36
lines changed

casbin/model/policy.py

Lines changed: 53 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,8 @@ def items(self):
3939

4040
def build_role_links(self, rm_map):
4141
"""initializes the roles in RBAC."""
42-
4342
if "g" not in self.keys():
4443
return
45-
4644
for ptype, ast in self["g"].items():
4745
rm = rm_map.get(ptype)
4846
if rm:
@@ -68,28 +66,23 @@ def build_conditional_role_links(self, cond_rm_map):
6866

6967
def print_policy(self):
7068
"""Log using info"""
71-
7269
self.logger.info("Policy:")
7370
for sec in ["p", "g"]:
7471
if sec not in self.keys():
7572
continue
76-
7773
for key, ast in self[sec].items():
7874
self.logger.info("{} : {} : {}".format(key, ast.value, ast.policy))
7975

8076
def clear_policy(self):
8177
"""clears all current policy."""
82-
8378
for sec in ["p", "g"]:
8479
if sec not in self.keys():
8580
continue
86-
8781
for key in self[sec].keys():
8882
self[sec][key].policy = []
8983

9084
def get_policy(self, sec, ptype):
9185
"""gets all rules in a policy."""
92-
9386
return self[sec][ptype].policy
9487

9588
def get_filtered_policy(self, sec, ptype, field_index, *field_values):
@@ -109,7 +102,6 @@ def has_policy(self, sec, ptype, rule):
109102
return False
110103
if ptype not in self[sec]:
111104
return False
112-
113105
return rule in self[sec][ptype].policy
114106

115107
def add_policy(self, sec, ptype, rule):
@@ -123,23 +115,19 @@ def add_policy(self, sec, ptype, rule):
123115
if sec == "p" and assertion.priority_index >= 0:
124116
try:
125117
idx_insert = int(rule[assertion.priority_index])
126-
127118
i = len(assertion.policy) - 1
128119
for i in range(i, 0, -1):
129120
try:
130121
idx = int(assertion.policy[i - 1][assertion.priority_index])
131122
except Exception as e:
132123
print(e)
133-
134124
if idx > idx_insert:
135125
tmp = assertion.policy[i]
136126
assertion.policy[i] = assertion.policy[i - 1]
137127
assertion.policy[i - 1] = tmp
138128
else:
139129
break
140-
141130
assertion.policy_map[DEFAULT_SEP.join(rule)] = i
142-
143131
except Exception as e:
144132
print(e)
145133

@@ -148,19 +136,16 @@ def add_policy(self, sec, ptype, rule):
148136

149137
def add_policies(self, sec, ptype, rules):
150138
"""adds policy rules to the model."""
151-
152139
for rule in rules:
153140
if self.has_policy(sec, ptype, rule):
154141
return False
155-
156142
for rule in rules:
157-
self[sec][ptype].policy.append(rule)
158-
143+
if not self.add_policy(sec, ptype, rule):
144+
return False
159145
return True
160146

161147
def update_policy(self, sec, ptype, old_rule, new_rule):
162148
"""update a policy rule from the model."""
163-
164149
if sec not in self.keys():
165150
return False
166151
if ptype not in self[sec]:
@@ -175,18 +160,21 @@ def update_policy(self, sec, ptype, old_rule, new_rule):
175160

176161
if "p_priority" in ast.tokens:
177162
priority_index = ast.tokens.index("p_priority")
178-
if old_rule[priority_index] == new_rule[priority_index]:
179-
ast.policy[rule_index] = new_rule
180-
else:
163+
if old_rule[priority_index] != new_rule[priority_index]:
181164
raise Exception("New rule should have the same priority with old rule.")
182-
else:
183-
ast.policy[rule_index] = new_rule
165+
# 替换列表中的规则
166+
ast.policy[rule_index] = new_rule
167+
# 更新映射:删除旧键,添加新键
168+
old_key = DEFAULT_SEP.join(old_rule)
169+
new_key = DEFAULT_SEP.join(new_rule)
170+
if old_key in ast.policy_map:
171+
del ast.policy_map[old_key]
172+
ast.policy_map[new_key] = rule_index
184173

185174
return True
186175

187176
def update_policies(self, sec, ptype, old_rules, new_rules):
188177
"""update policy rules from the model."""
189-
190178
if sec not in self.keys():
191179
return False
192180
if ptype not in self[sec]:
@@ -206,13 +194,22 @@ def update_policies(self, sec, ptype, old_rules, new_rules):
206194
if "p_priority" in ast.tokens:
207195
priority_index = ast.tokens.index("p_priority")
208196
for idx, old_rule, new_rule in zip(old_rules_index, old_rules, new_rules):
209-
if old_rule[priority_index] == new_rule[priority_index]:
210-
ast.policy[idx] = new_rule
211-
else:
197+
if old_rule[priority_index] != new_rule[priority_index]:
212198
raise Exception("New rule should have the same priority with old rule.")
199+
ast.policy[idx] = new_rule
200+
old_key = DEFAULT_SEP.join(old_rule)
201+
new_key = DEFAULT_SEP.join(new_rule)
202+
if old_key in ast.policy_map:
203+
del ast.policy_map[old_key]
204+
ast.policy_map[new_key] = idx
213205
else:
214206
for idx, old_rule, new_rule in zip(old_rules_index, old_rules, new_rules):
215207
ast.policy[idx] = new_rule
208+
old_key = DEFAULT_SEP.join(old_rule)
209+
new_key = DEFAULT_SEP.join(new_rule)
210+
if old_key in ast.policy_map:
211+
del ast.policy_map[old_key]
212+
ast.policy_map[new_key] = idx
216213

217214
return True
218215

@@ -221,19 +218,30 @@ def remove_policy(self, sec, ptype, rule):
221218
if not self.has_policy(sec, ptype, rule):
222219
return False
223220

224-
self[sec][ptype].policy.remove(rule)
221+
assertion = self[sec][ptype]
222+
assertion.policy.remove(rule)
223+
# 重新构建映射
224+
new_map = {}
225+
for idx, r in enumerate(assertion.policy):
226+
new_map[DEFAULT_SEP.join(r)] = idx
227+
assertion.policy_map = new_map
225228

226-
return rule not in self[sec][ptype].policy
229+
return rule not in assertion.policy
227230

228231
def remove_policies(self, sec, ptype, rules):
229232
"""RemovePolicies removes policy rules from the model."""
230-
233+
assertion = self[sec][ptype]
231234
for rule in rules:
232235
if not self.has_policy(sec, ptype, rule):
233236
return False
234-
self[sec][ptype].policy.remove(rule)
235-
if rule in self[sec][ptype].policy:
237+
assertion.policy.remove(rule)
238+
if rule in assertion.policy:
236239
return False
240+
# 重新构建映射
241+
new_map = {}
242+
for idx, r in enumerate(assertion.policy):
243+
new_map[DEFAULT_SEP.join(r)] = idx
244+
assertion.policy_map = new_map
237245

238246
return True
239247

@@ -243,7 +251,6 @@ def remove_policies_with_effected(self, sec, ptype, rules):
243251
if self.has_policy(sec, ptype, rule):
244252
effected.append(rule)
245253
self.remove_policy(sec, ptype, rule)
246-
247254
return effected
248255

249256
def remove_filtered_policy_returns_effects(self, sec, ptype, field_index, *field_values):
@@ -266,7 +273,13 @@ def remove_filtered_policy_returns_effects(self, sec, ptype, field_index, *field
266273
else:
267274
tmp.append(rule)
268275

269-
self[sec][ptype].policy = tmp
276+
assertion = self[sec][ptype]
277+
assertion.policy = tmp
278+
# 重新构建映射
279+
new_map = {}
280+
for idx, r in enumerate(assertion.policy):
281+
new_map[DEFAULT_SEP.join(r)] = idx
282+
assertion.policy_map = new_map
270283

271284
return effects
272285

@@ -286,7 +299,13 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
286299
else:
287300
tmp.append(rule)
288301

289-
self[sec][ptype].policy = tmp
302+
assertion = self[sec][ptype]
303+
assertion.policy = tmp
304+
# 重新构建映射
305+
new_map = {}
306+
for idx, r in enumerate(assertion.policy):
307+
new_map[DEFAULT_SEP.join(r)] = idx
308+
assertion.policy_map = new_map
290309

291310
return res
292311

@@ -297,10 +316,8 @@ def get_values_for_field_in_policy(self, sec, ptype, field_index):
297316
return values
298317
if ptype not in self[sec]:
299318
return values
300-
301319
for rule in self[sec][ptype].policy:
302320
value = rule[field_index]
303321
if value not in values:
304322
values.append(value)
305-
306323
return values

0 commit comments

Comments
 (0)