@@ -3,6 +3,7 @@ package main
33import (
44 "bufio"
55 "context"
6+ "crypto/tls"
67 "fmt"
78 "io"
89 "net"
5657 bandwidthLimit int
5758 denyWebPage bool
5859 requestLimit int
60+ certPath string
61+ keyPath string
5962)
6063
6164type AccessRecord struct {
@@ -126,6 +129,85 @@ var AcceptDomain = []string{
126129 "api.github.com" ,
127130}
128131
132+ type CertContainer struct {
133+ CertPEM []byte
134+ KeyPEM []byte
135+ CertWatcher * fswatch.Watcher
136+ KeyWatcher * fswatch.Watcher
137+ Cert tls.Certificate
138+ }
139+
140+ func NewCertContainer () (* CertContainer , error ) {
141+ CertPEM , err := os .ReadFile (certPath )
142+ if err != nil {
143+ return nil , E .Cause (err , "Read cert file" )
144+ }
145+ KeyPEM , err := os .ReadFile (keyPath )
146+ if err != nil {
147+ return nil , E .Cause (err , "Read key file" )
148+ }
149+ Cert , err := tls .X509KeyPair (CertPEM , KeyPEM )
150+ if err != nil {
151+ return nil , E .Cause (err , "Check key pair" )
152+ }
153+ container := & CertContainer {
154+ CertPEM : CertPEM ,
155+ KeyPEM : KeyPEM ,
156+ Cert : Cert ,
157+ }
158+ if CertWatcher , err := fswatch .NewWatcher (fswatch.Options {
159+ Path : []string {certPath },
160+ Callback : func (path string ) {
161+ var err error
162+ CertPEM , err = os .ReadFile (certPath )
163+ if err != nil {
164+ return
165+ }
166+ container .CertPEM = CertPEM
167+ container .Update ()
168+ },
169+ }); err == nil {
170+ err = CertWatcher .Start ()
171+ if err == nil {
172+ container .CertWatcher = CertWatcher
173+ }
174+ }
175+ if KeyWatcher , err := fswatch .NewWatcher (fswatch.Options {
176+ Path : []string {keyPath },
177+ Callback : func (path string ) {
178+ var err error
179+ KeyPEM , err = os .ReadFile (keyPath )
180+ if err == nil {
181+ return
182+ }
183+ container .KeyPEM = KeyPEM
184+ container .Update ()
185+ },
186+ }); err == nil {
187+ err = KeyWatcher .Start ()
188+ if err == nil {
189+ container .KeyWatcher = KeyWatcher
190+ }
191+ }
192+ return container , nil
193+ }
194+
195+ func (c * CertContainer ) Update () {
196+ Cert , err := tls .X509KeyPair (c .CertPEM , c .KeyPEM )
197+ if err == nil {
198+ c .Cert = Cert
199+ }
200+ }
201+
202+ func (c * CertContainer ) Close () {
203+ if c .CertWatcher != nil {
204+ c .CertWatcher .Close ()
205+ }
206+ if c .KeyWatcher != nil {
207+ c .KeyWatcher .Close ()
208+ }
209+ }
210+
129211var command = & cobra.Command {
130212 Use : "git-proxy" ,
131213 Short : "A HTTP service to proxy git requests" ,
@@ -140,6 +222,8 @@ func init() {
140222 command .PersistentFlags ().IntVarP (& bandwidthLimit , "bandwidth-limit" , "l" , 0 , "set total bandwidth limit (MB/s), 0 as no limit" )
141223 command .PersistentFlags ().IntVarP (& requestLimit , "request-limit" , "r" , 0 , "set request limit by ip, 0 as no limit" )
142224 command .PersistentFlags ().BoolVarP (& denyWebPage , "deny-web-page" , "" , false , "deny web page requests" )
225+ command .PersistentFlags ().StringVarP (& certPath , "cert-path" , "c" , "cert.pem" , "set tls cert path" )
226+ command .PersistentFlags ().StringVarP (& keyPath , "key-path" , "k" , "key.pem" , "set tls key path" )
143227}
144228
145229func main () {
@@ -198,6 +282,20 @@ func run(*cobra.Command, []string) {
198282 }
199283 listen := M .ParseSocksaddr (":" + strconv .Itoa (runningPort ))
200284 listener := listenTCP (listen )
285+ if len (certPath ) == 0 || len (keyPath ) == 0 {
286+ log .Info ("Listening TCP port " , listen .Port )
287+ } else if container , err := NewCertContainer (); err != nil {
288+ log .Warn (E .Cause (err , "Update TCP to TLS" ))
289+ log .Info ("Listening TCP port " , listen .Port )
290+ } else {
291+ log .Info ("Listening TLS port " , listen .Port )
292+ listener = tls .NewListener (listener , & tls.Config {
293+ GetCertificate : func (info * tls.ClientHelloInfo ) (* tls.Certificate , error ) {
294+ return & container .Cert , nil
295+ },
296+ })
297+ defer container .Close ()
298+ }
201299 chiRouter := chi .NewRouter ()
202300 chiRouter .Group (func (r chi.Router ) {
203301 r .Use (middleware .RealIP )
@@ -425,7 +523,6 @@ func listenTCP(address M.Socksaddr) net.Listener {
425523 }
426524 address .Port = address .Port + 1
427525 }
428- log .Info ("Listening tcp port " , address .Port )
429526 return listener
430527}
431528
0 commit comments