Browse Source

conn: new package that splits out the Bind and Endpoint types

The sticky socket code stays in the device package for now,
as it reaches deeply into the peer list.

This is the first step in an effort to split some code out of
the very busy device package.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
David Crawshaw 6 months ago
parent
commit
c4a8eab3dd

device/boundif_windows.go → conn/boundif_windows.go View File

@@ -3,11 +3,10 @@
3 3
  * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
4 4
  */
5 5
 
6
-package device
6
+package conn
7 7
 
8 8
 import (
9 9
 	"encoding/binary"
10
-	"errors"
11 10
 	"unsafe"
12 11
 
13 12
 	"golang.org/x/sys/windows"
@@ -18,17 +17,13 @@ const (
18 17
 	sockoptIPV6_UNICAST_IF = 31
19 18
 )
20 19
 
21
-func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
20
+func (bind *nativeBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
22 21
 	/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
23 22
 	bytes := make([]byte, 4)
24 23
 	binary.BigEndian.PutUint32(bytes, interfaceIndex)
25 24
 	interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
26 25
 
27
-	if device.net.bind == nil {
28
-		return errors.New("Bind is not yet initialized")
29
-	}
30
-
31
-	sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
26
+	sysconn, err := bind.ipv4.SyscallConn()
32 27
 	if err != nil {
33 28
 		return err
34 29
 	}
@@ -41,12 +36,12 @@ func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bo
41 36
 	if err != nil {
42 37
 		return err
43 38
 	}
44
-	device.net.bind.(*nativeBind).blackhole4 = blackhole
39
+	bind.blackhole4 = blackhole
45 40
 	return nil
46 41
 }
47 42
 
48
-func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
49
-	sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
43
+func (bind *nativeBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
44
+	sysconn, err := bind.ipv6.SyscallConn()
50 45
 	if err != nil {
51 46
 		return err
52 47
 	}
@@ -59,6 +54,6 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bo
59 54
 	if err != nil {
60 55
 		return err
61 56
 	}
62
-	device.net.bind.(*nativeBind).blackhole6 = blackhole
57
+	bind.blackhole6 = blackhole
63 58
 	return nil
64 59
 }

+ 101
- 0
conn/conn.go View File

@@ -0,0 +1,101 @@
1
+/* SPDX-License-Identifier: MIT
2
+ *
3
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
4
+ */
5
+
6
+// Package conn implements WireGuard's network connections.
7
+package conn
8
+
9
+import (
10
+	"errors"
11
+	"net"
12
+	"strings"
13
+)
14
+
15
+// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
16
+type Bind interface {
17
+	// LastMark reports the last mark set for this Bind.
18
+	LastMark() uint32
19
+
20
+	// SetMark sets the mark for each packet sent through this Bind.
21
+	// This mark is passed to the kernel as the socket option SO_MARK.
22
+	SetMark(mark uint32) error
23
+
24
+	// ReceiveIPv6 reads an IPv6 UDP packet into b.
25
+	//
26
+	// It reports the number of bytes read, n,
27
+	// the packet source address ep,
28
+	// and any error.
29
+	ReceiveIPv6(buff []byte) (n int, ep Endpoint, err error)
30
+
31
+	// ReceiveIPv4 reads an IPv4 UDP packet into b.
32
+	//
33
+	// It reports the number of bytes read, n,
34
+	// the packet source address ep,
35
+	// and any error.
36
+	ReceiveIPv4(b []byte) (n int, ep Endpoint, err error)
37
+
38
+	// Send writes a packet b to address ep.
39
+	Send(b []byte, ep Endpoint) error
40
+
41
+	// Close closes the Bind connection.
42
+	Close() error
43
+}
44
+
45
+// CreateBind creates a Bind bound to a port.
46
+//
47
+// The value actualPort reports the actual port number the Bind
48
+// object gets bound to.
49
+func CreateBind(port uint16) (b Bind, actualPort uint16, err error) {
50
+	return createBind(port)
51
+}
52
+
53
+// BindToInterface is implemented by Bind objects that support being
54
+// tied to a single network interface.
55
+type BindToInterface interface {
56
+	BindToInterface4(interfaceIndex uint32, blackhole bool) error
57
+	BindToInterface6(interfaceIndex uint32, blackhole bool) error
58
+}
59
+
60
+// An Endpoint maintains the source/destination caching for a peer.
61
+//
62
+//	dst : the remote address of a peer ("endpoint" in uapi terminology)
63
+//	src : the local address from which datagrams originate going to the peer
64
+type Endpoint interface {
65
+	ClearSrc()           // clears the source address
66
+	SrcToString() string // returns the local source address (ip:port)
67
+	DstToString() string // returns the destination address (ip:port)
68
+	DstToBytes() []byte  // used for mac2 cookie calculations
69
+	DstIP() net.IP
70
+	SrcIP() net.IP
71
+}
72
+
73
+func parseEndpoint(s string) (*net.UDPAddr, error) {
74
+	// ensure that the host is an IP address
75
+
76
+	host, _, err := net.SplitHostPort(s)
77
+	if err != nil {
78
+		return nil, err
79
+	}
80
+	if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
81
+		// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
82
+		// trying to make sure with a small sanity test that this is a real IP address and
83
+		// not something that's likely to incur DNS lookups.
84
+		host = host[:i]
85
+	}
86
+	if ip := net.ParseIP(host); ip == nil {
87
+		return nil, errors.New("Failed to parse IP address: " + host)
88
+	}
89
+
90
+	// parse address and port
91
+
92
+	addr, err := net.ResolveUDPAddr("udp", s)
93
+	if err != nil {
94
+		return nil, err
95
+	}
96
+	ip4 := addr.IP.To4()
97
+	if ip4 != nil {
98
+		addr.IP = ip4
99
+	}
100
+	return addr, err
101
+}

device/conn_default.go → conn/conn_default.go View File

@@ -5,7 +5,7 @@
5 5
  * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
6 6
  */
7 7
 
8
-package device
8
+package conn
9 9
 
10 10
 import (
11 11
 	"net"
@@ -67,16 +67,13 @@ func (e *NativeEndpoint) SrcToString() string {
67 67
 }
68 68
 
69 69
 func listenNet(network string, port int) (*net.UDPConn, int, error) {
70
-
71
-	// listen
72
-
73 70
 	conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
74 71
 	if err != nil {
75 72
 		return nil, 0, err
76 73
 	}
77 74
 
78
-	// retrieve port
79
-
75
+	// Retrieve port.
76
+	// TODO(crawshaw): under what circumstances is this necessary?
80 77
 	laddr := conn.LocalAddr()
81 78
 	uaddr, err := net.ResolveUDPAddr(
82 79
 		laddr.Network(),
@@ -100,7 +97,7 @@ func extractErrno(err error) error {
100 97
 	return syscallErr.Err
101 98
 }
102 99
 
103
-func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
100
+func createBind(uport uint16) (Bind, uint16, error) {
104 101
 	var err error
105 102
 	var bind nativeBind
106 103
 
@@ -135,6 +132,8 @@ func (bind *nativeBind) Close() error {
135 132
 	return err2
136 133
 }
137 134
 
135
+func (bind *nativeBind) LastMark() uint32 { return 0 }
136
+
138 137
 func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
139 138
 	if bind.ipv4 == nil {
140 139
 		return 0, nil, syscall.EAFNOSUPPORT

device/conn_linux.go → conn/conn_linux.go View File

@@ -3,18 +3,9 @@
3 3
 /* SPDX-License-Identifier: MIT
4 4
  *
5 5
  * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
6
- *
7
- * This implements userspace semantics of "sticky sockets", modeled after
8
- * WireGuard's kernelspace implementation. This is more or less a straight port
9
- * of the sticky-sockets.c example code:
10
- * https://git.zx2c4.com/wireguard-tools/tree/contrib/sticky-sockets/sticky-sockets.c
11
- *
12
- * Currently there is no way to achieve this within the net package:
13
- * See e.g. https://github.com/golang/go/issues/17930
14
- * So this code is remains platform dependent.
15 6
  */
16 7
 
17
-package device
8
+package conn
18 9
 
19 10
 import (
20 11
 	"errors"
@@ -25,7 +16,6 @@ import (
25 16
 	"unsafe"
26 17
 
27 18
 	"golang.org/x/sys/unix"
28
-	"golang.zx2c4.com/wireguard/rwcancel"
29 19
 )
30 20
 
31 21
 const (
@@ -33,8 +23,8 @@ const (
33 23
 )
34 24
 
35 25
 type IPv4Source struct {
36
-	src     [4]byte
37
-	ifindex int32
26
+	Src     [4]byte
27
+	Ifindex int32
38 28
 }
39 29
 
40 30
 type IPv6Source struct {
@@ -49,6 +39,10 @@ type NativeEndpoint struct {
49 39
 	isV6 bool
50 40
 }
51 41
 
42
+func (endpoint *NativeEndpoint) Src4() *IPv4Source         { return endpoint.src4() }
43
+func (endpoint *NativeEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
44
+func (endpoint *NativeEndpoint) IsV6() bool                { return endpoint.isV6 }
45
+
52 46
 func (endpoint *NativeEndpoint) src4() *IPv4Source {
53 47
 	return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
54 48
 }
@@ -66,11 +60,9 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
66 60
 }
67 61
 
68 62
 type nativeBind struct {
69
-	sock4         int
70
-	sock6         int
71
-	netlinkSock   int
72
-	netlinkCancel *rwcancel.RWCancel
73
-	lastMark      uint32
63
+	sock4    int
64
+	sock6    int
65
+	lastMark uint32
74 66
 }
75 67
 
76 68
 var _ Endpoint = (*NativeEndpoint)(nil)
@@ -111,59 +103,25 @@ func CreateEndpoint(s string) (Endpoint, error) {
111 103
 	return nil, errors.New("Invalid IP address")
112 104
 }
113 105
 
114
-func createNetlinkRouteSocket() (int, error) {
115
-	sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
116
-	if err != nil {
117
-		return -1, err
118
-	}
119
-	saddr := &unix.SockaddrNetlink{
120
-		Family: unix.AF_NETLINK,
121
-		Groups: unix.RTMGRP_IPV4_ROUTE,
122
-	}
123
-	err = unix.Bind(sock, saddr)
124
-	if err != nil {
125
-		unix.Close(sock)
126
-		return -1, err
127
-	}
128
-	return sock, nil
129
-
130
-}
131
-
132
-func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
106
+func createBind(port uint16) (Bind, uint16, error) {
133 107
 	var err error
134 108
 	var bind nativeBind
135 109
 	var newPort uint16
136 110
 
137
-	bind.netlinkSock, err = createNetlinkRouteSocket()
138
-	if err != nil {
139
-		return nil, 0, err
140
-	}
141
-	bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock)
142
-	if err != nil {
143
-		unix.Close(bind.netlinkSock)
144
-		return nil, 0, err
145
-	}
146
-
147
-	go bind.routineRouteListener(device)
148
-
149
-	// attempt ipv6 bind, update port if successful
150
-
111
+	// Attempt ipv6 bind, update port if successful.
151 112
 	bind.sock6, newPort, err = create6(port)
152 113
 	if err != nil {
153 114
 		if err != syscall.EAFNOSUPPORT {
154
-			bind.netlinkCancel.Cancel()
155 115
 			return nil, 0, err
156 116
 		}
157 117
 	} else {
158 118
 		port = newPort
159 119
 	}
160 120
 
161
-	// attempt ipv4 bind, update port if successful
162
-
121
+	// Attempt ipv4 bind, update port if successful.
163 122
 	bind.sock4, newPort, err = create4(port)
164 123
 	if err != nil {
165 124
 		if err != syscall.EAFNOSUPPORT {
166
-			bind.netlinkCancel.Cancel()
167 125
 			unix.Close(bind.sock6)
168 126
 			return nil, 0, err
169 127
 		}
@@ -178,6 +136,10 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
178 136
 	return &bind, port, nil
179 137
 }
180 138
 
139
+func (bind *nativeBind) LastMark() uint32 {
140
+	return bind.lastMark
141
+}
142
+
181 143
 func (bind *nativeBind) SetMark(value uint32) error {
182 144
 	if bind.sock6 != -1 {
183 145
 		err := unix.SetsockoptInt(
@@ -216,22 +178,18 @@ func closeUnblock(fd int) error {
216 178
 }
217 179
 
218 180
 func (bind *nativeBind) Close() error {
219
-	var err1, err2, err3 error
181
+	var err1, err2 error
220 182
 	if bind.sock6 != -1 {
221 183
 		err1 = closeUnblock(bind.sock6)
222 184
 	}
223 185
 	if bind.sock4 != -1 {
224 186
 		err2 = closeUnblock(bind.sock4)
225 187
 	}
226
-	err3 = bind.netlinkCancel.Cancel()
227 188
 
228 189
 	if err1 != nil {
229 190
 		return err1
230 191
 	}
231
-	if err2 != nil {
232
-		return err2
233
-	}
234
-	return err3
192
+	return err2
235 193
 }
236 194
 
237 195
 func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
@@ -278,10 +236,10 @@ func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
278 236
 func (end *NativeEndpoint) SrcIP() net.IP {
279 237
 	if !end.isV6 {
280 238
 		return net.IPv4(
281
-			end.src4().src[0],
282
-			end.src4().src[1],
283
-			end.src4().src[2],
284
-			end.src4().src[3],
239
+			end.src4().Src[0],
240
+			end.src4().Src[1],
241
+			end.src4().Src[2],
242
+			end.src4().Src[3],
285 243
 		)
286 244
 	} else {
287 245
 		return end.src6().src[:]
@@ -478,8 +436,8 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
478 436
 			Len:   unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
479 437
 		},
480 438
 		unix.Inet4Pktinfo{
481
-			Spec_dst: end.src4().src,
482
-			Ifindex:  end.src4().ifindex,
439
+			Spec_dst: end.src4().Src,
440
+			Ifindex:  end.src4().Ifindex,
483 441
 		},
484 442
 	}
485 443
 
@@ -573,8 +531,8 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
573 531
 	if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
574 532
 		cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
575 533
 		cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
576
-		end.src4().src = cmsg.pktinfo.Spec_dst
577
-		end.src4().ifindex = cmsg.pktinfo.Ifindex
534
+		end.src4().Src = cmsg.pktinfo.Spec_dst
535
+		end.src4().Ifindex = cmsg.pktinfo.Ifindex
578 536
 	}
579 537
 
580 538
 	return size, nil
@@ -611,156 +569,3 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
611 569
 
612 570
 	return size, nil
613 571
 }
614
-
615
-func (bind *nativeBind) routineRouteListener(device *Device) {
616
-	type peerEndpointPtr struct {
617
-		peer     *Peer
618
-		endpoint *Endpoint
619
-	}
620
-	var reqPeer map[uint32]peerEndpointPtr
621
-	var reqPeerLock sync.Mutex
622
-
623
-	defer unix.Close(bind.netlinkSock)
624
-
625
-	for msg := make([]byte, 1<<16); ; {
626
-		var err error
627
-		var msgn int
628
-		for {
629
-			msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
630
-			if err == nil || !rwcancel.RetryAfterError(err) {
631
-				break
632
-			}
633
-			if !bind.netlinkCancel.ReadyRead() {
634
-				return
635
-			}
636
-		}
637
-		if err != nil {
638
-			return
639
-		}
640
-
641
-		for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
642
-
643
-			hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
644
-
645
-			if uint(hdr.Len) > uint(len(remain)) {
646
-				break
647
-			}
648
-
649
-			switch hdr.Type {
650
-			case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
651
-				if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
652
-					if uint(len(remain)) < uint(hdr.Len) {
653
-						break
654
-					}
655
-					if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
656
-						attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
657
-						for {
658
-							if uint(len(attr)) < uint(unix.SizeofRtAttr) {
659
-								break
660
-							}
661
-							attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
662
-							if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
663
-								break
664
-							}
665
-							if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
666
-								ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
667
-								reqPeerLock.Lock()
668
-								if reqPeer == nil {
669
-									reqPeerLock.Unlock()
670
-									break
671
-								}
672
-								pePtr, ok := reqPeer[hdr.Seq]
673
-								reqPeerLock.Unlock()
674
-								if !ok {
675
-									break
676
-								}
677
-								pePtr.peer.Lock()
678
-								if &pePtr.peer.endpoint != pePtr.endpoint {
679
-									pePtr.peer.Unlock()
680
-									break
681
-								}
682
-								if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx {
683
-									pePtr.peer.Unlock()
684
-									break
685
-								}
686
-								pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc()
687
-								pePtr.peer.Unlock()
688
-							}
689
-							attr = attr[attrhdr.Len:]
690
-						}
691
-					}
692
-					break
693
-				}
694
-				reqPeerLock.Lock()
695
-				reqPeer = make(map[uint32]peerEndpointPtr)
696
-				reqPeerLock.Unlock()
697
-				go func() {
698
-					device.peers.RLock()
699
-					i := uint32(1)
700
-					for _, peer := range device.peers.keyMap {
701
-						peer.RLock()
702
-						if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil {
703
-							peer.RUnlock()
704
-							continue
705
-						}
706
-						if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 {
707
-							peer.RUnlock()
708
-							break
709
-						}
710
-						nlmsg := struct {
711
-							hdr     unix.NlMsghdr
712
-							msg     unix.RtMsg
713
-							dsthdr  unix.RtAttr
714
-							dst     [4]byte
715
-							srchdr  unix.RtAttr
716
-							src     [4]byte
717
-							markhdr unix.RtAttr
718
-							mark    uint32
719
-						}{
720
-							unix.NlMsghdr{
721
-								Type:  uint16(unix.RTM_GETROUTE),
722
-								Flags: unix.NLM_F_REQUEST,
723
-								Seq:   i,
724
-							},
725
-							unix.RtMsg{
726
-								Family:  unix.AF_INET,
727
-								Dst_len: 32,
728
-								Src_len: 32,
729
-							},
730
-							unix.RtAttr{
731
-								Len:  8,
732
-								Type: unix.RTA_DST,
733
-							},
734
-							peer.endpoint.(*NativeEndpoint).dst4().Addr,
735
-							unix.RtAttr{
736
-								Len:  8,
737
-								Type: unix.RTA_SRC,
738
-							},
739
-							peer.endpoint.(*NativeEndpoint).src4().src,
740
-							unix.RtAttr{
741
-								Len:  8,
742
-								Type: unix.RTA_MARK,
743
-							},
744
-							uint32(bind.lastMark),
745
-						}
746
-						nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
747
-						reqPeerLock.Lock()
748
-						reqPeer[i] = peerEndpointPtr{
749
-							peer:     peer,
750
-							endpoint: &peer.endpoint,
751
-						}
752
-						reqPeerLock.Unlock()
753
-						peer.RUnlock()
754
-						i++
755
-						_, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
756
-						if err != nil {
757
-							break
758
-						}
759
-					}
760
-					device.peers.RUnlock()
761
-				}()
762
-			}
763
-			remain = remain[hdr.Len:]
764
-		}
765
-	}
766
-}

device/mark_default.go → conn/mark_default.go View File

@@ -5,7 +5,7 @@
5 5
  * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
6 6
  */
7 7
 
8
-package device
8
+package conn
9 9
 
10 10
 func (bind *nativeBind) SetMark(mark uint32) error {
11 11
 	return nil

device/mark_unix.go → conn/mark_unix.go View File

@@ -5,7 +5,7 @@
5 5
  * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
6 6
  */
7 7
 
8
-package device
8
+package conn
9 9
 
10 10
 import (
11 11
 	"runtime"

+ 9
- 5
device/bind_test.go View File

@@ -5,11 +5,15 @@
5 5
 
6 6
 package device
7 7
 
8
-import "errors"
8
+import (
9
+	"errors"
10
+
11
+	"golang.zx2c4.com/wireguard/conn"
12
+)
9 13
 
10 14
 type DummyDatagram struct {
11 15
 	msg      []byte
12
-	endpoint Endpoint
16
+	endpoint conn.Endpoint
13 17
 	world    bool // better type
14 18
 }
15 19
 
@@ -25,7 +29,7 @@ func (b *DummyBind) SetMark(v uint32) error {
25 29
 	return nil
26 30
 }
27 31
 
28
-func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
32
+func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) {
29 33
 	datagram, ok := <-b.in6
30 34
 	if !ok {
31 35
 		return 0, nil, errors.New("closed")
@@ -34,7 +38,7 @@ func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
34 38
 	return len(datagram.msg), datagram.endpoint, nil
35 39
 }
36 40
 
37
-func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
41
+func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) {
38 42
 	datagram, ok := <-b.in4
39 43
 	if !ok {
40 44
 		return 0, nil, errors.New("closed")
@@ -50,6 +54,6 @@ func (b *DummyBind) Close() error {
50 54
 	return nil
51 55
 }
52 56
 
53
-func (b *DummyBind) Send(buff []byte, end Endpoint) error {
57
+func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error {
54 58
 	return nil
55 59
 }

+ 36
- 0
device/bindsocketshim.go View File

@@ -0,0 +1,36 @@
1
+/* SPDX-License-Identifier: MIT
2
+ *
3
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
4
+ */
5
+
6
+package device
7
+
8
+import (
9
+	"errors"
10
+
11
+	"golang.zx2c4.com/wireguard/conn"
12
+)
13
+
14
+// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn.
15
+func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
16
+	if device.net.bind == nil {
17
+		return errors.New("Bind is not yet initialized")
18
+	}
19
+
20
+	if iface, ok := device.net.bind.(conn.BindToInterface); ok {
21
+		return iface.BindToInterface4(interfaceIndex, blackhole)
22
+	}
23
+	return nil
24
+}
25
+
26
+// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn.
27
+func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
28
+	if device.net.bind == nil {
29
+		return errors.New("Bind is not yet initialized")
30
+	}
31
+
32
+	if iface, ok := device.net.bind.(conn.BindToInterface); ok {
33
+		return iface.BindToInterface6(interfaceIndex, blackhole)
34
+	}
35
+	return nil
36
+}

+ 0
- 187
device/conn.go View File

@@ -1,187 +0,0 @@
1
-/* SPDX-License-Identifier: MIT
2
- *
3
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
4
- */
5
-
6
-package device
7
-
8
-import (
9
-	"errors"
10
-	"net"
11
-	"strings"
12
-
13
-	"golang.org/x/net/ipv4"
14
-	"golang.org/x/net/ipv6"
15
-)
16
-
17
-const (
18
-	ConnRoutineNumber = 2
19
-)
20
-
21
-/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
22
- */
23
-type Bind interface {
24
-	SetMark(value uint32) error
25
-	ReceiveIPv6(buff []byte) (int, Endpoint, error)
26
-	ReceiveIPv4(buff []byte) (int, Endpoint, error)
27
-	Send(buff []byte, end Endpoint) error
28
-	Close() error
29
-}
30
-
31
-/* An Endpoint maintains the source/destination caching for a peer
32
- *
33
- * dst : the remote address of a peer ("endpoint" in uapi terminology)
34
- * src : the local address from which datagrams originate going to the peer
35
- */
36
-type Endpoint interface {
37
-	ClearSrc()           // clears the source address
38
-	SrcToString() string // returns the local source address (ip:port)
39
-	DstToString() string // returns the destination address (ip:port)
40
-	DstToBytes() []byte  // used for mac2 cookie calculations
41
-	DstIP() net.IP
42
-	SrcIP() net.IP
43
-}
44
-
45
-func parseEndpoint(s string) (*net.UDPAddr, error) {
46
-	// ensure that the host is an IP address
47
-
48
-	host, _, err := net.SplitHostPort(s)
49
-	if err != nil {
50
-		return nil, err
51
-	}
52
-	if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
53
-		// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
54
-		// trying to make sure with a small sanity test that this is a real IP address and
55
-		// not something that's likely to incur DNS lookups.
56
-		host = host[:i]
57
-	}
58
-	if ip := net.ParseIP(host); ip == nil {
59
-		return nil, errors.New("Failed to parse IP address: " + host)
60
-	}
61
-
62
-	// parse address and port
63
-
64
-	addr, err := net.ResolveUDPAddr("udp", s)
65
-	if err != nil {
66
-		return nil, err
67
-	}
68
-	ip4 := addr.IP.To4()
69
-	if ip4 != nil {
70
-		addr.IP = ip4
71
-	}
72
-	return addr, err
73
-}
74
-
75
-func unsafeCloseBind(device *Device) error {
76
-	var err error
77
-	netc := &device.net
78
-	if netc.bind != nil {
79
-		err = netc.bind.Close()
80
-		netc.bind = nil
81
-	}
82
-	netc.stopping.Wait()
83
-	return err
84
-}
85
-
86
-func (device *Device) BindSetMark(mark uint32) error {
87
-
88
-	device.net.Lock()
89
-	defer device.net.Unlock()
90
-
91
-	// check if modified
92
-
93
-	if device.net.fwmark == mark {
94
-		return nil
95
-	}
96
-
97
-	// update fwmark on existing bind
98
-
99
-	device.net.fwmark = mark
100
-	if device.isUp.Get() && device.net.bind != nil {
101
-		if err := device.net.bind.SetMark(mark); err != nil {
102
-			return err
103
-		}
104
-	}
105
-
106
-	// clear cached source addresses
107
-
108
-	device.peers.RLock()
109
-	for _, peer := range device.peers.keyMap {
110
-		peer.Lock()
111
-		defer peer.Unlock()
112
-		if peer.endpoint != nil {
113
-			peer.endpoint.ClearSrc()
114
-		}
115
-	}
116
-	device.peers.RUnlock()
117
-
118
-	return nil
119
-}
120
-
121
-func (device *Device) BindUpdate() error {
122
-
123
-	device.net.Lock()
124
-	defer device.net.Unlock()
125
-
126
-	// close existing sockets
127
-
128
-	if err := unsafeCloseBind(device); err != nil {
129
-		return err
130
-	}
131
-
132
-	// open new sockets
133
-
134
-	if device.isUp.Get() {
135
-
136
-		// bind to new port
137
-
138
-		var err error
139
-		netc := &device.net
140
-		netc.bind, netc.port, err = CreateBind(netc.port, device)
141
-		if err != nil {
142
-			netc.bind = nil
143
-			netc.port = 0
144
-			return err
145
-		}
146
-
147
-		// set fwmark
148
-
149
-		if netc.fwmark != 0 {
150
-			err = netc.bind.SetMark(netc.fwmark)
151
-			if err != nil {
152
-				return err
153
-			}
154
-		}
155
-
156
-		// clear cached source addresses
157
-
158
-		device.peers.RLock()
159
-		for _, peer := range device.peers.keyMap {
160
-			peer.Lock()
161
-			defer peer.Unlock()
162
-			if peer.endpoint != nil {
163
-				peer.endpoint.ClearSrc()
164
-			}
165
-		}
166
-		device.peers.RUnlock()
167
-
168
-		// start receiving routines
169
-
170
-		device.net.starting.Add(ConnRoutineNumber)
171
-		device.net.stopping.Add(ConnRoutineNumber)
172
-		go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
173
-		go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
174
-		device.net.starting.Wait()
175
-
176
-		device.log.Debug.Println("UDP bind has been updated")
177
-	}
178
-
179
-	return nil
180
-}
181
-
182
-func (device *Device) BindClose() error {
183
-	device.net.Lock()
184
-	err := unsafeCloseBind(device)
185
-	device.net.Unlock()
186
-	return err
187
-}

+ 136
- 10
device/device.go View File

@@ -11,15 +11,14 @@ import (
11 11
 	"sync/atomic"
12 12
 	"time"
13 13
 
14
+	"golang.org/x/net/ipv4"
15
+	"golang.org/x/net/ipv6"
16
+	"golang.zx2c4.com/wireguard/conn"
14 17
 	"golang.zx2c4.com/wireguard/ratelimiter"
18
+	"golang.zx2c4.com/wireguard/rwcancel"
15 19
 	"golang.zx2c4.com/wireguard/tun"
16 20
 )
17 21
 
18
-const (
19
-	DeviceRoutineNumberPerCPU     = 3
20
-	DeviceRoutineNumberAdditional = 2
21
-)
22
-
23 22
 type Device struct {
24 23
 	isUp     AtomicBool // device is (going) up
25 24
 	isClosed AtomicBool // device is closed? (acting as guard)
@@ -39,9 +38,10 @@ type Device struct {
39 38
 		starting sync.WaitGroup
40 39
 		stopping sync.WaitGroup
41 40
 		sync.RWMutex
42
-		bind   Bind   // bind interface
43
-		port   uint16 // listening port
44
-		fwmark uint32 // mark value (0 = disabled)
41
+		bind          conn.Bind // bind interface
42
+		netlinkCancel *rwcancel.RWCancel
43
+		port          uint16 // listening port
44
+		fwmark        uint32 // mark value (0 = disabled)
45 45
 	}
46 46
 
47 47
 	staticIdentity struct {
@@ -299,14 +299,16 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
299 299
 	cpus := runtime.NumCPU()
300 300
 	device.state.starting.Wait()
301 301
 	device.state.stopping.Wait()
302
-	device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
303
-	device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
304 302
 	for i := 0; i < cpus; i += 1 {
303
+		device.state.starting.Add(3)
304
+		device.state.stopping.Add(3)
305 305
 		go device.RoutineEncryption()
306 306
 		go device.RoutineDecryption()
307 307
 		go device.RoutineHandshake()
308 308
 	}
309 309
 
310
+	device.state.starting.Add(2)
311
+	device.state.stopping.Add(2)
310 312
 	go device.RoutineReadFromTUN()
311 313
 	go device.RoutineTUNEventReader()
312 314
 
@@ -413,3 +415,127 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
413 415
 	}
414 416
 	device.peers.RUnlock()
415 417
 }
418
+
419
+func unsafeCloseBind(device *Device) error {
420
+	var err error
421
+	netc := &device.net
422
+	if netc.netlinkCancel != nil {
423
+		netc.netlinkCancel.Cancel()
424
+	}
425
+	if netc.bind != nil {
426
+		err = netc.bind.Close()
427
+		netc.bind = nil
428
+	}
429
+	netc.stopping.Wait()
430
+	return err
431
+}
432
+
433
+func (device *Device) BindSetMark(mark uint32) error {
434
+
435
+	device.net.Lock()
436
+	defer device.net.Unlock()
437
+
438
+	// check if modified
439
+
440
+	if device.net.fwmark == mark {
441
+		return nil
442
+	}
443
+
444
+	// update fwmark on existing bind
445
+
446
+	device.net.fwmark = mark
447
+	if device.isUp.Get() && device.net.bind != nil {
448
+		if err := device.net.bind.SetMark(mark); err != nil {
449
+			return err
450
+		}
451
+	}
452
+
453
+	// clear cached source addresses
454
+
455
+	device.peers.RLock()
456
+	for _, peer := range device.peers.keyMap {
457
+		peer.Lock()
458
+		defer peer.Unlock()
459
+		if peer.endpoint != nil {
460
+			peer.endpoint.ClearSrc()
461
+		}
462
+	}
463
+	device.peers.RUnlock()
464
+
465
+	return nil
466
+}
467
+
468
+func (device *Device) BindUpdate() error {
469
+
470
+	device.net.Lock()
471
+	defer device.net.Unlock()
472
+
473
+	// close existing sockets
474
+
475
+	if err := unsafeCloseBind(device); err != nil {
476
+		return err
477
+	}
478
+
479
+	// open new sockets
480
+
481
+	if device.isUp.Get() {
482
+
483
+		// bind to new port
484
+
485
+		var err error
486
+		netc := &device.net
487
+		netc.bind, netc.port, err = conn.CreateBind(netc.port)
488
+		if err != nil {
489
+			netc.bind = nil
490
+			netc.port = 0
491
+			return err
492
+		}
493
+		netc.netlinkCancel, err = device.startRouteListener(netc.bind)
494
+		if err != nil {
495
+			netc.bind.Close()
496
+			netc.bind = nil
497
+			netc.port = 0
498
+			return err
499
+		}
500
+
501
+		// set fwmark
502
+
503
+		if netc.fwmark != 0 {
504
+			err = netc.bind.SetMark(netc.fwmark)
505
+			if err != nil {
506
+				return err
507
+			}
508
+		}
509
+
510
+		// clear cached source addresses
511
+
512
+		device.peers.RLock()
513
+		for _, peer := range device.peers.keyMap {
514
+			peer.Lock()
515
+			defer peer.Unlock()
516
+			if peer.endpoint != nil {
517
+				peer.endpoint.ClearSrc()
518
+			}
519
+		}
520
+		device.peers.RUnlock()
521
+
522
+		// start receiving routines
523
+
524
+		device.net.starting.Add(2)
525
+		device.net.stopping.Add(2)
526
+		go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
527
+		go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
528
+		device.net.starting.Wait()
529
+
530
+		device.log.Debug.Println("UDP bind has been updated")
531
+	}
532
+
533
+	return nil
534
+}
535
+
536
+func (device *Device) BindClose() error {
537
+	device.net.Lock()
538
+	err := unsafeCloseBind(device)
539
+	device.net.Unlock()
540
+	return err
541
+}

+ 4
- 2
device/peer.go View File

@@ -12,6 +12,8 @@ import (
12 12
 	"sync"
13 13
 	"sync/atomic"
14 14
 	"time"
15
+
16
+	"golang.zx2c4.com/wireguard/conn"
15 17
 )
16 18
 
17 19
 const (
@@ -38,7 +40,7 @@ type Peer struct {
38 40
 	keypairs                    Keypairs
39 41
 	handshake                   Handshake
40 42
 	device                      *Device
41
-	endpoint                    Endpoint
43
+	endpoint                    conn.Endpoint
42 44
 	persistentKeepaliveInterval uint16
43 45
 
44 46
 	timers struct {
@@ -293,7 +295,7 @@ func (peer *Peer) Stop() {
293 295
 
294 296
 var RoamingDisabled bool
295 297
 
296
-func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) {
298
+func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
297 299
 	if RoamingDisabled {
298 300
 		return
299 301
 	}

+ 5
- 4
device/receive.go View File

@@ -17,12 +17,13 @@ import (
17 17
 	"golang.org/x/crypto/chacha20poly1305"
18 18
 	"golang.org/x/net/ipv4"
19 19
 	"golang.org/x/net/ipv6"
20
+	"golang.zx2c4.com/wireguard/conn"
20 21
 )
21 22
 
22 23
 type QueueHandshakeElement struct {
23 24
 	msgType  uint32
24 25
 	packet   []byte
25
-	endpoint Endpoint
26
+	endpoint conn.Endpoint
26 27
 	buffer   *[MaxMessageSize]byte
27 28
 }
28 29
 
@@ -33,7 +34,7 @@ type QueueInboundElement struct {
33 34
 	packet   []byte
34 35
 	counter  uint64
35 36
 	keypair  *Keypair
36
-	endpoint Endpoint
37
+	endpoint conn.Endpoint
37 38
 }
38 39
 
39 40
 func (elem *QueueInboundElement) Drop() {
@@ -90,7 +91,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
90 91
  * Every time the bind is updated a new routine is started for
91 92
  * IPv4 and IPv6 (separately)
92 93
  */
93
-func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
94
+func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
94 95
 
95 96
 	logDebug := device.log.Debug
96 97
 	defer func() {
@@ -108,7 +109,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
108 109
 	var (
109 110
 		err      error
110 111
 		size     int
111
-		endpoint Endpoint
112
+		endpoint conn.Endpoint
112 113
 	)
113 114
 
114 115
 	for {

+ 12
- 0
device/sticky_default.go View File

@@ -0,0 +1,12 @@
1
+// +build !linux
2
+
3
+package device
4
+
5
+import (
6
+	"golang.zx2c4.com/wireguard/conn"
7
+	"golang.zx2c4.com/wireguard/rwcancel"
8
+)
9
+
10
+func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
11
+	return nil, nil
12
+}

+ 215
- 0
device/sticky_linux.go View File

@@ -0,0 +1,215 @@
1
+/* SPDX-License-Identifier: MIT
2
+ *
3
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
4
+ *
5
+ * This implements userspace semantics of "sticky sockets", modeled after
6
+ * WireGuard's kernelspace implementation. This is more or less a straight port
7
+ * of the sticky-sockets.c example code:
8
+ * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
9
+ *
10
+ * Currently there is no way to achieve this within the net package:
11
+ * See e.g. https://github.com/golang/go/issues/17930
12
+ * So this code is remains platform dependent.
13
+ */
14
+
15
+package device
16
+
17
+import (
18
+	"sync"
19
+	"unsafe"
20
+
21
+	"golang.org/x/sys/unix"
22
+	"golang.zx2c4.com/wireguard/conn"
23
+	"golang.zx2c4.com/wireguard/rwcancel"
24
+)
25
+
26
+func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
27
+	netlinkSock, err := createNetlinkRouteSocket()
28
+	if err != nil {
29
+		return nil, err
30
+	}
31
+	netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
32
+	if err != nil {
33
+		unix.Close(netlinkSock)
34
+		return nil, err
35
+	}
36
+
37
+	go device.routineRouteListener(bind, netlinkSock, netlinkCancel)
38
+
39
+	return netlinkCancel, nil
40
+}
41
+
42
+func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
43
+	type peerEndpointPtr struct {
44
+		peer     *Peer
45
+		endpoint *conn.Endpoint
46
+	}
47
+	var reqPeer map[uint32]peerEndpointPtr
48
+	var reqPeerLock sync.Mutex
49
+
50
+	defer unix.Close(netlinkSock)
51
+
52
+	for msg := make([]byte, 1<<16); ; {
53
+		var err error
54
+		var msgn int
55
+		for {
56
+			msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
57
+			if err == nil || !rwcancel.RetryAfterError(err) {
58
+				break
59
+			}
60
+			if !netlinkCancel.ReadyRead() {
61
+				return
62
+			}
63
+		}
64
+		if err != nil {
65
+			return
66
+		}
67
+
68
+		for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
69
+
70
+			hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
71
+
72
+			if uint(hdr.Len) > uint(len(remain)) {
73
+				break
74
+			}
75
+
76
+			switch hdr.Type {
77
+			case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
78
+				if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
79
+					if uint(len(remain)) < uint(hdr.Len) {
80
+						break
81
+					}
82
+					if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
83
+						attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
84
+						for {
85
+							if uint(len(attr)) < uint(unix.SizeofRtAttr) {
86
+								break
87
+							}
88
+							attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
89
+							if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
90
+								break
91
+							}
92
+							if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
93
+								ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
94
+								reqPeerLock.Lock()
95
+								if reqPeer == nil {
96
+									reqPeerLock.Unlock()
97
+									break
98
+								}
99
+								pePtr, ok := reqPeer[hdr.Seq]
100
+								reqPeerLock.Unlock()
101
+								if !ok {
102
+									break
103
+								}
104
+								pePtr.peer.Lock()
105
+								if &pePtr.peer.endpoint != pePtr.endpoint {
106
+									pePtr.peer.Unlock()
107
+									break
108
+								}
109
+								if uint32(pePtr.peer.endpoint.(*conn.NativeEndpoint).Src4().Ifindex) == ifidx {
110
+									pePtr.peer.Unlock()
111
+									break
112
+								}
113
+								pePtr.peer.endpoint.(*conn.NativeEndpoint).ClearSrc()
114
+								pePtr.peer.Unlock()
115
+							}
116
+							attr = attr[attrhdr.Len:]
117
+						}
118
+					}
119
+					break
120
+				}
121
+				reqPeerLock.Lock()
122
+				reqPeer = make(map[uint32]peerEndpointPtr)
123
+				reqPeerLock.Unlock()
124
+				go func() {
125
+					device.peers.RLock()
126
+					i := uint32(1)
127
+					for _, peer := range device.peers.keyMap {
128
+						peer.RLock()
129
+						if peer.endpoint == nil {
130
+							peer.RUnlock()
131
+							continue
132
+						}
133
+						nativeEP, _ := peer.endpoint.(*conn.NativeEndpoint)
134
+						if nativeEP == nil {
135
+							peer.RUnlock()
136
+							continue
137
+						}
138
+						if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
139
+							peer.RUnlock()
140
+							break
141
+						}
142
+						nlmsg := struct {
143
+							hdr     unix.NlMsghdr
144
+							msg     unix.RtMsg
145
+							dsthdr  unix.RtAttr
146
+							dst     [4]byte
147
+							srchdr  unix.RtAttr
148
+							src     [4]byte
149
+							markhdr unix.RtAttr
150
+							mark    uint32
151
+						}{
152
+							unix.NlMsghdr{
153
+								Type:  uint16(unix.RTM_GETROUTE),
154
+								Flags: unix.NLM_F_REQUEST,
155
+								Seq:   i,
156
+							},
157
+							unix.RtMsg{
158
+								Family:  unix.AF_INET,
159
+								Dst_len: 32,
160
+								Src_len: 32,
161
+							},
162
+							unix.RtAttr{
163
+								Len:  8,
164
+								Type: unix.RTA_DST,
165
+							},
166
+							nativeEP.Dst4().Addr,
167
+							unix.RtAttr{
168
+								Len:  8,
169
+								Type: unix.RTA_SRC,
170
+							},
171
+							nativeEP.Src4().Src,
172
+							unix.RtAttr{
173
+								Len:  8,
174
+								Type: unix.RTA_MARK,
175
+							},
176
+							uint32(bind.LastMark()),
177
+						}
178
+						nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
179
+						reqPeerLock.Lock()
180
+						reqPeer[i] = peerEndpointPtr{
181
+							peer:     peer,
182
+							endpoint: &peer.endpoint,
183
+						}
184
+						reqPeerLock.Unlock()
185
+						peer.RUnlock()
186
+						i++
187
+						_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
188
+						if err != nil {
189
+							break
190
+						}
191
+					}
192
+					device.peers.RUnlock()
193
+				}()
194
+			}
195
+			remain = remain[hdr.Len:]
196
+		}
197
+	}
198
+}
199
+
200
+func createNetlinkRouteSocket() (int, error) {
201
+	sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
202
+	if err != nil {
203
+		return -1, err
204
+	}
205
+	saddr := &unix.SockaddrNetlink{
206
+		Family: unix.AF_NETLINK,
207
+		Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
208
+	}
209
+	err = unix.Bind(sock, saddr)
210
+	if err != nil {
211
+		unix.Close(sock)
212
+		return -1, err
213
+	}
214
+	return sock, nil
215
+}

+ 2
- 1
device/uapi.go View File

@@ -15,6 +15,7 @@ import (
15 15
 	"sync/atomic"
16 16
 	"time"
17 17
 
18
+	"golang.zx2c4.com/wireguard/conn"
18 19
 	"golang.zx2c4.com/wireguard/ipc"
19 20
 )
20 21
 
@@ -306,7 +307,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
306 307
 				err := func() error {
307 308
 					peer.Lock()
308 309
 					defer peer.Unlock()
309
-					endpoint, err := CreateEndpoint(value)
310
+					endpoint, err := conn.CreateEndpoint(value)
310 311
 					if err != nil {
311 312
 						return err
312 313
 					}

Loading…
Cancel
Save