@@ -3,6 +3,7 @@ package main
33import (
44 "bufio"
55 "context"
6+ "fmt"
67 "io"
78 "net"
89 "net/http"
@@ -12,6 +13,7 @@ import (
1213 "path/filepath"
1314 "strconv"
1415 "strings"
16+ "sync"
1517 "syscall"
1618 "time"
1719
5355 blacklistPath string
5456 bandwidthLimit int
5557 denyWebPage bool
58+ requestLimit int
5659)
5760
61+ type AccessRecord struct {
62+ count int
63+ }
64+
65+ type IPLimiter struct {
66+ records map [string ]* AccessRecord
67+ limit int
68+
69+ sync.RWMutex
70+ }
71+
72+ func NewIPLimiter () * IPLimiter {
73+ return & IPLimiter {
74+ records : make (map [string ]* AccessRecord ),
75+ }
76+ }
77+
78+ func (ir * IPLimiter ) GetAccess (address string ) bool {
79+ ir .RLock ()
80+ record , exist := ir .records [address ]
81+ ir .RUnlock ()
82+ if exist {
83+ if record .count < ir .limit {
84+ record .count = record .count + 1
85+ return true
86+ } else {
87+ return false
88+ }
89+ } else {
90+ ir .Lock ()
91+ ir .records [address ] = & AccessRecord {1 }
92+ ir .Unlock ()
93+ return true
94+ }
95+ }
96+
97+ func (ir * IPLimiter ) Leave (address string ) {
98+ ir .RLock ()
99+ record , exist := ir .records [address ]
100+ ir .RUnlock ()
101+ if exist {
102+ if record .count == 1 {
103+ ir .Lock ()
104+ delete (ir .records , address )
105+ ir .Unlock ()
106+ } else {
107+ record .count = record .count - 1
108+ }
109+ }
110+ }
111+
112+ var RequestLimiter * IPLimiter
113+
58114var BandwidthLimiter * R.Bucket
59115
60116var Blacklist []RepoInfo
@@ -82,6 +138,7 @@ func init() {
82138 command .PersistentFlags ().StringVarP (& domainListPath , "domain-list-path" , "d" , "domainlist.txt" , "set accept domain" )
83139 command .PersistentFlags ().StringVarP (& blacklistPath , "blacklist-path" , "b" , "blacklist.txt" , "set repository blacklist" )
84140 command .PersistentFlags ().IntVarP (& bandwidthLimit , "bandwidth-limit" , "l" , 0 , "set total bandwidth limit (MB/s), 0 as no limit" )
141+ command .PersistentFlags ().IntVarP (& requestLimit , "request-limit" , "r" , 0 , "set request limit by ip, 0 as no limit" )
85142 command .PersistentFlags ().BoolVarP (& denyWebPage , "deny-web-page" , "" , false , "deny web page requests" )
86143}
87144
@@ -112,6 +169,10 @@ func run(*cobra.Command, []string) {
112169 BandwidthLimiter = R .NewBucketWithRate (float64 (bandwidthLimit * 1024 * 1024 ), int64 (bandwidthLimit * 1024 * 1024 ))
113170 log .Info ("Bandwidth limit is set as " , bandwidthLimit , "MB/s" )
114171 }
172+ if requestLimit > 0 {
173+ log .Info ("Request limit is set as " , requestLimit , " each IP" )
174+ RequestLimiter = NewIPLimiter ()
175+ }
115176 if denyWebPage {
116177 log .Info ("Denying web page requests" )
117178 }
@@ -142,6 +203,7 @@ func run(*cobra.Command, []string) {
142203 r .Use (middleware .RealIP )
143204 r .Use (setContext )
144205 r .Use (commonLog )
206+ r .Use (requestLimitHandle )
145207 r .Get ("/" , hello )
146208 r .Mount ("/" , finalHandle ())
147209 })
@@ -385,6 +447,29 @@ func commonLog(next http.Handler) http.Handler {
385447 })
386448}
387449
450+ func requestLimitHandle (next http.Handler ) http.Handler {
451+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
452+ if RequestLimiter == nil {
453+ next .ServeHTTP (w , r )
454+ return
455+ }
456+ ip := M .ParseSocksaddr (r .RemoteAddr ).Addr .String ()
457+ access := RequestLimiter .GetAccess (ip )
458+ if ! access {
459+ log .WarnContext (r .Context (), "Match request limit" )
460+ w .WriteHeader (http .StatusTooManyRequests )
461+ if requestLimit == 1 {
462+ w .Write ([]byte ("You can only initiate 1 request at the same time" ))
463+ } else {
464+ w .Write ([]byte (fmt .Sprint ("You can only initiate " , requestLimit , " requests at the same time" )))
465+ }
466+ return
467+ }
468+ next .ServeHTTP (w , r )
469+ RequestLimiter .Leave (ip )
470+ })
471+ }
472+
388473func finalHandle () http.Handler {
389474 return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
390475 finalHandler (r ).ServeHTTP (w , r )
0 commit comments