Browse Source

ratelimiter: use a fake clock in tests and style cleanups

The existing test would occasionally flake out with:

	--- FAIL: TestRatelimiter (0.12s)
	    ratelimiter_test.go:99: Test failed for 127.0.0.1 , on: 7 ( not having refilled enough ) expected: false got: true
	FAIL
	FAIL    golang.zx2c4.com/wireguard/ratelimiter  0.171s

The fake clock also means the tests run much faster, so
testing this package with -count=1000 now takes < 100ms.

While here, several style cleanups. The most significant one
is unembeding the sync.Mutex fields in the rate limiter objects.
Embedded as they were, the lock methods were accessible
outside the ratelimiter package. As they aren't needed externally,
keep them internal to make them easier to reason about.

Passes `go test -race -count=10000 ./ratelimiter`

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
David Crawshaw 5 months ago
parent
commit
9cd8909df2
2 changed files with 88 additions and 65 deletions
  1. 53
    46
      ratelimiter/ratelimiter.go
  2. 35
    19
      ratelimiter/ratelimiter_test.go

+ 53
- 46
ratelimiter/ratelimiter.go View File

@@ -20,21 +20,23 @@ const (
20 20
 )
21 21
 
22 22
 type RatelimiterEntry struct {
23
-	sync.Mutex
23
+	mu       sync.Mutex
24 24
 	lastTime time.Time
25 25
 	tokens   int64
26 26
 }
27 27
 
28 28
 type Ratelimiter struct {
29
-	sync.RWMutex
30
-	stopReset chan struct{}
29
+	mu      sync.RWMutex
30
+	timeNow func() time.Time
31
+
32
+	stopReset chan struct{} // send to reset, close to stop
31 33
 	tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
32 34
 	tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
33 35
 }
34 36
 
35 37
 func (rate *Ratelimiter) Close() {
36
-	rate.Lock()
37
-	defer rate.Unlock()
38
+	rate.mu.Lock()
39
+	defer rate.mu.Unlock()
38 40
 
39 41
 	if rate.stopReset != nil {
40 42
 		close(rate.stopReset)
@@ -42,11 +44,14 @@ func (rate *Ratelimiter) Close() {
42 44
 }
43 45
 
44 46
 func (rate *Ratelimiter) Init() {
45
-	rate.Lock()
46
-	defer rate.Unlock()
47
+	rate.mu.Lock()
48
+	defer rate.mu.Unlock()
47 49
 
48
-	// stop any ongoing garbage collection routine
50
+	if rate.timeNow == nil {
51
+		rate.timeNow = time.Now
52
+	}
49 53
 
54
+	// stop any ongoing garbage collection routine
50 55
 	if rate.stopReset != nil {
51 56
 		close(rate.stopReset)
52 57
 	}
@@ -55,50 +60,52 @@ func (rate *Ratelimiter) Init() {
55 60
 	rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
56 61
 	rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
57 62
 
58
-	// start garbage collection routine
63
+	stopReset := rate.stopReset // store in case Init is called again.
59 64
 
65
+	// Start garbage collection routine.
60 66
 	go func() {
61 67
 		ticker := time.NewTicker(time.Second)
62 68
 		ticker.Stop()
63 69
 		for {
64 70
 			select {
65
-			case _, ok := <-rate.stopReset:
71
+			case _, ok := <-stopReset:
66 72
 				ticker.Stop()
67
-				if ok {
68
-					ticker = time.NewTicker(time.Second)
69
-				} else {
73
+				if !ok {
70 74
 					return
71 75
 				}
76
+				ticker = time.NewTicker(time.Second)
72 77
 			case <-ticker.C:
73
-				func() {
74
-					rate.Lock()
75
-					defer rate.Unlock()
76
-
77
-					for key, entry := range rate.tableIPv4 {
78
-						entry.Lock()
79
-						if time.Since(entry.lastTime) > garbageCollectTime {
80
-							delete(rate.tableIPv4, key)
81
-						}
82
-						entry.Unlock()
83
-					}
84
-
85
-					for key, entry := range rate.tableIPv6 {
86
-						entry.Lock()
87
-						if time.Since(entry.lastTime) > garbageCollectTime {
88
-							delete(rate.tableIPv6, key)
89
-						}
90
-						entry.Unlock()
91
-					}
92
-
93
-					if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 {
94
-						ticker.Stop()
95
-					}
96
-				}()
78
+				if rate.cleanup() {
79
+					ticker.Stop()
80
+				}
97 81
 			}
98 82
 		}
99 83
 	}()
100 84
 }
101 85
 
86
+func (rate *Ratelimiter) cleanup() (empty bool) {
87
+	rate.mu.Lock()
88
+	defer rate.mu.Unlock()
89
+
90
+	for key, entry := range rate.tableIPv4 {
91
+		entry.mu.Lock()
92
+		if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
93
+			delete(rate.tableIPv4, key)
94
+		}
95
+		entry.mu.Unlock()
96
+	}
97
+
98
+	for key, entry := range rate.tableIPv6 {
99
+		entry.mu.Lock()
100
+		if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
101
+			delete(rate.tableIPv6, key)
102
+		}
103
+		entry.mu.Unlock()
104
+	}
105
+
106
+	return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
107
+}
108
+
102 109
 func (rate *Ratelimiter) Allow(ip net.IP) bool {
103 110
 	var entry *RatelimiterEntry
104 111
 	var keyIPv4 [net.IPv4len]byte
@@ -109,7 +116,7 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
109 116
 	IPv4 := ip.To4()
110 117
 	IPv6 := ip.To16()
111 118
 
112
-	rate.RLock()
119
+	rate.mu.RLock()
113 120
 
114 121
 	if IPv4 != nil {
115 122
 		copy(keyIPv4[:], IPv4)
@@ -119,15 +126,15 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
119 126
 		entry = rate.tableIPv6[keyIPv6]
120 127
 	}
121 128
 
122
-	rate.RUnlock()
129
+	rate.mu.RUnlock()
123 130
 
124 131
 	// make new entry if not found
125 132
 
126 133
 	if entry == nil {
127 134
 		entry = new(RatelimiterEntry)
128 135
 		entry.tokens = maxTokens - packetCost
129
-		entry.lastTime = time.Now()
130
-		rate.Lock()
136
+		entry.lastTime = rate.timeNow()
137
+		rate.mu.Lock()
131 138
 		if IPv4 != nil {
132 139
 			rate.tableIPv4[keyIPv4] = entry
133 140
 			if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
@@ -139,14 +146,14 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
139 146
 				rate.stopReset <- struct{}{}
140 147
 			}
141 148
 		}
142
-		rate.Unlock()
149
+		rate.mu.Unlock()
143 150
 		return true
144 151
 	}
145 152
 
146 153
 	// add tokens to entry
147 154
 
148
-	entry.Lock()
149
-	now := time.Now()
155
+	entry.mu.Lock()
156
+	now := rate.timeNow()
150 157
 	entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
151 158
 	entry.lastTime = now
152 159
 	if entry.tokens > maxTokens {
@@ -157,9 +164,9 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
157 164
 
158 165
 	if entry.tokens > packetCost {
159 166
 		entry.tokens -= packetCost
160
-		entry.Unlock()
167
+		entry.mu.Unlock()
161 168
 		return true
162 169
 	}
163
-	entry.Unlock()
170
+	entry.mu.Unlock()
164 171
 	return false
165 172
 }

+ 35
- 19
ratelimiter/ratelimiter_test.go View File

@@ -11,22 +11,21 @@ import (
11 11
 	"time"
12 12
 )
13 13
 
14
-type RatelimiterResult struct {
14
+type result struct {
15 15
 	allowed bool
16 16
 	text    string
17 17
 	wait    time.Duration
18 18
 }
19 19
 
20 20
 func TestRatelimiter(t *testing.T) {
21
+	var rate Ratelimiter
22
+	var expectedResults []result
21 23
 
22
-	var ratelimiter Ratelimiter
23
-	var expectedResults []RatelimiterResult
24
-
25
-	Nano := func(nano int64) time.Duration {
24
+	nano := func(nano int64) time.Duration {
26 25
 		return time.Nanosecond * time.Duration(nano)
27 26
 	}
28 27
 
29
-	Add := func(res RatelimiterResult) {
28
+	add := func(res result) {
30 29
 		expectedResults = append(
31 30
 			expectedResults,
32 31
 			res,
@@ -34,40 +33,40 @@ func TestRatelimiter(t *testing.T) {
34 33
 	}
35 34
 
36 35
 	for i := 0; i < packetsBurstable; i++ {
37
-		Add(RatelimiterResult{
36
+		add(result{
38 37
 			allowed: true,
39 38
 			text:    "initial burst",
40 39
 		})
41 40
 	}
42 41
 
43
-	Add(RatelimiterResult{
42
+	add(result{
44 43
 		allowed: false,
45 44
 		text:    "after burst",
46 45
 	})
47 46
 
48
-	Add(RatelimiterResult{
47
+	add(result{
49 48
 		allowed: true,
50
-		wait:    Nano(time.Second.Nanoseconds() / packetsPerSecond),
49
+		wait:    nano(time.Second.Nanoseconds() / packetsPerSecond),
51 50
 		text:    "filling tokens for single packet",
52 51
 	})
53 52
 
54
-	Add(RatelimiterResult{
53
+	add(result{
55 54
 		allowed: false,
56 55
 		text:    "not having refilled enough",
57 56
 	})
58 57
 
59
-	Add(RatelimiterResult{
58
+	add(result{
60 59
 		allowed: true,
61
-		wait:    2 * (Nano(time.Second.Nanoseconds() / packetsPerSecond)),
60
+		wait:    2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)),
62 61
 		text:    "filling tokens for two packet burst",
63 62
 	})
64 63
 
65
-	Add(RatelimiterResult{
64
+	add(result{
66 65
 		allowed: true,
67 66
 		text:    "second packet in 2 packet burst",
68 67
 	})
69 68
 
70
-	Add(RatelimiterResult{
69
+	add(result{
71 70
 		allowed: false,
72 71
 		text:    "packet following 2 packet burst",
73 72
 	})
@@ -89,14 +88,31 @@ func TestRatelimiter(t *testing.T) {
89 88
 		net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
90 89
 	}
91 90
 
92
-	ratelimiter.Init()
91
+	now := time.Now()
92
+	rate.timeNow = func() time.Time {
93
+		return now
94
+	}
95
+	defer func() {
96
+		// Lock to avoid data race with cleanup goroutine from Init.
97
+		rate.mu.Lock()
98
+		defer rate.mu.Unlock()
99
+
100
+		rate.timeNow = time.Now
101
+	}()
102
+	timeSleep := func(d time.Duration) {
103
+		now = now.Add(d + 1)
104
+		rate.cleanup()
105
+	}
106
+
107
+	rate.Init()
108
+	defer rate.Close()
93 109
 
94 110
 	for i, res := range expectedResults {
95
-		time.Sleep(res.wait)
111
+		timeSleep(res.wait)
96 112
 		for _, ip := range ips {
97
-			allowed := ratelimiter.Allow(ip)
113
+			allowed := rate.Allow(ip)
98 114
 			if allowed != res.allowed {
99
-				t.Fatal("Test failed for", ip.String(), ", on:", i, "(", res.text, ")", "expected:", res.allowed, "got:", allowed)
115
+				t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed)
100 116
 			}
101 117
 		}
102 118
 	}

Loading…
Cancel
Save