Forráskód Böngészése

feat(discovery): 优先按接收网卡返回管理地址

优化UDP发现逻辑,确保响应IP与请求网卡一致,提升多网卡环境下的连接准确性。
yangkaixiang 1 hónapja
szülő
commit
f6ccbebaf4

+ 1 - 1
docs/05-Server模块设计.md

@@ -88,7 +88,7 @@
 
 1. 监听 UDP 广播请求
 2. 解析发现报文
-3. 生成发现响应
+3. 优先按收到广播的本地网卡生成发现响应
 4. 返回设备基础信息
 
 边界:

+ 1 - 1
docs/07-Server首阶段实现清单.md

@@ -166,7 +166,7 @@
 
 1. 启动 UDP 监听
 2. 收包后读取 `deviceinfo`
-3. 收包后读取管理地址
+3. 收包后优先读取接收网卡上的管理地址
 4. 返回发现响应
 
 ## 6. 第一阶段建议数据结构

+ 1 - 1
server/internal/config/config.go

@@ -6,7 +6,7 @@ import (
 	"net"
 )
 
-const ServerVersion = "2026.05.14.1353"
+const ServerVersion = "2026.05.14.1741"
 
 type Config struct {
 	HTTPHost         string

+ 65 - 14
server/internal/discovery/discovery.go

@@ -40,9 +40,13 @@ func (s *Server) Run(ctx context.Context) error {
 	}()
 
 	s.log.Info("udp discovery listening", "addr", conn.LocalAddr().String())
+	if err := enableLocalAddrControl(conn); err != nil {
+		s.log.Warn("udp discovery local interface detection is unavailable", "error", err.Error())
+	}
+
 	buf := make([]byte, 2048)
 	for {
-		n, remote, err := conn.ReadFromUDP(buf)
+		n, remote, packetInfo, err := readFromUDP(conn, buf)
 		if err != nil {
 			if ctx.Err() != nil {
 				return nil
@@ -55,7 +59,7 @@ func (s *Server) Run(ctx context.Context) error {
 			continue
 		}
 
-		lan2IP, mac := s.maintenanceEndpoint()
+		lan2IP, mac := s.maintenanceEndpoint(packetInfo)
 		if lan2IP == "" {
 			s.log.Warn("skip discovery response because no 169.254 maintenance address was found")
 			continue
@@ -79,36 +83,83 @@ func (s *Server) Run(ctx context.Context) error {
 	}
 }
 
-func (s *Server) maintenanceEndpoint() (string, string) {
+type udpPacketInfo struct {
+	localIP net.IP
+	ifIndex int
+}
+
+func (s *Server) maintenanceEndpoint(packetInfo udpPacketInfo) (string, string) {
 	if s.cfg.MaintenanceIP != "" {
 		mac := findMACByIP(s.cfg.MaintenanceIP)
 		if mac != "" {
 			return s.cfg.MaintenanceIP, mac
 		}
 	}
+	if packetInfo.ifIndex > 0 {
+		lan2IP, mac := findLinkLocalEndpointByInterfaceIndex(packetInfo.ifIndex)
+		if lan2IP != "" {
+			return lan2IP, mac
+		}
+	}
+	if packetInfo.localIP != nil {
+		lan2IP, mac := findLinkLocalEndpointByInterfaceIP(packetInfo.localIP.String())
+		if lan2IP != "" {
+			return lan2IP, mac
+		}
+	}
 	return findFirstLinkLocalEndpoint()
 }
 
+func findLinkLocalEndpointByInterfaceIndex(index int) (string, string) {
+	iface, err := net.InterfaceByIndex(index)
+	if err != nil {
+		return "", ""
+	}
+	return findLinkLocalEndpointOnInterface(*iface)
+}
+
+func findLinkLocalEndpointByInterfaceIP(ip string) (string, string) {
+	ifaces, err := net.Interfaces()
+	if err != nil {
+		return "", ""
+	}
+	for _, iface := range ifaces {
+		lan2IP, mac := findLinkLocalEndpointOnInterface(iface)
+		if lan2IP == ip {
+			return lan2IP, mac
+		}
+	}
+	return "", ""
+}
+
 func findFirstLinkLocalEndpoint() (string, string) {
 	ifaces, err := net.Interfaces()
 	if err != nil {
 		return "", ""
 	}
 	for _, iface := range ifaces {
-		if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 || len(iface.HardwareAddr) == 0 {
-			continue
+		lan2IP, mac := findLinkLocalEndpointOnInterface(iface)
+		if lan2IP != "" {
+			return lan2IP, mac
 		}
-		addrs, err := iface.Addrs()
-		if err != nil {
+	}
+	return "", ""
+}
+
+func findLinkLocalEndpointOnInterface(iface net.Interface) (string, string) {
+	if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 || len(iface.HardwareAddr) == 0 {
+		return "", ""
+	}
+	addrs, err := iface.Addrs()
+	if err != nil {
+		return "", ""
+	}
+	for _, addr := range addrs {
+		current := ipv4FromAddr(addr)
+		if current == nil || !strings.HasPrefix(current.String(), "169.254.") {
 			continue
 		}
-		for _, addr := range addrs {
-			current := ipv4FromAddr(addr)
-			if current == nil || !strings.HasPrefix(current.String(), "169.254.") {
-				continue
-			}
-			return current.String(), iface.HardwareAddr.String()
-		}
+		return current.String(), iface.HardwareAddr.String()
 	}
 	return "", ""
 }

+ 67 - 0
server/internal/discovery/packetinfo_linux.go

@@ -0,0 +1,67 @@
+//go:build linux
+
+package discovery
+
+import (
+	"encoding/binary"
+	"net"
+	"syscall"
+)
+
+func enableLocalAddrControl(conn *net.UDPConn) error {
+	rawConn, err := conn.SyscallConn()
+	if err != nil {
+		return err
+	}
+	var controlErr error
+	err = rawConn.Control(func(fd uintptr) {
+		controlErr = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_PKTINFO, 1)
+	})
+	if err != nil {
+		return err
+	}
+	return controlErr
+}
+
+func readFromUDP(conn *net.UDPConn, buf []byte) (int, *net.UDPAddr, udpPacketInfo, error) {
+	oob := make([]byte, 128)
+	n, oobn, _, remote, err := conn.ReadMsgUDP(buf, oob)
+	if err != nil {
+		return 0, nil, udpPacketInfo{}, err
+	}
+	packetInfo := parsePacketInfo(oob[:oobn])
+	return n, remote, packetInfo, nil
+}
+
+func parsePacketInfo(oob []byte) udpPacketInfo {
+	messages, err := syscall.ParseSocketControlMessage(oob)
+	if err != nil {
+		return udpPacketInfo{}
+	}
+	for _, message := range messages {
+		if message.Header.Level != syscall.IPPROTO_IP || message.Header.Type != syscall.IP_PKTINFO || len(message.Data) < 12 {
+			continue
+		}
+		specDst := [4]byte(message.Data[4:8])
+		addr := [4]byte(message.Data[8:12])
+		return udpPacketInfo{
+			localIP: packetInfoIP(specDst, addr),
+			ifIndex: int(int32(binary.LittleEndian.Uint32(message.Data[0:4]))),
+		}
+	}
+	return udpPacketInfo{}
+}
+
+func packetInfoIP(specDst [4]byte, addr [4]byte) net.IP {
+	if ip := ipv4FromBytes(specDst); ip != nil {
+		return ip
+	}
+	return ipv4FromBytes(addr)
+}
+
+func ipv4FromBytes(value [4]byte) net.IP {
+	if value == [4]byte{} || value == [4]byte{255, 255, 255, 255} {
+		return nil
+	}
+	return net.IPv4(value[0], value[1], value[2], value[3]).To4()
+}

+ 14 - 0
server/internal/discovery/packetinfo_other.go

@@ -0,0 +1,14 @@
+//go:build !linux
+
+package discovery
+
+import "net"
+
+func enableLocalAddrControl(conn *net.UDPConn) error {
+	return nil
+}
+
+func readFromUDP(conn *net.UDPConn, buf []byte) (int, *net.UDPAddr, udpPacketInfo, error) {
+	n, remote, err := conn.ReadFromUDP(buf)
+	return n, remote, udpPacketInfo{}, err
+}