Skip to content

Commit 256bcbd

Browse files
committed
device: add support for removing allowedips individually
This pairs with the recent change in wireguard-tools. Signed-off-by: Jason A. Donenfeld <[email protected]>
1 parent 1571e0f commit 256bcbd

File tree

3 files changed

+125
-34
lines changed

3 files changed

+125
-34
lines changed

device/allowedips.go

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -223,45 +223,68 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix)
223223
}
224224
}
225225

226+
func (node *trieEntry) remove() {
227+
node.removeFromPeerEntries()
228+
node.peer = nil
229+
if node.child[0] != nil && node.child[1] != nil {
230+
return
231+
}
232+
bit := 0
233+
if node.child[0] == nil {
234+
bit = 1
235+
}
236+
child := node.child[bit]
237+
if child != nil {
238+
child.parent = node.parent
239+
}
240+
*node.parent.parentBit = child
241+
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
242+
node.zeroizePointers()
243+
return
244+
}
245+
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
246+
if parent.peer != nil {
247+
node.zeroizePointers()
248+
return
249+
}
250+
child = parent.child[node.parent.parentBitType^1]
251+
if child != nil {
252+
child.parent = parent.parent
253+
}
254+
*parent.parent.parentBit = child
255+
node.zeroizePointers()
256+
parent.zeroizePointers()
257+
}
258+
259+
func (table *AllowedIPs) Remove(prefix netip.Prefix, peer *Peer) {
260+
table.mutex.Lock()
261+
defer table.mutex.Unlock()
262+
var node *trieEntry
263+
var exact bool
264+
265+
if prefix.Addr().Is6() {
266+
ip := prefix.Addr().As16()
267+
node, exact = table.IPv6.nodePlacement(ip[:], uint8(prefix.Bits()))
268+
} else if prefix.Addr().Is4() {
269+
ip := prefix.Addr().As4()
270+
node, exact = table.IPv4.nodePlacement(ip[:], uint8(prefix.Bits()))
271+
} else {
272+
panic(errors.New("removing unknown address type"))
273+
}
274+
if !exact || node == nil || peer != node.peer {
275+
return
276+
}
277+
node.remove()
278+
}
279+
226280
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
227281
table.mutex.Lock()
228282
defer table.mutex.Unlock()
229283

230284
var next *list.Element
231285
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
232286
next = elem.Next()
233-
node := elem.Value.(*trieEntry)
234-
235-
node.removeFromPeerEntries()
236-
node.peer = nil
237-
if node.child[0] != nil && node.child[1] != nil {
238-
continue
239-
}
240-
bit := 0
241-
if node.child[0] == nil {
242-
bit = 1
243-
}
244-
child := node.child[bit]
245-
if child != nil {
246-
child.parent = node.parent
247-
}
248-
*node.parent.parentBit = child
249-
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
250-
node.zeroizePointers()
251-
continue
252-
}
253-
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
254-
if parent.peer != nil {
255-
node.zeroizePointers()
256-
continue
257-
}
258-
child = parent.child[node.parent.parentBitType^1]
259-
if child != nil {
260-
child.parent = parent.parent
261-
}
262-
*parent.parent.parentBit = child
263-
node.zeroizePointers()
264-
parent.zeroizePointers()
287+
elem.Value.(*trieEntry).remove()
265288
}
266289
}
267290

device/allowedips_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ func TestTrieIPv4(t *testing.T) {
101101
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
102102
}
103103

104+
remove := func(peer *Peer, a, b, c, d byte, cidr uint8) {
105+
allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
106+
}
107+
104108
assertEQ := func(peer *Peer, a, b, c, d byte) {
105109
p := allowedIPs.Lookup([]byte{a, b, c, d})
106110
if p != peer {
@@ -176,6 +180,21 @@ func TestTrieIPv4(t *testing.T) {
176180
allowedIPs.RemoveByPeer(a)
177181

178182
assertNEQ(a, 192, 168, 0, 1)
183+
184+
insert(a, 1, 0, 0, 0, 32)
185+
insert(a, 192, 0, 0, 0, 24)
186+
assertEQ(a, 1, 0, 0, 0)
187+
assertEQ(a, 192, 0, 0, 1)
188+
remove(a, 192, 0, 0, 0, 32)
189+
assertEQ(a, 192, 0, 0, 1)
190+
remove(nil, 192, 0, 0, 0, 24)
191+
assertEQ(a, 192, 0, 0, 1)
192+
remove(b, 192, 0, 0, 0, 24)
193+
assertEQ(a, 192, 0, 0, 1)
194+
remove(a, 192, 0, 0, 0, 24)
195+
assertNEQ(a, 192, 0, 0, 1)
196+
remove(a, 1, 0, 0, 0, 32)
197+
assertNEQ(a, 1, 0, 0, 0)
179198
}
180199

181200
/* Test ported from kernel implementation:
@@ -211,6 +230,15 @@ func TestTrieIPv6(t *testing.T) {
211230
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
212231
}
213232

233+
remove := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
234+
var addr []byte
235+
addr = append(addr, expand(a)...)
236+
addr = append(addr, expand(b)...)
237+
addr = append(addr, expand(c)...)
238+
addr = append(addr, expand(d)...)
239+
allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
240+
}
241+
214242
assertEQ := func(peer *Peer, a, b, c, d uint32) {
215243
var addr []byte
216244
addr = append(addr, expand(a)...)
@@ -223,6 +251,18 @@ func TestTrieIPv6(t *testing.T) {
223251
}
224252
}
225253

254+
assertNEQ := func(peer *Peer, a, b, c, d uint32) {
255+
var addr []byte
256+
addr = append(addr, expand(a)...)
257+
addr = append(addr, expand(b)...)
258+
addr = append(addr, expand(c)...)
259+
addr = append(addr, expand(d)...)
260+
p := allowedIPs.Lookup(addr)
261+
if p == peer {
262+
t.Error("Assert NEQ failed")
263+
}
264+
}
265+
226266
insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
227267
insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
228268
insert(e, 0, 0, 0, 0, 0)
@@ -244,4 +284,21 @@ func TestTrieIPv6(t *testing.T) {
244284
assertEQ(h, 0x24046800, 0x40040800, 0, 0)
245285
assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010)
246286
assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef)
287+
288+
insert(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
289+
insert(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
290+
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
291+
assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
292+
remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 96)
293+
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
294+
remove(nil, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
295+
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
296+
remove(b, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
297+
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
298+
remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
299+
assertNEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
300+
remove(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
301+
assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
302+
remove(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
303+
assertNEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
247304
}

device/uapi.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,15 +371,26 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
371371
device.allowedips.RemoveByPeer(peer.Peer)
372372

373373
case "allowed_ip":
374-
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
374+
add := true
375+
verb := "Adding"
376+
if len(value) > 0 && value[0] == '-' {
377+
add = false
378+
verb = "Removing"
379+
value = value[1:]
380+
}
381+
device.log.Verbosef("%v - UAPI: %s allowedip", peer.Peer, verb)
375382
prefix, err := netip.ParsePrefix(value)
376383
if err != nil {
377384
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
378385
}
379386
if peer.dummy {
380387
return nil
381388
}
382-
device.allowedips.Insert(prefix, peer.Peer)
389+
if add {
390+
device.allowedips.Insert(prefix, peer.Peer)
391+
} else {
392+
device.allowedips.Remove(prefix, peer.Peer)
393+
}
383394

384395
case "protocol_version":
385396
if value != "1" {

0 commit comments

Comments
 (0)