aboutsummaryrefslogtreecommitdiff
path: root/internal/httputil/rate_limiting.go
blob: dab36481e71e5ca81ddc9dde280f977f546cca29 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package httputil

import (
	"net/http"
	"sync"
	"time"

	"github.com/matrix-org/dendrite/clientapi/jsonerror"
	"github.com/matrix-org/dendrite/setup/config"
	userapi "github.com/matrix-org/dendrite/userapi/api"
	"github.com/matrix-org/util"
)

type RateLimits struct {
	limits           map[string]chan struct{}
	limitsMutex      sync.RWMutex
	cleanMutex       sync.RWMutex
	enabled          bool
	requestThreshold int64
	cooloffDuration  time.Duration
	exemptUserIDs    map[string]struct{}
}

func NewRateLimits(cfg *config.RateLimiting) *RateLimits {
	l := &RateLimits{
		limits:           make(map[string]chan struct{}),
		enabled:          cfg.Enabled,
		requestThreshold: cfg.Threshold,
		cooloffDuration:  time.Duration(cfg.CooloffMS) * time.Millisecond,
		exemptUserIDs:    map[string]struct{}{},
	}
	for _, userID := range cfg.ExemptUserIDs {
		l.exemptUserIDs[userID] = struct{}{}
	}
	if l.enabled {
		go l.clean()
	}
	return l
}

func (l *RateLimits) clean() {
	for {
		// On a 30 second interval, we'll take an exclusive write
		// lock of the entire map and see if any of the channels are
		// empty. If they are then we will close and delete them,
		// freeing up memory.
		time.Sleep(time.Second * 30)
		l.cleanMutex.Lock()
		l.limitsMutex.Lock()
		for k, c := range l.limits {
			if len(c) == 0 {
				close(c)
				delete(l.limits, k)
			}
		}
		l.limitsMutex.Unlock()
		l.cleanMutex.Unlock()
	}
}

func (l *RateLimits) Limit(req *http.Request, device *userapi.Device) *util.JSONResponse {
	// If rate limiting is disabled then do nothing.
	if !l.enabled {
		return nil
	}

	// Take a read lock out on the cleaner mutex. The cleaner expects to
	// be able to take a write lock, which isn't possible while there are
	// readers, so this has the effect of blocking the cleaner goroutine
	// from doing its work until there are no requests in flight.
	l.cleanMutex.RLock()
	defer l.cleanMutex.RUnlock()

	// First of all, work out if X-Forwarded-For was sent to us. If not
	// then we'll just use the IP address of the caller.
	var caller string
	if device != nil {
		switch device.AccountType {
		case userapi.AccountTypeAdmin:
			return nil // don't rate-limit server administrators
		case userapi.AccountTypeAppService:
			return nil // don't rate-limit appservice users
		default:
			if _, ok := l.exemptUserIDs[device.UserID]; ok {
				// If the user is exempt from rate limiting then do nothing.
				return nil
			}
			caller = device.UserID + device.ID
		}
	} else {
		if forwardedFor := req.Header.Get("X-Forwarded-For"); forwardedFor != "" {
			caller = forwardedFor
		} else {
			caller = req.RemoteAddr
		}
	}

	// Look up the caller's channel, if they have one.
	l.limitsMutex.RLock()
	rateLimit, ok := l.limits[caller]
	l.limitsMutex.RUnlock()

	// If the caller doesn't have a channel, create one and write it
	// back to the map.
	if !ok {
		rateLimit = make(chan struct{}, l.requestThreshold)

		l.limitsMutex.Lock()
		l.limits[caller] = rateLimit
		l.limitsMutex.Unlock()
	}

	// Check if the user has got free resource slots for this request.
	// If they don't then we'll return an error.
	select {
	case rateLimit <- struct{}{}:
	default:
		// We hit the rate limit. Tell the client to back off.
		return &util.JSONResponse{
			Code: http.StatusTooManyRequests,
			JSON: jsonerror.LimitExceeded("You are sending too many requests too quickly!", l.cooloffDuration.Milliseconds()),
		}
	}

	// After the time interval, drain a resource from the rate limiting
	// channel. This will free up space in the channel for new requests.
	go func() {
		<-time.After(l.cooloffDuration)
		<-rateLimit
	}()
	return nil
}