Skip to content

Commit 45d17da

Browse files
Fix the bind problem by just recreating the dev
TODO: WHY CANT WE REBIND TO A PORT - WE NEED TO FIX THIS BETTER
1 parent dfba35f commit 45d17da

File tree

3 files changed

+185
-14
lines changed

3 files changed

+185
-14
lines changed

clients.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ func setupClientsNetstack(client *websocket.Client, host string) {
5454
}
5555
})
5656

57+
wgService.SetOnNetstackClose(func() {
58+
if wgTesterServer != nil {
59+
wgTesterServer.Stop()
60+
wgTesterServer = nil
61+
}
62+
})
63+
5764
client.OnTokenUpdate(func(token string) {
5865
wgService.SetToken(token)
5966
})

proxy/manager.go

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,13 @@ func (pm *ProxyManager) Stop() error {
191191
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
192192
}
193193

194-
// Clear the target maps
195-
for k := range pm.tcpTargets {
196-
delete(pm.tcpTargets, k)
197-
}
198-
for k := range pm.udpTargets {
199-
delete(pm.udpTargets, k)
200-
}
194+
// // Clear the target maps
195+
// for k := range pm.tcpTargets {
196+
// delete(pm.tcpTargets, k)
197+
// }
198+
// for k := range pm.udpTargets {
199+
// delete(pm.udpTargets, k)
200+
// }
201201

202202
// Give active connections a chance to close gracefully
203203
time.Sleep(100 * time.Millisecond)
@@ -368,3 +368,23 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
368368
}
369369
}
370370
}
371+
372+
// write a function to print out the current targets in the ProxyManager
373+
func (pm *ProxyManager) PrintTargets() {
374+
pm.mutex.RLock()
375+
defer pm.mutex.RUnlock()
376+
377+
logger.Info("Current TCP Targets:")
378+
for listenIP, targets := range pm.tcpTargets {
379+
for port, targetAddr := range targets {
380+
logger.Info("TCP %s:%d -> %s", listenIP, port, targetAddr)
381+
}
382+
}
383+
384+
logger.Info("Current UDP Targets:")
385+
for listenIP, targets := range pm.udpTargets {
386+
for port, targetAddr := range targets {
387+
logger.Info("UDP %s:%d -> %s", listenIP, port, targetAddr)
388+
}
389+
}
390+
}

wgnetstack/wgnetstack.go

Lines changed: 151 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ type WireGuardService struct {
8585
dns []netip.Addr
8686
// Callback for when netstack is ready
8787
onNetstackReady func(*netstack.Net)
88+
// Callback for when netstack is closed
89+
onNetstackClose func()
8890
othertnet *netstack.Net
8991
// Proxy manager for tunnel
9092
proxyManager *proxy.ProxyManager
@@ -254,7 +256,7 @@ func (s *WireGuardService) addTcpTarget(msg websocket.WSMessage) {
254256
}
255257

256258
if len(targetData.Targets) > 0 {
257-
updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", targetData)
259+
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", targetData)
258260
}
259261
}
260262

@@ -274,7 +276,7 @@ func (s *WireGuardService) addUdpTarget(msg websocket.WSMessage) {
274276
}
275277

276278
if len(targetData.Targets) > 0 {
277-
updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", targetData)
279+
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", targetData)
278280
}
279281
}
280282

@@ -294,7 +296,7 @@ func (s *WireGuardService) removeUdpTarget(msg websocket.WSMessage) {
294296
}
295297

296298
if len(targetData.Targets) > 0 {
297-
updateTargets(s.proxyManager, "remove", s.TunnelIP, "udp", targetData)
299+
s.updateTargets(s.proxyManager, "remove", s.TunnelIP, "udp", targetData)
298300
}
299301
}
300302

@@ -314,7 +316,7 @@ func (s *WireGuardService) removeTcpTarget(msg websocket.WSMessage) {
314316
}
315317

316318
if len(targetData.Targets) > 0 {
317-
updateTargets(s.proxyManager, "remove", s.TunnelIP, "tcp", targetData)
319+
s.updateTargets(s.proxyManager, "remove", s.TunnelIP, "tcp", targetData)
318320
}
319321
}
320322

@@ -392,6 +394,10 @@ func (s *WireGuardService) SetOnNetstackReady(callback func(*netstack.Net)) {
392394
s.onNetstackReady = callback
393395
}
394396

397+
func (s *WireGuardService) SetOnNetstackClose(callback func()) {
398+
s.onNetstackClose = callback
399+
}
400+
395401
func (s *WireGuardService) LoadRemoteConfig() error {
396402
s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{
397403
"publicKey": s.key.PublicKey().String(),
@@ -438,11 +444,11 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
438444

439445
// add the targets if there are any
440446
if len(config.Targets.TCP) > 0 {
441-
updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", TargetData{Targets: config.Targets.TCP})
447+
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", TargetData{Targets: config.Targets.TCP})
442448
}
443449

444450
if len(config.Targets.UDP) > 0 {
445-
updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", TargetData{Targets: config.Targets.UDP})
451+
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", TargetData{Targets: config.Targets.UDP})
446452
}
447453

448454
// Create ProxyManager for this tunnel
@@ -1077,7 +1083,8 @@ func (s *WireGuardService) keepSendingUDPHolePunch(host string) {
10771083
}
10781084
}
10791085

1080-
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
1086+
func (s *WireGuardService) updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
1087+
var replace = true
10811088
for _, t := range targetData.Targets {
10821089
// Split the first number off of the target with : separator and use as the port
10831090
parts := strings.Split(t, ":")
@@ -1106,6 +1113,8 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
11061113
// Ignore "target not found" errors as this is expected for new targets
11071114
if !strings.Contains(err.Error(), "target not found") {
11081115
logger.Error("Failed to remove existing target: %v", err)
1116+
} else {
1117+
replace = false // If we got here, it means the target didn't exist, so we can add it without replacing
11091118
}
11101119
}
11111120

@@ -1123,6 +1132,17 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
11231132
}
11241133
}
11251134

1135+
if replace {
1136+
// If we replaced any targets, we need to hot swap the netstack
1137+
if err := s.ReplaceNetstack(s.dns); err != nil {
1138+
logger.Error("Failed to replace netstack after updating targets: %v", err)
1139+
return err
1140+
}
1141+
logger.Info("Netstack replaced successfully after updating targets")
1142+
} else {
1143+
logger.Info("No targets updated, no netstack replacement needed")
1144+
}
1145+
11261146
return nil
11271147
}
11281148

@@ -1140,3 +1160,127 @@ func parseTargetData(data interface{}) (TargetData, error) {
11401160
}
11411161
return targetData, nil
11421162
}
1163+
1164+
// Add this method to WireGuardService
1165+
func (s *WireGuardService) ReplaceNetstack(newDNS []netip.Addr) error {
1166+
s.mu.Lock()
1167+
defer s.mu.Unlock()
1168+
1169+
if s.device == nil || s.tun == nil {
1170+
return fmt.Errorf("WireGuard device not initialized")
1171+
}
1172+
1173+
// Parse the current tunnel IP from the existing config
1174+
parts := strings.Split(s.config.IpAddress, "/")
1175+
if len(parts) != 2 {
1176+
return fmt.Errorf("invalid IP address format: %s", s.config.IpAddress)
1177+
}
1178+
tunnelIP := netip.MustParseAddr(parts[0])
1179+
1180+
// Stop the proxy manager temporarily
1181+
s.proxyManager.Stop()
1182+
1183+
// Create new TUN device and netstack with new DNS
1184+
newTun, newTnet, err := netstack.CreateNetTUN(
1185+
[]netip.Addr{tunnelIP},
1186+
newDNS,
1187+
s.mtu)
1188+
if err != nil {
1189+
// Restart proxy manager with old tnet on failure
1190+
s.proxyManager.Start()
1191+
return fmt.Errorf("failed to create new TUN device: %v", err)
1192+
}
1193+
1194+
// Get current device config before closing
1195+
currentConfig, err := s.device.IpcGet()
1196+
if err != nil {
1197+
newTun.Close()
1198+
s.proxyManager.Start()
1199+
return fmt.Errorf("failed to get current device config: %v", err)
1200+
}
1201+
1202+
// Filter out read-only fields from the config
1203+
filteredConfig := s.filterReadOnlyFields(currentConfig)
1204+
1205+
// if onNetstackClose callback is set, call it
1206+
if s.onNetstackClose != nil {
1207+
s.onNetstackClose()
1208+
}
1209+
1210+
// Close old device (this closes the old TUN device)
1211+
s.device.Close()
1212+
1213+
// Update references
1214+
s.tun = newTun
1215+
s.tnet = newTnet
1216+
s.dns = newDNS
1217+
1218+
// Create new WireGuard device with same port
1219+
s.device = device.NewDevice(s.tun, NewFixedPortBind(s.Port), device.NewLogger(
1220+
device.LogLevelSilent,
1221+
"wireguard: ",
1222+
))
1223+
1224+
// Restore the configuration (without read-only fields)
1225+
err = s.device.IpcSet(filteredConfig)
1226+
if err != nil {
1227+
return fmt.Errorf("failed to restore WireGuard configuration: %v", err)
1228+
}
1229+
1230+
// Bring up the device
1231+
err = s.device.Up()
1232+
if err != nil {
1233+
return fmt.Errorf("failed to bring up new WireGuard device: %v", err)
1234+
}
1235+
1236+
// Update proxy manager with new tnet and restart
1237+
s.proxyManager.SetTNet(s.tnet)
1238+
s.proxyManager.Start()
1239+
1240+
s.proxyManager.PrintTargets()
1241+
1242+
// Call the netstack ready callback if set
1243+
if s.onNetstackReady != nil {
1244+
go s.onNetstackReady(s.tnet)
1245+
}
1246+
1247+
logger.Info("Netstack replaced successfully with new DNS servers")
1248+
return nil
1249+
}
1250+
1251+
// filterReadOnlyFields removes read-only fields from WireGuard IPC configuration
1252+
func (s *WireGuardService) filterReadOnlyFields(config string) string {
1253+
lines := strings.Split(config, "\n")
1254+
var filteredLines []string
1255+
1256+
// List of read-only fields that should not be included in IpcSet
1257+
readOnlyFields := map[string]bool{
1258+
"last_handshake_time_sec": true,
1259+
"last_handshake_time_nsec": true,
1260+
"rx_bytes": true,
1261+
"tx_bytes": true,
1262+
"protocol_version": true,
1263+
}
1264+
1265+
for _, line := range lines {
1266+
if line == "" {
1267+
continue
1268+
}
1269+
1270+
// Check if this line contains a read-only field
1271+
isReadOnly := false
1272+
for field := range readOnlyFields {
1273+
if strings.HasPrefix(line, field+"=") {
1274+
isReadOnly = true
1275+
break
1276+
}
1277+
}
1278+
1279+
// Only include non-read-only lines
1280+
if !isReadOnly {
1281+
filteredLines = append(filteredLines, line)
1282+
}
1283+
}
1284+
1285+
return strings.Join(filteredLines, "\n")
1286+
}

0 commit comments

Comments
 (0)