aboutsummaryrefslogtreecommitdiff
path: root/internal/httputil/rate_limiting.go
blob: c4f47c7b55bd10ced790020499264264f9464c43 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
package httputil

import (
	"net/http"
	"sync"
	"time"

	"github.com/matrix-org/dendrite/clientapi/jsonerror"
	"github.com/matrix-org/dendrite/setup/config"
	"github.com/matrix-org/util"
)

type RateLimits struct {
	limits           map[string]chan struct{}
	limitsMutex      sync.RWMutex
	cleanMutex       sync.RWMutex
	enabled          bool
	requestThreshold int64
	cooloffDuration  time.Duration
}

func NewRateLimits(cfg *config.RateLimiting) *RateLimits {
	l := &RateLimits{
		limits:           make(map[string]chan struct{}),
		enabled:          cfg.Enabled,
		requestThreshold: cfg.Threshold,
		cooloffDuration:  time.Duration(cfg.CooloffMS) * time.Millisecond,
	}
	if l.enabled {
		go l.clean()
	}
	return l
}

func (l *RateLimits) clean() {
	for {
		// On a 30 second interval, we'll take an exclusive write
		// lock of the entire map and see if any of the channels are
		// empty. If they are then we will close and delete them,
		// freeing up memory.
		time.Sleep(time.Second * 30)
		l.cleanMutex.Lock()
		l.limitsMutex.Lock()
		for k, c := range l.limits {
			if len(c) == 0 {
				close(c)
				delete(l.limits, k)
			}
		}
		l.limitsMutex.Unlock()
		l.cleanMutex.Unlock()
	}
}

func (l *RateLimits) Limit(req *http.Request) *util.JSONResponse {
	// If rate limiting is disabled then do nothing.
	if !l.enabled {
		return nil
	}

	// Take a read lock out on the cleaner mutex. The cleaner expects to
	// be able to take a write lock, which isn't possible while there are
	// readers, so this has the effect of blocking the cleaner goroutine
	// from doing its work until there are no requests in flight.
	l.cleanMutex.RLock()
	defer l.cleanMutex.RUnlock()

	// First of all, work out if X-Forwarded-For was sent to us. If not
	// then we'll just use the IP address of the caller.
	caller := req.RemoteAddr
	if forwardedFor := req.Header.Get("X-Forwarded-For"); forwardedFor != "" {
		caller = forwardedFor
	}

	// Look up the caller's channel, if they have one.
	l.limitsMutex.RLock()
	rateLimit, ok := l.limits[caller]
	l.limitsMutex.RUnlock()

	// If the caller doesn't have a channel, create one and write it
	// back to the map.
	if !ok {
		rateLimit = make(chan struct{}, l.requestThreshold)

		l.limitsMutex.Lock()
		l.limits[caller] = rateLimit
		l.limitsMutex.Unlock()
	}

	// Check if the user has got free resource slots for this request.
	// If they don't then we'll return an error.
	select {
	case rateLimit <- struct{}{}:
	default:
		// We hit the rate limit. Tell the client to back off.
		return &util.JSONResponse{
			Code: http.StatusTooManyRequests,
			JSON: jsonerror.LimitExceeded("You are sending too many requests too quickly!", l.cooloffDuration.Milliseconds()),
		}
	}

	// After the time interval, drain a resource from the rate limiting
	// channel. This will free up space in the channel for new requests.
	go func() {
		<-time.After(l.cooloffDuration)
		<-rateLimit
	}()
	return nil
}