Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions cmd/proxy-buffer/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,41 @@ func (pb *PacketBuffer) deserializePacket(data []byte) (*packet.DataPacket, erro
return dataPacket, nil
}

// ProcessErrorPacket processes an error packet and returns a BufferedPacket for forwarding.
// Error packets are not buffered or fragmented - they fit in one MTU and are forwarded directly.
func (pb *PacketBuffer) ProcessErrorPacket(data []byte, src *net.UDPAddr) (*util.BufferedPacket, error) {
// Deserialize error packet
codec := &packet.ErrorPacketCodec{}
packetAny, err := codec.Deserialize(data)
if err != nil {
return nil, fmt.Errorf("failed to deserialize error packet: %w", err)
}

errorPacket, ok := packetAny.(*packet.ErrorPacket)
if !ok {
return nil, fmt.Errorf("unexpected packet type in ProcessErrorPacket")
}

// Create destination address from the packet's routing info
peer := &net.UDPAddr{IP: net.IP(errorPacket.DstIP[:]), Port: int(errorPacket.DstPort)}

// Create BufferedPacket with error message as payload
bufferedPacket := &util.BufferedPacket{
Payload: []byte(errorPacket.ErrorMsg),
Source: src,
Peer: peer,
PacketType: util.PacketTypeError,
RPCID: errorPacket.RPCID,
DstIP: errorPacket.DstIP,
DstPort: errorPacket.DstPort,
SrcIP: errorPacket.SrcIP,
SrcPort: errorPacket.SrcPort,
TotalPackets: 1,
}

return bufferedPacket, nil
}

// cleanupRoutine periodically cleans up expired fragments
func (pb *PacketBuffer) cleanupRoutine() {
for {
Expand Down
45 changes: 44 additions & 1 deletion cmd/proxy-buffer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/appnet-org/arpc/cmd/proxy-buffer/util"
"github.com/appnet-org/arpc/pkg/logging"
"github.com/appnet-org/arpc/pkg/packet"
"github.com/appnet-org/arpc/pkg/transport"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -205,6 +206,48 @@ func runProxyServer(port int, state *ProxyState, config *Config) error {
func handlePacket(conn *net.UDPConn, state *ProxyState, src *net.UDPAddr, data []byte, config *Config) {
ctx := context.Background()

// Check if this is an error packet (PacketTypeID == 3)
if len(data) > 0 && data[0] == byte(packet.PacketTypeError.TypeID) {
// Process error packet - forward directly without element chain
bufferedPacket, err := state.packetBuffer.ProcessErrorPacket(data, src)
if err != nil {
logging.Error("Error processing error packet", zap.Error(err))
return
}

// Serialize the error packet for forwarding
errorPacket := &packet.ErrorPacket{
PacketTypeID: packet.PacketTypeError.TypeID,
RPCID: bufferedPacket.RPCID,
DstIP: bufferedPacket.DstIP,
DstPort: bufferedPacket.DstPort,
SrcIP: bufferedPacket.SrcIP,
SrcPort: bufferedPacket.SrcPort,
ErrorMsg: string(bufferedPacket.Payload),
}

codec := &packet.ErrorPacketCodec{}
serialized, err := codec.Serialize(errorPacket, nil)
if err != nil {
logging.Error("Failed to serialize error packet for forwarding", zap.Error(err))
return
}

// Forward the error packet to the destination
if _, err := conn.WriteToUDP(serialized, bufferedPacket.Peer); err != nil {
logging.Error("Failed to forward error packet", zap.Error(err))
return
}

logging.Debug("Forwarded error packet",
zap.Uint64("rpcID", bufferedPacket.RPCID),
zap.String("from", bufferedPacket.Source.String()),
zap.String("to", bufferedPacket.Peer.String()),
zap.String("errorMsg", string(bufferedPacket.Payload)))

return
}

// Process packet - returns nil if still buffering fragments
// Returns a complete BufferedPacket only when ALL fragments have been received
bufferedPacket, err := state.packetBuffer.ProcessPacket(data, src)
Expand Down Expand Up @@ -252,7 +295,7 @@ func handlePacket(conn *net.UDPConn, state *ProxyState, src *net.UDPAddr, data [
logging.Error("Error processing packet through element chain or packet was dropped",
zap.Error(err))
// Send error packet back to the source
if sendErr := util.SendErrorPacket(conn, bufferedPacket.Source, bufferedPacket.RPCID, err.Error()); sendErr != nil {
if sendErr := util.SendErrorPacket(conn, bufferedPacket.Source, bufferedPacket.RPCID, err.Error(), bufferedPacket.SrcIP, bufferedPacket.SrcPort, bufferedPacket.DstIP, bufferedPacket.DstPort); sendErr != nil {
logging.Error("Failed to send error packet", zap.Error(sendErr))
}
return
Expand Down
9 changes: 6 additions & 3 deletions cmd/proxy-buffer/util/send_packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@ import (
"go.uber.org/zap"
)

// sendErrorPacket sends an error packet back to the source
func SendErrorPacket(conn *net.UDPConn, dest *net.UDPAddr, rpcID uint64, errorMsg string) error {
// SendErrorPacket sends an error packet back to the source with routing information
func SendErrorPacket(conn *net.UDPConn, dest *net.UDPAddr, rpcID uint64, errorMsg string, dstIP [4]byte, dstPort uint16, srcIP [4]byte, srcPort uint16) error {
// Create error packet
errorPacket := &packet.ErrorPacket{
PacketTypeID: packet.PacketTypeError.TypeID,
RPCID: rpcID,
DstIP: dstIP,
DstPort: dstPort,
SrcIP: srcIP,
SrcPort: srcPort,
ErrorMsg: errorMsg,
}

Expand All @@ -37,4 +41,3 @@ func SendErrorPacket(conn *net.UDPConn, dest *net.UDPAddr, rpcID uint64, errorMs

return nil
}

37 changes: 37 additions & 0 deletions cmd/proxy/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,43 @@ func (pb *PacketBuffer) deserializePacket(data []byte) (*packet.DataPacket, erro
return dataPacket, nil
}

// ProcessErrorPacket processes an error packet and returns a BufferedPacket for forwarding.
// Error packets are not buffered or fragmented - they fit in one MTU and are forwarded directly.
func (pb *PacketBuffer) ProcessErrorPacket(data []byte, src *net.UDPAddr) (*util.BufferedPacket, error) {
// Deserialize error packet
codec := &packet.ErrorPacketCodec{}
packetAny, err := codec.Deserialize(data)
if err != nil {
return nil, fmt.Errorf("failed to deserialize error packet: %w", err)
}

errorPacket, ok := packetAny.(*packet.ErrorPacket)
if !ok {
return nil, fmt.Errorf("unexpected packet type in ProcessErrorPacket")
}

// Create destination address from the packet's routing info
peer := &net.UDPAddr{IP: net.IP(errorPacket.DstIP[:]), Port: int(errorPacket.DstPort)}

// Create BufferedPacket with error message as payload
bufferedPacket := &util.BufferedPacket{
Payload: []byte(errorPacket.ErrorMsg),
Source: src,
Peer: peer,
PacketType: util.PacketTypeError,
RPCID: errorPacket.RPCID,
DstIP: errorPacket.DstIP,
DstPort: errorPacket.DstPort,
SrcIP: errorPacket.SrcIP,
SrcPort: errorPacket.SrcPort,
IsFull: true,
SeqNumber: -1,
TotalPackets: 1,
}

return bufferedPacket, nil
}

// CleanupUsedFragments removes fragments up to and including the lastUsedSeqNum
// This should be called after the public segment has been successfully forwarded
func (pb *PacketBuffer) CleanupUsedFragments(connKey string, rpcID uint64, lastUsedSeqNum uint16) {
Expand Down
191 changes: 191 additions & 0 deletions cmd/proxy/buffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1368,3 +1368,194 @@ func TestFragmentPacketForForward_LastUsedSeqNum_NormalCase(t *testing.T) {

t.Log("SUCCESS: Normal case without packing works correctly")
}

// TestErrorPacketCodec_SerializeDeserialize tests error packet serialization and deserialization with IP/port fields
func TestErrorPacketCodec_SerializeDeserialize(t *testing.T) {
codec := &packet.ErrorPacketCodec{}

// Create an error packet with routing information
originalPacket := &packet.ErrorPacket{
PacketTypeID: packet.PacketTypeError.TypeID,
RPCID: uint64(12345),
DstIP: [4]byte{192, 168, 1, 100},
DstPort: 8080,
SrcIP: [4]byte{10, 0, 0, 50},
SrcPort: 9090,
ErrorMsg: "Test error message",
}

// Serialize the packet
serialized, err := codec.Serialize(originalPacket, nil)
if err != nil {
t.Fatalf("Failed to serialize error packet: %v", err)
}

// Verify serialized data has correct length (29 bytes header + message length)
expectedLen := 29 + len("Test error message")
if len(serialized) != expectedLen {
t.Errorf("Expected serialized length %d, got %d", expectedLen, len(serialized))
}

// Deserialize the packet
deserializedAny, err := codec.Deserialize(serialized)
if err != nil {
t.Fatalf("Failed to deserialize error packet: %v", err)
}

deserializedPacket, ok := deserializedAny.(*packet.ErrorPacket)
if !ok {
t.Fatalf("Deserialized packet is not an ErrorPacket")
}

// Verify all fields match
if deserializedPacket.PacketTypeID != originalPacket.PacketTypeID {
t.Errorf("PacketTypeID mismatch: expected %d, got %d", originalPacket.PacketTypeID, deserializedPacket.PacketTypeID)
}
if deserializedPacket.RPCID != originalPacket.RPCID {
t.Errorf("RPCID mismatch: expected %d, got %d", originalPacket.RPCID, deserializedPacket.RPCID)
}
if deserializedPacket.DstIP != originalPacket.DstIP {
t.Errorf("DstIP mismatch: expected %v, got %v", originalPacket.DstIP, deserializedPacket.DstIP)
}
if deserializedPacket.DstPort != originalPacket.DstPort {
t.Errorf("DstPort mismatch: expected %d, got %d", originalPacket.DstPort, deserializedPacket.DstPort)
}
if deserializedPacket.SrcIP != originalPacket.SrcIP {
t.Errorf("SrcIP mismatch: expected %v, got %v", originalPacket.SrcIP, deserializedPacket.SrcIP)
}
if deserializedPacket.SrcPort != originalPacket.SrcPort {
t.Errorf("SrcPort mismatch: expected %d, got %d", originalPacket.SrcPort, deserializedPacket.SrcPort)
}
if deserializedPacket.ErrorMsg != originalPacket.ErrorMsg {
t.Errorf("ErrorMsg mismatch: expected %s, got %s", originalPacket.ErrorMsg, deserializedPacket.ErrorMsg)
}
}

// TestErrorPacketCodec_MTUValidation tests that error messages exceeding MTU are rejected
func TestErrorPacketCodec_MTUValidation(t *testing.T) {
codec := &packet.ErrorPacketCodec{}

// Create an error packet with a message that's too long
// MaxUDPPayloadSize is typically 1400, header is 29 bytes, so max message is 1371 bytes
longMessage := make([]byte, packet.MaxUDPPayloadSize-28) // One byte too long
for i := range longMessage {
longMessage[i] = 'A'
}

errorPacket := &packet.ErrorPacket{
PacketTypeID: packet.PacketTypeError.TypeID,
RPCID: uint64(12345),
DstIP: [4]byte{192, 168, 1, 100},
DstPort: 8080,
SrcIP: [4]byte{10, 0, 0, 50},
SrcPort: 9090,
ErrorMsg: string(longMessage),
}

// Attempt to serialize - should fail
_, err := codec.Serialize(errorPacket, nil)
if err == nil {
t.Error("Expected error when serializing message that exceeds MTU, but got none")
}
if err != nil && err.Error() != "error message too long, must fit in one MTU" {
t.Errorf("Expected MTU error message, got: %v", err)
}

// Test with a message that fits
validMessage := make([]byte, packet.MaxUDPPayloadSize-30) // Should fit
for i := range validMessage {
validMessage[i] = 'B'
}
errorPacket.ErrorMsg = string(validMessage)

_, err = codec.Serialize(errorPacket, nil)
if err != nil {
t.Errorf("Expected no error for valid message size, got: %v", err)
}
}

// TestPacketBuffer_ProcessErrorPacket tests the ProcessErrorPacket method
func TestPacketBuffer_ProcessErrorPacket(t *testing.T) {
pb := NewPacketBuffer(5 * time.Second)
defer pb.Close()

// Create and serialize an error packet
codec := &packet.ErrorPacketCodec{}
errorPacket := &packet.ErrorPacket{
PacketTypeID: packet.PacketTypeError.TypeID,
RPCID: uint64(99999),
DstIP: [4]byte{192, 168, 1, 200},
DstPort: 7070,
SrcIP: [4]byte{10, 0, 0, 100},
SrcPort: 6060,
ErrorMsg: "Connection timeout",
}

serialized, err := codec.Serialize(errorPacket, nil)
if err != nil {
t.Fatalf("Failed to serialize error packet: %v", err)
}

// Process the error packet
src := &net.UDPAddr{IP: net.IPv4(10, 0, 0, 100), Port: 6060}
bufferedPacket, err := pb.ProcessErrorPacket(serialized, src)
if err != nil {
t.Fatalf("Failed to process error packet: %v", err)
}

// Verify the buffered packet
if bufferedPacket == nil {
t.Fatal("Expected buffered packet, got nil")
}

if bufferedPacket.RPCID != errorPacket.RPCID {
t.Errorf("RPCID mismatch: expected %d, got %d", errorPacket.RPCID, bufferedPacket.RPCID)
}

if bufferedPacket.PacketType != util.PacketTypeError {
t.Errorf("PacketType mismatch: expected Error, got %v", bufferedPacket.PacketType)
}

if string(bufferedPacket.Payload) != errorPacket.ErrorMsg {
t.Errorf("Payload mismatch: expected %s, got %s", errorPacket.ErrorMsg, string(bufferedPacket.Payload))
}

if bufferedPacket.DstIP != errorPacket.DstIP {
t.Errorf("DstIP mismatch: expected %v, got %v", errorPacket.DstIP, bufferedPacket.DstIP)
}

if bufferedPacket.DstPort != errorPacket.DstPort {
t.Errorf("DstPort mismatch: expected %d, got %d", errorPacket.DstPort, bufferedPacket.DstPort)
}

if bufferedPacket.SrcIP != errorPacket.SrcIP {
t.Errorf("SrcIP mismatch: expected %v, got %v", errorPacket.SrcIP, bufferedPacket.SrcIP)
}

if bufferedPacket.SrcPort != errorPacket.SrcPort {
t.Errorf("SrcPort mismatch: expected %d, got %d", errorPacket.SrcPort, bufferedPacket.SrcPort)
}

// Verify peer address is constructed from DstIP and DstPort
expectedPeerIP := net.IP(errorPacket.DstIP[:])
if !bufferedPacket.Peer.IP.Equal(expectedPeerIP) {
t.Errorf("Peer IP mismatch: expected %v, got %v", expectedPeerIP, bufferedPacket.Peer.IP)
}

if bufferedPacket.Peer.Port != int(errorPacket.DstPort) {
t.Errorf("Peer Port mismatch: expected %d, got %d", errorPacket.DstPort, bufferedPacket.Peer.Port)
}

// Verify packet is marked as full (no fragmentation)
if !bufferedPacket.IsFull {
t.Error("Expected IsFull to be true for error packet")
}

if bufferedPacket.SeqNumber != -1 {
t.Errorf("Expected SeqNumber -1, got %d", bufferedPacket.SeqNumber)
}

if bufferedPacket.TotalPackets != 1 {
t.Errorf("Expected TotalPackets 1, got %d", bufferedPacket.TotalPackets)
}
}
Loading