Skip to content

Commit 1055a78

Browse files
committed
feat: add tls support
1 parent 376d965 commit 1055a78

File tree

3 files changed

+102
-1
lines changed

3 files changed

+102
-1
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@
1212
/venv/
1313
domainlist.txt
1414
blacklist.txt
15+
cert.pem
16+
key.pem

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@ Usage:
1313
Flags:
1414
-l, --bandwidth-limit int set total bandwidth limit (MB/s), 0 as no limit
1515
-b, --blacklist-path string set repository blacklist (default "blacklist.txt")
16+
-c, --cert-path string set tls cert path (default "cert.pem")
1617
--deny-web-page deny web page requests
1718
--disable-color disable color output
1819
-d, --domain-list-path string set accept domain (default "domainlist.txt")
1920
-h, --help help for git-proxy
21+
-k, --key-path string set tls key path (default "key.pem")
2022
-r, --request-limit int set request limit by ip, 0 as no limit
2123
-p, --running-port int disable color output (default 30000)
2224
```

main.go

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package main
33
import (
44
"bufio"
55
"context"
6+
"crypto/tls"
67
"fmt"
78
"io"
89
"net"
@@ -56,6 +57,8 @@ var (
5657
bandwidthLimit int
5758
denyWebPage bool
5859
requestLimit int
60+
certPath string
61+
keyPath string
5962
)
6063

6164
type AccessRecord struct {
@@ -126,6 +129,85 @@ var AcceptDomain = []string{
126129
"api.github.com",
127130
}
128131

132+
type CertContainer struct {
133+
CertPEM []byte
134+
KeyPEM []byte
135+
CertWatcher *fswatch.Watcher
136+
KeyWatcher *fswatch.Watcher
137+
Cert tls.Certificate
138+
}
139+
140+
func NewCertContainer() (*CertContainer, error) {
141+
CertPEM, err := os.ReadFile(certPath)
142+
if err != nil {
143+
return nil, E.Cause(err, "Read cert file")
144+
}
145+
KeyPEM, err := os.ReadFile(keyPath)
146+
if err != nil {
147+
return nil, E.Cause(err, "Read key file")
148+
}
149+
Cert, err := tls.X509KeyPair(CertPEM, KeyPEM)
150+
if err != nil {
151+
return nil, E.Cause(err, "Check key pair")
152+
}
153+
container := &CertContainer{
154+
CertPEM: CertPEM,
155+
KeyPEM: KeyPEM,
156+
Cert: Cert,
157+
}
158+
if CertWatcher, err := fswatch.NewWatcher(fswatch.Options{
159+
Path: []string{certPath},
160+
Callback: func(path string) {
161+
var err error
162+
CertPEM, err = os.ReadFile(certPath)
163+
if err != nil {
164+
return
165+
}
166+
container.CertPEM = CertPEM
167+
container.Update()
168+
},
169+
}); err == nil {
170+
err = CertWatcher.Start()
171+
if err == nil {
172+
container.CertWatcher = CertWatcher
173+
}
174+
}
175+
if KeyWatcher, err := fswatch.NewWatcher(fswatch.Options{
176+
Path: []string{keyPath},
177+
Callback: func(path string) {
178+
var err error
179+
KeyPEM, err = os.ReadFile(keyPath)
180+
if err == nil {
181+
return
182+
}
183+
container.KeyPEM = KeyPEM
184+
container.Update()
185+
},
186+
}); err == nil {
187+
err = KeyWatcher.Start()
188+
if err == nil {
189+
container.KeyWatcher = KeyWatcher
190+
}
191+
}
192+
return container, nil
193+
}
194+
195+
func (c *CertContainer) Update() {
196+
Cert, err := tls.X509KeyPair(c.CertPEM, c.KeyPEM)
197+
if err == nil {
198+
c.Cert = Cert
199+
}
200+
}
201+
202+
func (c *CertContainer) Close() {
203+
if c.CertWatcher != nil {
204+
c.CertWatcher.Close()
205+
}
206+
if c.KeyWatcher != nil {
207+
c.KeyWatcher.Close()
208+
}
209+
}
210+
129211
var command = &cobra.Command{
130212
Use: "git-proxy",
131213
Short: "A HTTP service to proxy git requests",
@@ -140,6 +222,8 @@ func init() {
140222
command.PersistentFlags().IntVarP(&bandwidthLimit, "bandwidth-limit", "l", 0, "set total bandwidth limit (MB/s), 0 as no limit")
141223
command.PersistentFlags().IntVarP(&requestLimit, "request-limit", "r", 0, "set request limit by ip, 0 as no limit")
142224
command.PersistentFlags().BoolVarP(&denyWebPage, "deny-web-page", "", false, "deny web page requests")
225+
command.PersistentFlags().StringVarP(&certPath, "cert-path", "c", "cert.pem", "set tls cert path")
226+
command.PersistentFlags().StringVarP(&keyPath, "key-path", "k", "key.pem", "set tls key path")
143227
}
144228

145229
func main() {
@@ -198,6 +282,20 @@ func run(*cobra.Command, []string) {
198282
}
199283
listen := M.ParseSocksaddr(":" + strconv.Itoa(runningPort))
200284
listener := listenTCP(listen)
285+
if len(certPath) == 0 || len(keyPath) == 0 {
286+
log.Info("Listening TCP port ", listen.Port)
287+
} else if container, err := NewCertContainer(); err != nil {
288+
log.Warn(E.Cause(err, "Update TCP to TLS"))
289+
log.Info("Listening TCP port ", listen.Port)
290+
} else {
291+
log.Info("Listening TLS port ", listen.Port)
292+
listener = tls.NewListener(listener, &tls.Config{
293+
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
294+
return &container.Cert, nil
295+
},
296+
})
297+
defer container.Close()
298+
}
201299
chiRouter := chi.NewRouter()
202300
chiRouter.Group(func(r chi.Router) {
203301
r.Use(middleware.RealIP)
@@ -425,7 +523,6 @@ func listenTCP(address M.Socksaddr) net.Listener {
425523
}
426524
address.Port = address.Port + 1
427525
}
428-
log.Info("Listening tcp port ", address.Port)
429526
return listener
430527
}
431528

0 commit comments

Comments
 (0)