diff --git a/cmd/proxy-buffer/buffer.go b/cmd/proxy-buffer/buffer.go index da95fcb..7f77db4 100644 --- a/cmd/proxy-buffer/buffer.go +++ b/cmd/proxy-buffer/buffer.go @@ -281,6 +281,41 @@ func (pb *PacketBuffer) deserializePacket(data []byte) (*packet.DataPacket, erro return dataPacket, nil } +// ProcessErrorPacket processes an error packet and returns a BufferedPacket for forwarding. +// Error packets are not buffered or fragmented - they fit in one MTU and are forwarded directly. +func (pb *PacketBuffer) ProcessErrorPacket(data []byte, src *net.UDPAddr) (*util.BufferedPacket, error) { + // Deserialize error packet + codec := &packet.ErrorPacketCodec{} + packetAny, err := codec.Deserialize(data) + if err != nil { + return nil, fmt.Errorf("failed to deserialize error packet: %w", err) + } + + errorPacket, ok := packetAny.(*packet.ErrorPacket) + if !ok { + return nil, fmt.Errorf("unexpected packet type in ProcessErrorPacket") + } + + // Create destination address from the packet's routing info + peer := &net.UDPAddr{IP: net.IP(errorPacket.DstIP[:]), Port: int(errorPacket.DstPort)} + + // Create BufferedPacket with error message as payload + bufferedPacket := &util.BufferedPacket{ + Payload: []byte(errorPacket.ErrorMsg), + Source: src, + Peer: peer, + PacketType: util.PacketTypeError, + RPCID: errorPacket.RPCID, + DstIP: errorPacket.DstIP, + DstPort: errorPacket.DstPort, + SrcIP: errorPacket.SrcIP, + SrcPort: errorPacket.SrcPort, + TotalPackets: 1, + } + + return bufferedPacket, nil +} + // cleanupRoutine periodically cleans up expired fragments func (pb *PacketBuffer) cleanupRoutine() { for { diff --git a/cmd/proxy-buffer/main.go b/cmd/proxy-buffer/main.go index 78ba561..9a204d3 100644 --- a/cmd/proxy-buffer/main.go +++ b/cmd/proxy-buffer/main.go @@ -12,6 +12,7 @@ import ( "github.com/appnet-org/arpc/cmd/proxy-buffer/util" "github.com/appnet-org/arpc/pkg/logging" + "github.com/appnet-org/arpc/pkg/packet" "github.com/appnet-org/arpc/pkg/transport" "go.uber.org/zap" ) @@ -205,6 +206,48 @@ func runProxyServer(port int, state *ProxyState, config *Config) error { func handlePacket(conn *net.UDPConn, state *ProxyState, src *net.UDPAddr, data []byte, config *Config) { ctx := context.Background() + // Check if this is an error packet (PacketTypeID == 3) + if len(data) > 0 && data[0] == byte(packet.PacketTypeError.TypeID) { + // Process error packet - forward directly without element chain + bufferedPacket, err := state.packetBuffer.ProcessErrorPacket(data, src) + if err != nil { + logging.Error("Error processing error packet", zap.Error(err)) + return + } + + // Serialize the error packet for forwarding + errorPacket := &packet.ErrorPacket{ + PacketTypeID: packet.PacketTypeError.TypeID, + RPCID: bufferedPacket.RPCID, + DstIP: bufferedPacket.DstIP, + DstPort: bufferedPacket.DstPort, + SrcIP: bufferedPacket.SrcIP, + SrcPort: bufferedPacket.SrcPort, + ErrorMsg: string(bufferedPacket.Payload), + } + + codec := &packet.ErrorPacketCodec{} + serialized, err := codec.Serialize(errorPacket, nil) + if err != nil { + logging.Error("Failed to serialize error packet for forwarding", zap.Error(err)) + return + } + + // Forward the error packet to the destination + if _, err := conn.WriteToUDP(serialized, bufferedPacket.Peer); err != nil { + logging.Error("Failed to forward error packet", zap.Error(err)) + return + } + + logging.Debug("Forwarded error packet", + zap.Uint64("rpcID", bufferedPacket.RPCID), + zap.String("from", bufferedPacket.Source.String()), + zap.String("to", bufferedPacket.Peer.String()), + zap.String("errorMsg", string(bufferedPacket.Payload))) + + return + } + // Process packet - returns nil if still buffering fragments // Returns a complete BufferedPacket only when ALL fragments have been received bufferedPacket, err := state.packetBuffer.ProcessPacket(data, src) @@ -252,7 +295,7 @@ func handlePacket(conn *net.UDPConn, state *ProxyState, src *net.UDPAddr, data [ logging.Error("Error processing packet through element chain or packet was dropped", zap.Error(err)) // Send error packet back to the source - if sendErr := util.SendErrorPacket(conn, bufferedPacket.Source, bufferedPacket.RPCID, err.Error()); sendErr != nil { + if sendErr := util.SendErrorPacket(conn, bufferedPacket.Source, bufferedPacket.RPCID, err.Error(), bufferedPacket.SrcIP, bufferedPacket.SrcPort, bufferedPacket.DstIP, bufferedPacket.DstPort); sendErr != nil { logging.Error("Failed to send error packet", zap.Error(sendErr)) } return diff --git a/cmd/proxy-buffer/util/send_packet.go b/cmd/proxy-buffer/util/send_packet.go index 51f8d77..8a2b358 100644 --- a/cmd/proxy-buffer/util/send_packet.go +++ b/cmd/proxy-buffer/util/send_packet.go @@ -9,12 +9,16 @@ import ( "go.uber.org/zap" ) -// sendErrorPacket sends an error packet back to the source -func SendErrorPacket(conn *net.UDPConn, dest *net.UDPAddr, rpcID uint64, errorMsg string) error { +// SendErrorPacket sends an error packet back to the source with routing information +func SendErrorPacket(conn *net.UDPConn, dest *net.UDPAddr, rpcID uint64, errorMsg string, dstIP [4]byte, dstPort uint16, srcIP [4]byte, srcPort uint16) error { // Create error packet errorPacket := &packet.ErrorPacket{ PacketTypeID: packet.PacketTypeError.TypeID, RPCID: rpcID, + DstIP: dstIP, + DstPort: dstPort, + SrcIP: srcIP, + SrcPort: srcPort, ErrorMsg: errorMsg, } @@ -37,4 +41,3 @@ func SendErrorPacket(conn *net.UDPConn, dest *net.UDPAddr, rpcID uint64, errorMs return nil } - diff --git a/cmd/proxy/buffer.go b/cmd/proxy/buffer.go index 67d6ad1..1c96ca6 100644 --- a/cmd/proxy/buffer.go +++ b/cmd/proxy/buffer.go @@ -429,6 +429,43 @@ func (pb *PacketBuffer) deserializePacket(data []byte) (*packet.DataPacket, erro return dataPacket, nil } +// ProcessErrorPacket processes an error packet and returns a BufferedPacket for forwarding. +// Error packets are not buffered or fragmented - they fit in one MTU and are forwarded directly. +func (pb *PacketBuffer) ProcessErrorPacket(data []byte, src *net.UDPAddr) (*util.BufferedPacket, error) { + // Deserialize error packet + codec := &packet.ErrorPacketCodec{} + packetAny, err := codec.Deserialize(data) + if err != nil { + return nil, fmt.Errorf("failed to deserialize error packet: %w", err) + } + + errorPacket, ok := packetAny.(*packet.ErrorPacket) + if !ok { + return nil, fmt.Errorf("unexpected packet type in ProcessErrorPacket") + } + + // Create destination address from the packet's routing info + peer := &net.UDPAddr{IP: net.IP(errorPacket.DstIP[:]), Port: int(errorPacket.DstPort)} + + // Create BufferedPacket with error message as payload + bufferedPacket := &util.BufferedPacket{ + Payload: []byte(errorPacket.ErrorMsg), + Source: src, + Peer: peer, + PacketType: util.PacketTypeError, + RPCID: errorPacket.RPCID, + DstIP: errorPacket.DstIP, + DstPort: errorPacket.DstPort, + SrcIP: errorPacket.SrcIP, + SrcPort: errorPacket.SrcPort, + IsFull: true, + SeqNumber: -1, + TotalPackets: 1, + } + + return bufferedPacket, nil +} + // CleanupUsedFragments removes fragments up to and including the lastUsedSeqNum // This should be called after the public segment has been successfully forwarded func (pb *PacketBuffer) CleanupUsedFragments(connKey string, rpcID uint64, lastUsedSeqNum uint16) { diff --git a/cmd/proxy/buffer_test.go b/cmd/proxy/buffer_test.go index 7d75bc0..a0ee944 100644 --- a/cmd/proxy/buffer_test.go +++ b/cmd/proxy/buffer_test.go @@ -1368,3 +1368,194 @@ func TestFragmentPacketForForward_LastUsedSeqNum_NormalCase(t *testing.T) { t.Log("SUCCESS: Normal case without packing works correctly") } + +// TestErrorPacketCodec_SerializeDeserialize tests error packet serialization and deserialization with IP/port fields +func TestErrorPacketCodec_SerializeDeserialize(t *testing.T) { + codec := &packet.ErrorPacketCodec{} + + // Create an error packet with routing information + originalPacket := &packet.ErrorPacket{ + PacketTypeID: packet.PacketTypeError.TypeID, + RPCID: uint64(12345), + DstIP: [4]byte{192, 168, 1, 100}, + DstPort: 8080, + SrcIP: [4]byte{10, 0, 0, 50}, + SrcPort: 9090, + ErrorMsg: "Test error message", + } + + // Serialize the packet + serialized, err := codec.Serialize(originalPacket, nil) + if err != nil { + t.Fatalf("Failed to serialize error packet: %v", err) + } + + // Verify serialized data has correct length (29 bytes header + message length) + expectedLen := 29 + len("Test error message") + if len(serialized) != expectedLen { + t.Errorf("Expected serialized length %d, got %d", expectedLen, len(serialized)) + } + + // Deserialize the packet + deserializedAny, err := codec.Deserialize(serialized) + if err != nil { + t.Fatalf("Failed to deserialize error packet: %v", err) + } + + deserializedPacket, ok := deserializedAny.(*packet.ErrorPacket) + if !ok { + t.Fatalf("Deserialized packet is not an ErrorPacket") + } + + // Verify all fields match + if deserializedPacket.PacketTypeID != originalPacket.PacketTypeID { + t.Errorf("PacketTypeID mismatch: expected %d, got %d", originalPacket.PacketTypeID, deserializedPacket.PacketTypeID) + } + if deserializedPacket.RPCID != originalPacket.RPCID { + t.Errorf("RPCID mismatch: expected %d, got %d", originalPacket.RPCID, deserializedPacket.RPCID) + } + if deserializedPacket.DstIP != originalPacket.DstIP { + t.Errorf("DstIP mismatch: expected %v, got %v", originalPacket.DstIP, deserializedPacket.DstIP) + } + if deserializedPacket.DstPort != originalPacket.DstPort { + t.Errorf("DstPort mismatch: expected %d, got %d", originalPacket.DstPort, deserializedPacket.DstPort) + } + if deserializedPacket.SrcIP != originalPacket.SrcIP { + t.Errorf("SrcIP mismatch: expected %v, got %v", originalPacket.SrcIP, deserializedPacket.SrcIP) + } + if deserializedPacket.SrcPort != originalPacket.SrcPort { + t.Errorf("SrcPort mismatch: expected %d, got %d", originalPacket.SrcPort, deserializedPacket.SrcPort) + } + if deserializedPacket.ErrorMsg != originalPacket.ErrorMsg { + t.Errorf("ErrorMsg mismatch: expected %s, got %s", originalPacket.ErrorMsg, deserializedPacket.ErrorMsg) + } +} + +// TestErrorPacketCodec_MTUValidation tests that error messages exceeding MTU are rejected +func TestErrorPacketCodec_MTUValidation(t *testing.T) { + codec := &packet.ErrorPacketCodec{} + + // Create an error packet with a message that's too long + // MaxUDPPayloadSize is typically 1400, header is 29 bytes, so max message is 1371 bytes + longMessage := make([]byte, packet.MaxUDPPayloadSize-28) // One byte too long + for i := range longMessage { + longMessage[i] = 'A' + } + + errorPacket := &packet.ErrorPacket{ + PacketTypeID: packet.PacketTypeError.TypeID, + RPCID: uint64(12345), + DstIP: [4]byte{192, 168, 1, 100}, + DstPort: 8080, + SrcIP: [4]byte{10, 0, 0, 50}, + SrcPort: 9090, + ErrorMsg: string(longMessage), + } + + // Attempt to serialize - should fail + _, err := codec.Serialize(errorPacket, nil) + if err == nil { + t.Error("Expected error when serializing message that exceeds MTU, but got none") + } + if err != nil && err.Error() != "error message too long, must fit in one MTU" { + t.Errorf("Expected MTU error message, got: %v", err) + } + + // Test with a message that fits + validMessage := make([]byte, packet.MaxUDPPayloadSize-30) // Should fit + for i := range validMessage { + validMessage[i] = 'B' + } + errorPacket.ErrorMsg = string(validMessage) + + _, err = codec.Serialize(errorPacket, nil) + if err != nil { + t.Errorf("Expected no error for valid message size, got: %v", err) + } +} + +// TestPacketBuffer_ProcessErrorPacket tests the ProcessErrorPacket method +func TestPacketBuffer_ProcessErrorPacket(t *testing.T) { + pb := NewPacketBuffer(5 * time.Second) + defer pb.Close() + + // Create and serialize an error packet + codec := &packet.ErrorPacketCodec{} + errorPacket := &packet.ErrorPacket{ + PacketTypeID: packet.PacketTypeError.TypeID, + RPCID: uint64(99999), + DstIP: [4]byte{192, 168, 1, 200}, + DstPort: 7070, + SrcIP: [4]byte{10, 0, 0, 100}, + SrcPort: 6060, + ErrorMsg: "Connection timeout", + } + + serialized, err := codec.Serialize(errorPacket, nil) + if err != nil { + t.Fatalf("Failed to serialize error packet: %v", err) + } + + // Process the error packet + src := &net.UDPAddr{IP: net.IPv4(10, 0, 0, 100), Port: 6060} + bufferedPacket, err := pb.ProcessErrorPacket(serialized, src) + if err != nil { + t.Fatalf("Failed to process error packet: %v", err) + } + + // Verify the buffered packet + if bufferedPacket == nil { + t.Fatal("Expected buffered packet, got nil") + } + + if bufferedPacket.RPCID != errorPacket.RPCID { + t.Errorf("RPCID mismatch: expected %d, got %d", errorPacket.RPCID, bufferedPacket.RPCID) + } + + if bufferedPacket.PacketType != util.PacketTypeError { + t.Errorf("PacketType mismatch: expected Error, got %v", bufferedPacket.PacketType) + } + + if string(bufferedPacket.Payload) != errorPacket.ErrorMsg { + t.Errorf("Payload mismatch: expected %s, got %s", errorPacket.ErrorMsg, string(bufferedPacket.Payload)) + } + + if bufferedPacket.DstIP != errorPacket.DstIP { + t.Errorf("DstIP mismatch: expected %v, got %v", errorPacket.DstIP, bufferedPacket.DstIP) + } + + if bufferedPacket.DstPort != errorPacket.DstPort { + t.Errorf("DstPort mismatch: expected %d, got %d", errorPacket.DstPort, bufferedPacket.DstPort) + } + + if bufferedPacket.SrcIP != errorPacket.SrcIP { + t.Errorf("SrcIP mismatch: expected %v, got %v", errorPacket.SrcIP, bufferedPacket.SrcIP) + } + + if bufferedPacket.SrcPort != errorPacket.SrcPort { + t.Errorf("SrcPort mismatch: expected %d, got %d", errorPacket.SrcPort, bufferedPacket.SrcPort) + } + + // Verify peer address is constructed from DstIP and DstPort + expectedPeerIP := net.IP(errorPacket.DstIP[:]) + if !bufferedPacket.Peer.IP.Equal(expectedPeerIP) { + t.Errorf("Peer IP mismatch: expected %v, got %v", expectedPeerIP, bufferedPacket.Peer.IP) + } + + if bufferedPacket.Peer.Port != int(errorPacket.DstPort) { + t.Errorf("Peer Port mismatch: expected %d, got %d", errorPacket.DstPort, bufferedPacket.Peer.Port) + } + + // Verify packet is marked as full (no fragmentation) + if !bufferedPacket.IsFull { + t.Error("Expected IsFull to be true for error packet") + } + + if bufferedPacket.SeqNumber != -1 { + t.Errorf("Expected SeqNumber -1, got %d", bufferedPacket.SeqNumber) + } + + if bufferedPacket.TotalPackets != 1 { + t.Errorf("Expected TotalPackets 1, got %d", bufferedPacket.TotalPackets) + } +} diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 32c15f2..02454f4 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -12,6 +12,7 @@ import ( "github.com/appnet-org/arpc/cmd/proxy/util" "github.com/appnet-org/arpc/pkg/logging" + "github.com/appnet-org/arpc/pkg/packet" "github.com/appnet-org/arpc/pkg/transport" "go.uber.org/zap" ) @@ -71,7 +72,7 @@ func (c *Config) SetEncryption(key []byte) { func getLoggingConfig() *logging.Config { level := os.Getenv("LOG_LEVEL") if level == "" { - level = "info" + level = "debug" } format := os.Getenv("LOG_FORMAT") @@ -208,6 +209,48 @@ func runProxyServer(port int, state *ProxyState, config *Config) error { func handlePacket(conn *net.UDPConn, state *ProxyState, src *net.UDPAddr, data []byte, config *Config) { ctx := context.Background() + // Check if this is an error packet (PacketTypeID == 3) + if len(data) > 0 && data[0] == byte(packet.PacketTypeError.TypeID) { + // Process error packet - forward directly without element chain + bufferedPacket, err := state.packetBuffer.ProcessErrorPacket(data, src) + if err != nil { + logging.Error("Error processing error packet", zap.Error(err)) + return + } + + // Serialize the error packet for forwarding + errorPacket := &packet.ErrorPacket{ + PacketTypeID: packet.PacketTypeError.TypeID, + RPCID: bufferedPacket.RPCID, + DstIP: bufferedPacket.DstIP, + DstPort: bufferedPacket.DstPort, + SrcIP: bufferedPacket.SrcIP, + SrcPort: bufferedPacket.SrcPort, + ErrorMsg: string(bufferedPacket.Payload), + } + + codec := &packet.ErrorPacketCodec{} + serialized, err := codec.Serialize(errorPacket, nil) + if err != nil { + logging.Error("Failed to serialize error packet for forwarding", zap.Error(err)) + return + } + + // Forward the error packet to the destination + if _, err := conn.WriteToUDP(serialized, bufferedPacket.Peer); err != nil { + logging.Error("Failed to forward error packet", zap.Error(err)) + return + } + + logging.Debug("Forwarded error packet", + zap.Uint64("rpcID", bufferedPacket.RPCID), + zap.String("from", bufferedPacket.Source.String()), + zap.String("to", bufferedPacket.Peer.String()), + zap.String("errorMsg", string(bufferedPacket.Payload))) + + return + } + // Process packet (may return nil if still buffering fragments). // Returns a buffered packet when: // - We have enough data to cover the public segment, OR @@ -274,7 +317,7 @@ func handlePacket(conn *net.UDPConn, state *ProxyState, src *net.UDPAddr, data [ if err != nil { logging.Error("Error processing packet through element chain or packet was dropped by an element", zap.Error(err)) // Send error packet back to the source - if sendErr := util.SendErrorPacket(conn, bufferedPacket.Source, bufferedPacket.RPCID, err.Error()); sendErr != nil { + if sendErr := util.SendErrorPacket(conn, bufferedPacket.Source, bufferedPacket.RPCID, err.Error(), bufferedPacket.SrcIP, bufferedPacket.SrcPort, bufferedPacket.DstIP, bufferedPacket.DstPort); sendErr != nil { logging.Error("Failed to send error packet", zap.Error(sendErr)) } return diff --git a/cmd/proxy/main_test.go b/cmd/proxy/main_test.go index 9af5ed8..0f17423 100644 --- a/cmd/proxy/main_test.go +++ b/cmd/proxy/main_test.go @@ -639,3 +639,145 @@ func TestLargeMessage_FragmentZeroDelayed(t *testing.T) { t.Log("SUCCESS: All fragments accounted for even with delayed fragment 0") } + +// TestHandlePacket_ErrorPacketDetection tests that error packets are detected correctly +func TestHandlePacket_ErrorPacketDetection(t *testing.T) { + // Create proxy state + state := &ProxyState{ + elementChain: NewRPCElementChain(), + packetBuffer: NewPacketBuffer(5 * time.Second), + } + defer state.packetBuffer.Close() + + // Create an error packet + codec := &packet.ErrorPacketCodec{} + errorPacket := &packet.ErrorPacket{ + PacketTypeID: packet.PacketTypeError.TypeID, + RPCID: uint64(77777), + DstIP: [4]byte{127, 0, 0, 1}, + DstPort: 8080, + SrcIP: [4]byte{192, 168, 1, 50}, + SrcPort: 5050, + ErrorMsg: "Test error detection", + } + + serialized, err := codec.Serialize(errorPacket, nil) + if err != nil { + t.Fatalf("Failed to serialize error packet: %v", err) + } + + // Verify the packet type ID is correct + if len(serialized) == 0 { + t.Fatal("Serialized packet is empty") + } + + if serialized[0] != byte(packet.PacketTypeError.TypeID) { + t.Errorf("Expected packet type ID %d, got %d", packet.PacketTypeError.TypeID, serialized[0]) + } + + // Process the error packet through the buffer + src := &net.UDPAddr{IP: net.IPv4(192, 168, 1, 50), Port: 5050} + bufferedPacket, err := state.packetBuffer.ProcessErrorPacket(serialized, src) + if err != nil { + t.Fatalf("Failed to process error packet: %v", err) + } + + // Verify the buffered packet has correct routing info + if bufferedPacket.PacketType != util.PacketTypeError { + t.Errorf("Expected PacketType Error, got %v", bufferedPacket.PacketType) + } + + if bufferedPacket.RPCID != errorPacket.RPCID { + t.Errorf("RPCID mismatch: expected %d, got %d", errorPacket.RPCID, bufferedPacket.RPCID) + } + + if string(bufferedPacket.Payload) != errorPacket.ErrorMsg { + t.Errorf("ErrorMsg mismatch: expected %s, got %s", errorPacket.ErrorMsg, string(bufferedPacket.Payload)) + } + + // Verify routing fields + if bufferedPacket.DstIP != errorPacket.DstIP { + t.Errorf("DstIP mismatch: expected %v, got %v", errorPacket.DstIP, bufferedPacket.DstIP) + } + + if bufferedPacket.DstPort != errorPacket.DstPort { + t.Errorf("DstPort mismatch: expected %d, got %d", errorPacket.DstPort, bufferedPacket.DstPort) + } + + if bufferedPacket.SrcIP != errorPacket.SrcIP { + t.Errorf("SrcIP mismatch: expected %v, got %v", errorPacket.SrcIP, bufferedPacket.SrcIP) + } + + if bufferedPacket.SrcPort != errorPacket.SrcPort { + t.Errorf("SrcPort mismatch: expected %d, got %d", errorPacket.SrcPort, bufferedPacket.SrcPort) + } +} + +// TestSendErrorPacket_WithRoutingInfo tests SendErrorPacket creates packets with correct routing fields +func TestSendErrorPacket_WithRoutingInfo(t *testing.T) { + // Test that SendErrorPacket creates an error packet with routing information + // We verify this by creating the packet manually and checking the serialization + + rpcID := uint64(55555) + errorMsg := "SendErrorPacket test" + dstIP := [4]byte{10, 0, 0, 200} + dstPort := uint16(3030) + srcIP := [4]byte{192, 168, 10, 10} + srcPort := uint16(2020) + + // Create error packet manually (simulating what SendErrorPacket does) + errorPacket := &packet.ErrorPacket{ + PacketTypeID: packet.PacketTypeError.TypeID, + RPCID: rpcID, + DstIP: dstIP, + DstPort: dstPort, + SrcIP: srcIP, + SrcPort: srcPort, + ErrorMsg: errorMsg, + } + + // Serialize the packet + codec := &packet.ErrorPacketCodec{} + serialized, err := codec.Serialize(errorPacket, nil) + if err != nil { + t.Fatalf("Failed to serialize error packet: %v", err) + } + + // Deserialize to verify all fields are preserved + receivedAny, err := codec.Deserialize(serialized) + if err != nil { + t.Fatalf("Failed to deserialize error packet: %v", err) + } + + receivedPacket, ok := receivedAny.(*packet.ErrorPacket) + if !ok { + t.Fatal("Deserialized packet is not an ErrorPacket") + } + + // Verify all fields + if receivedPacket.RPCID != rpcID { + t.Errorf("RPCID mismatch: expected %d, got %d", rpcID, receivedPacket.RPCID) + } + + if receivedPacket.ErrorMsg != errorMsg { + t.Errorf("ErrorMsg mismatch: expected %s, got %s", errorMsg, receivedPacket.ErrorMsg) + } + + if receivedPacket.DstIP != dstIP { + t.Errorf("DstIP mismatch: expected %v, got %v", dstIP, receivedPacket.DstIP) + } + + if receivedPacket.DstPort != dstPort { + t.Errorf("DstPort mismatch: expected %d, got %d", dstPort, receivedPacket.DstPort) + } + + if receivedPacket.SrcIP != srcIP { + t.Errorf("SrcIP mismatch: expected %v, got %v", srcIP, receivedPacket.SrcIP) + } + + if receivedPacket.SrcPort != srcPort { + t.Errorf("SrcPort mismatch: expected %d, got %d", srcPort, receivedPacket.SrcPort) + } + + t.Log("SendErrorPacket routing fields verified successfully") +} diff --git a/cmd/proxy/util/send_packet.go b/cmd/proxy/util/send_packet.go index 96ac792..8a2b358 100644 --- a/cmd/proxy/util/send_packet.go +++ b/cmd/proxy/util/send_packet.go @@ -9,12 +9,16 @@ import ( "go.uber.org/zap" ) -// sendErrorPacket sends an error packet back to the source -func SendErrorPacket(conn *net.UDPConn, dest *net.UDPAddr, rpcID uint64, errorMsg string) error { +// SendErrorPacket sends an error packet back to the source with routing information +func SendErrorPacket(conn *net.UDPConn, dest *net.UDPAddr, rpcID uint64, errorMsg string, dstIP [4]byte, dstPort uint16, srcIP [4]byte, srcPort uint16) error { // Create error packet errorPacket := &packet.ErrorPacket{ PacketTypeID: packet.PacketTypeError.TypeID, RPCID: rpcID, + DstIP: dstIP, + DstPort: dstPort, + SrcIP: srcIP, + SrcPort: srcPort, ErrorMsg: errorMsg, } diff --git a/pkg/packet/builtin_packets.go b/pkg/packet/builtin_packets.go index 4e9a690..48f3332 100644 --- a/pkg/packet/builtin_packets.go +++ b/pkg/packet/builtin_packets.go @@ -42,11 +42,15 @@ type ResponsePacket struct { DataPacket } -// ErrorPacket has exactly two fields as specified +// ErrorPacket has routing information similar to DataPacket type ErrorPacket struct { PacketTypeID PacketTypeID - RPCID uint64 // RPC ID that caused the error - ErrorMsg string // Error message string (must fit in on one MTU) + RPCID uint64 // RPC ID that caused the error + DstIP [4]byte // Destination IP address (4 bytes) + DstPort uint16 // Destination port + SrcIP [4]byte // Source IP address (4 bytes) + SrcPort uint16 // Source port + ErrorMsg string // Error message string (must fit in one MTU) } // DataPacketCodec implements DataPacket serialization for both Request and Response packets @@ -159,7 +163,7 @@ func (c *DataPacketCodec) Deserialize(data []byte) (any, error) { type ErrorPacketCodec struct{} // Serialize encodes an ErrorPacket into binary format: -// [PacketTypeID(1B)][RPCID(8B)][MsgLen(4B)][Msg] +// [PacketTypeID(1B)][RPCID(8B)][DstIP(4B)][DstPort(2B)][SrcIP(4B)][SrcPort(2B)][MsgLen(4B)][Msg] func (c *ErrorPacketCodec) Serialize(packet any, pool *common.BufferPool) ([]byte, error) { p, ok := packet.(*ErrorPacket) if !ok { @@ -167,11 +171,11 @@ func (c *ErrorPacketCodec) Serialize(packet any, pool *common.BufferPool) ([]byt } msgBytes := []byte(p.ErrorMsg) - if len(msgBytes) > MaxUDPPayloadSize-13 { // 1+8+4 header = 13B + if len(msgBytes) > MaxUDPPayloadSize-29 { // 1+8+4+2+4+2+4 header = 29B return nil, errors.New("error message too long, must fit in one MTU") } - totalSize := 13 + len(msgBytes) + totalSize := 29 + len(msgBytes) var buf []byte if pool != nil { @@ -183,8 +187,24 @@ func (c *ErrorPacketCodec) Serialize(packet any, pool *common.BufferPool) ([]byt // Write fields buf[0] = byte(p.PacketTypeID) binary.LittleEndian.PutUint64(buf[1:9], p.RPCID) - binary.LittleEndian.PutUint32(buf[9:13], uint32(len(msgBytes))) - copy(buf[13:], msgBytes) + + // Copy destination IP (4 bytes) + copy(buf[9:13], p.DstIP[:]) + + // Write destination port + binary.LittleEndian.PutUint16(buf[13:15], p.DstPort) + + // Copy source IP (4 bytes) + copy(buf[15:19], p.SrcIP[:]) + + // Write source port + binary.LittleEndian.PutUint16(buf[19:21], p.SrcPort) + + // Write message length + binary.LittleEndian.PutUint32(buf[21:25], uint32(len(msgBytes))) + + // Copy message + copy(buf[25:], msgBytes) // Note: We don't return the buffer to the pool here because it's returned to the caller // The caller (transport.Send) is responsible for returning it after WriteToUDP @@ -192,20 +212,35 @@ func (c *ErrorPacketCodec) Serialize(packet any, pool *common.BufferPool) ([]byt } // Deserialize decodes binary data into an ErrorPacket +// Format: [PacketTypeID(1B)][RPCID(8B)][DstIP(4B)][DstPort(2B)][SrcIP(4B)][SrcPort(2B)][MsgLen(4B)][Msg] func (c *ErrorPacketCodec) Deserialize(data []byte) (any, error) { - if len(data) < 13 { + if len(data) < 29 { return nil, errors.New("data too short for ErrorPacket header") } pkt := &ErrorPacket{} pkt.PacketTypeID = PacketTypeID(data[0]) pkt.RPCID = binary.LittleEndian.Uint64(data[1:9]) - msgLen := binary.LittleEndian.Uint32(data[9:13]) - if len(data) < 13+int(msgLen) { + // Copy destination IP (4 bytes) + copy(pkt.DstIP[:], data[9:13]) + + // Read destination port + pkt.DstPort = binary.LittleEndian.Uint16(data[13:15]) + + // Copy source IP (4 bytes) + copy(pkt.SrcIP[:], data[15:19]) + + // Read source port + pkt.SrcPort = binary.LittleEndian.Uint16(data[19:21]) + + // Read message length + msgLen := binary.LittleEndian.Uint32(data[21:25]) + + if len(data) < 29+int(msgLen) { return nil, errors.New("data too short for declared error message length") } - pkt.ErrorMsg = string(data[13 : 13+msgLen]) + pkt.ErrorMsg = string(data[25 : 25+msgLen]) return pkt, nil }