@@ -85,6 +85,8 @@ type WireGuardService struct {
8585 dns []netip.Addr
8686 // Callback for when netstack is ready
8787 onNetstackReady func (* netstack.Net )
88+ // Callback for when netstack is closed
89+ onNetstackClose func ()
8890 othertnet * netstack.Net
8991 // Proxy manager for tunnel
9092 proxyManager * proxy.ProxyManager
@@ -254,7 +256,7 @@ func (s *WireGuardService) addTcpTarget(msg websocket.WSMessage) {
254256 }
255257
256258 if len (targetData .Targets ) > 0 {
257- updateTargets (s .proxyManager , "add" , s .TunnelIP , "tcp" , targetData )
259+ s . updateTargets (s .proxyManager , "add" , s .TunnelIP , "tcp" , targetData )
258260 }
259261}
260262
@@ -274,7 +276,7 @@ func (s *WireGuardService) addUdpTarget(msg websocket.WSMessage) {
274276 }
275277
276278 if len (targetData .Targets ) > 0 {
277- updateTargets (s .proxyManager , "add" , s .TunnelIP , "udp" , targetData )
279+ s . updateTargets (s .proxyManager , "add" , s .TunnelIP , "udp" , targetData )
278280 }
279281}
280282
@@ -294,7 +296,7 @@ func (s *WireGuardService) removeUdpTarget(msg websocket.WSMessage) {
294296 }
295297
296298 if len (targetData .Targets ) > 0 {
297- updateTargets (s .proxyManager , "remove" , s .TunnelIP , "udp" , targetData )
299+ s . updateTargets (s .proxyManager , "remove" , s .TunnelIP , "udp" , targetData )
298300 }
299301}
300302
@@ -314,7 +316,7 @@ func (s *WireGuardService) removeTcpTarget(msg websocket.WSMessage) {
314316 }
315317
316318 if len (targetData .Targets ) > 0 {
317- updateTargets (s .proxyManager , "remove" , s .TunnelIP , "tcp" , targetData )
319+ s . updateTargets (s .proxyManager , "remove" , s .TunnelIP , "tcp" , targetData )
318320 }
319321}
320322
@@ -392,6 +394,10 @@ func (s *WireGuardService) SetOnNetstackReady(callback func(*netstack.Net)) {
392394 s .onNetstackReady = callback
393395}
394396
397+ func (s * WireGuardService ) SetOnNetstackClose (callback func ()) {
398+ s .onNetstackClose = callback
399+ }
400+
395401func (s * WireGuardService ) LoadRemoteConfig () error {
396402 s .stopGetConfig = s .client .SendMessageInterval ("newt/wg/get-config" , map [string ]interface {}{
397403 "publicKey" : s .key .PublicKey ().String (),
@@ -438,11 +444,11 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
438444
439445 // add the targets if there are any
440446 if len (config .Targets .TCP ) > 0 {
441- updateTargets (s .proxyManager , "add" , s .TunnelIP , "tcp" , TargetData {Targets : config .Targets .TCP })
447+ s . updateTargets (s .proxyManager , "add" , s .TunnelIP , "tcp" , TargetData {Targets : config .Targets .TCP })
442448 }
443449
444450 if len (config .Targets .UDP ) > 0 {
445- updateTargets (s .proxyManager , "add" , s .TunnelIP , "udp" , TargetData {Targets : config .Targets .UDP })
451+ s . updateTargets (s .proxyManager , "add" , s .TunnelIP , "udp" , TargetData {Targets : config .Targets .UDP })
446452 }
447453
448454 // Create ProxyManager for this tunnel
@@ -1077,7 +1083,8 @@ func (s *WireGuardService) keepSendingUDPHolePunch(host string) {
10771083 }
10781084}
10791085
1080- func updateTargets (pm * proxy.ProxyManager , action string , tunnelIP string , proto string , targetData TargetData ) error {
1086+ func (s * WireGuardService ) updateTargets (pm * proxy.ProxyManager , action string , tunnelIP string , proto string , targetData TargetData ) error {
1087+ var replace = true
10811088 for _ , t := range targetData .Targets {
10821089 // Split the first number off of the target with : separator and use as the port
10831090 parts := strings .Split (t , ":" )
@@ -1106,6 +1113,8 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
11061113 // Ignore "target not found" errors as this is expected for new targets
11071114 if ! strings .Contains (err .Error (), "target not found" ) {
11081115 logger .Error ("Failed to remove existing target: %v" , err )
1116+ } else {
1117+ replace = false // If we got here, it means the target didn't exist, so we can add it without replacing
11091118 }
11101119 }
11111120
@@ -1123,6 +1132,17 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
11231132 }
11241133 }
11251134
1135+ if replace {
1136+ // If we replaced any targets, we need to hot swap the netstack
1137+ if err := s .ReplaceNetstack (s .dns ); err != nil {
1138+ logger .Error ("Failed to replace netstack after updating targets: %v" , err )
1139+ return err
1140+ }
1141+ logger .Info ("Netstack replaced successfully after updating targets" )
1142+ } else {
1143+ logger .Info ("No targets updated, no netstack replacement needed" )
1144+ }
1145+
11261146 return nil
11271147}
11281148
@@ -1140,3 +1160,127 @@ func parseTargetData(data interface{}) (TargetData, error) {
11401160 }
11411161 return targetData , nil
11421162}
1163+
1164+ // Add this method to WireGuardService
1165+ func (s * WireGuardService ) ReplaceNetstack (newDNS []netip.Addr ) error {
1166+ s .mu .Lock ()
1167+ defer s .mu .Unlock ()
1168+
1169+ if s .device == nil || s .tun == nil {
1170+ return fmt .Errorf ("WireGuard device not initialized" )
1171+ }
1172+
1173+ // Parse the current tunnel IP from the existing config
1174+ parts := strings .Split (s .config .IpAddress , "/" )
1175+ if len (parts ) != 2 {
1176+ return fmt .Errorf ("invalid IP address format: %s" , s .config .IpAddress )
1177+ }
1178+ tunnelIP := netip .MustParseAddr (parts [0 ])
1179+
1180+ // Stop the proxy manager temporarily
1181+ s .proxyManager .Stop ()
1182+
1183+ // Create new TUN device and netstack with new DNS
1184+ newTun , newTnet , err := netstack .CreateNetTUN (
1185+ []netip.Addr {tunnelIP },
1186+ newDNS ,
1187+ s .mtu )
1188+ if err != nil {
1189+ // Restart proxy manager with old tnet on failure
1190+ s .proxyManager .Start ()
1191+ return fmt .Errorf ("failed to create new TUN device: %v" , err )
1192+ }
1193+
1194+ // Get current device config before closing
1195+ currentConfig , err := s .device .IpcGet ()
1196+ if err != nil {
1197+ newTun .Close ()
1198+ s .proxyManager .Start ()
1199+ return fmt .Errorf ("failed to get current device config: %v" , err )
1200+ }
1201+
1202+ // Filter out read-only fields from the config
1203+ filteredConfig := s .filterReadOnlyFields (currentConfig )
1204+
1205+ // if onNetstackClose callback is set, call it
1206+ if s .onNetstackClose != nil {
1207+ s .onNetstackClose ()
1208+ }
1209+
1210+ // Close old device (this closes the old TUN device)
1211+ s .device .Close ()
1212+
1213+ // Update references
1214+ s .tun = newTun
1215+ s .tnet = newTnet
1216+ s .dns = newDNS
1217+
1218+ // Create new WireGuard device with same port
1219+ s .device = device .NewDevice (s .tun , NewFixedPortBind (s .Port ), device .NewLogger (
1220+ device .LogLevelSilent ,
1221+ "wireguard: " ,
1222+ ))
1223+
1224+ // Restore the configuration (without read-only fields)
1225+ err = s .device .IpcSet (filteredConfig )
1226+ if err != nil {
1227+ return fmt .Errorf ("failed to restore WireGuard configuration: %v" , err )
1228+ }
1229+
1230+ // Bring up the device
1231+ err = s .device .Up ()
1232+ if err != nil {
1233+ return fmt .Errorf ("failed to bring up new WireGuard device: %v" , err )
1234+ }
1235+
1236+ // Update proxy manager with new tnet and restart
1237+ s .proxyManager .SetTNet (s .tnet )
1238+ s .proxyManager .Start ()
1239+
1240+ s .proxyManager .PrintTargets ()
1241+
1242+ // Call the netstack ready callback if set
1243+ if s .onNetstackReady != nil {
1244+ go s .onNetstackReady (s .tnet )
1245+ }
1246+
1247+ logger .Info ("Netstack replaced successfully with new DNS servers" )
1248+ return nil
1249+ }
1250+
1251+ // filterReadOnlyFields removes read-only fields from WireGuard IPC configuration
1252+ func (s * WireGuardService ) filterReadOnlyFields (config string ) string {
1253+ lines := strings .Split (config , "\n " )
1254+ var filteredLines []string
1255+
1256+ // List of read-only fields that should not be included in IpcSet
1257+ readOnlyFields := map [string ]bool {
1258+ "last_handshake_time_sec" : true ,
1259+ "last_handshake_time_nsec" : true ,
1260+ "rx_bytes" : true ,
1261+ "tx_bytes" : true ,
1262+ "protocol_version" : true ,
1263+ }
1264+
1265+ for _ , line := range lines {
1266+ if line == "" {
1267+ continue
1268+ }
1269+
1270+ // Check if this line contains a read-only field
1271+ isReadOnly := false
1272+ for field := range readOnlyFields {
1273+ if strings .HasPrefix (line , field + "=" ) {
1274+ isReadOnly = true
1275+ break
1276+ }
1277+ }
1278+
1279+ // Only include non-read-only lines
1280+ if ! isReadOnly {
1281+ filteredLines = append (filteredLines , line )
1282+ }
1283+ }
1284+
1285+ return strings .Join (filteredLines , "\n " )
1286+ }
0 commit comments