diff --git a/internal/internal_test.go b/internal/internal_test.go new file mode 100644 index 0000000..4eaafed --- /dev/null +++ b/internal/internal_test.go @@ -0,0 +1,248 @@ +package internal + +import ( + "bytes" + "io" + "math/rand" + "testing" +) + +func TestRing(t *testing.T) { + rng := rand.New(rand.NewSource(0)) + const bufSize = 10 + r := &Ring{ + Buf: make([]byte, bufSize), + } + const data = "hello" + _, err := r.Write([]byte(data)) + if err != nil { + t.Error(err) + } + // Case where data is contiguous and at start of buffer. + var buf [bufSize]byte + n, err := fragmentReadInto(r, buf[:]) + if err != nil { + t.Fatal(err) + } + if string(buf[:n]) != data { + t.Fatalf("got %q; want %q", buf[:n], data) + } + + // Case where data overwrites end of buffer. + const overdata = "hello world" + n, err = r.Write([]byte(overdata)) + if err == nil || n > 0 { + t.Fatal(err, n) + } + + // Set Random data in ring buffer and read it back. + for i := 0; i < 32; i++ { + n := rng.Intn(bufSize) + copy(buf[:], overdata[:n]) + offset := rng.Intn(bufSize - 1) + setRingData(t, r, offset, buf[:n]) + + // Case where data wraps around end of buffer. + n, err = r.Read(buf[:]) + if err != nil { + break + } + if string(buf[:n]) != overdata[:n] { + t.Error("got", buf[:n], "want", overdata[:n]) + } + } + + // Set random data and write some more and read it back. + for i := 0; i < 32; i++ { + nfirst := rng.Intn(bufSize) / 2 + nsecond := rng.Intn(bufSize) / 2 + if nfirst+nsecond > bufSize { + nfirst = bufSize - nsecond + } + offset := rng.Intn(bufSize - 1) + + copy(buf[:], overdata[:nfirst]) + setRingData(t, r, offset, buf[:nfirst]) + // println("test", r.end, r.off, offset, r) + ngot, err := r.Write([]byte(overdata[nfirst : nfirst+nsecond])) + if err != nil { + t.Fatal(err) + } + if ngot != nsecond { + t.Errorf("%d did not write data correctly: got %d; want %d", i, ngot, nsecond) + } + buf = [bufSize]byte{} + // Case where data wraps around end of buffer. + n, err = r.Read(buf[:]) + if err != nil { + break + } + + if n != nfirst+nsecond { + t.Errorf("got %d; want %d (%d+%d)", n, nfirst+nsecond, nfirst, nsecond) + } + if string(buf[:n]) != overdata[:n] { + t.Errorf("got %q; want %q", buf[:n], overdata[:n]) + } + } + _ = r.string() +} + +func TestRing2(t *testing.T) { + const maxsize = 6 + const ntests = 800 + rng := rand.New(rand.NewSource(0)) + data := make([]byte, maxsize) + ringbuf := make([]byte, maxsize) + auxbuf := make([]byte, maxsize) + rng.Read(data) + // TODO(soypat): This test fails for greater ntests. + // It was not fixed because of a compiler bug: https://github.com/golang/go/issues/64854 + // and since the benefits of the changes in this PR are already much better than what we previously had. + for i := 0; i < ntests; i++ { + dsize := max(rng.Intn(len(data)), 1) + if !testRing1_loopback(t, rng, ringbuf, data[:dsize], auxbuf) { + t.Fatalf("failed test %d", i) + } + } +} + +func TestRing_findcrash(t *testing.T) { + const maxsize = 33 + const ntests = 800000 + r := Ring{ + Buf: make([]byte, maxsize*6), + } + rng := rand.New(rand.NewSource(0)) + data := make([]byte, maxsize) + + for i := 0; i < ntests; i++ { + free := r.Free() + if free < 0 { + t.Fatal("free < 0") + } + if rng.Intn(2) == 0 { + l := max(rng.Intn(len(data)), 1) + if l > free { + continue // Buffer full. + } + n, err := r.Write(data[:l]) + expectFree := free - n + free = r.Free() + if n != l { + t.Fatal(i, "write failed", n, l, err) + } else if expectFree != free { + t.Fatal(i, "free not updated correctly", expectFree, free) + } + } + buffered := r.Buffered() + if buffered < 0 { + t.Fatal("buffered < 0") + } + if rng.Intn(2) == 0 { + l := max(rng.Intn(len(data)), 1) + n, err := r.Read(data[:l]) + expectRead := min(buffered, l) + expectBuffered := buffered - n + buffered = r.Buffered() + if n != expectRead { + t.Fatal(i, "read failed", n, l, expectRead, err) + } else if buffered != expectBuffered { + t.Fatal(i, "buffered not updated correctly", expectBuffered, buffered) + } + } + } +} + +func testRing1_loopback(t *testing.T, rng *rand.Rand, ringbuf, data, auxbuf []byte) bool { + if len(data) > len(ringbuf) || len(data) > len(auxbuf) { + panic("invalid ringbuf or data") + } + dsize := len(data) + var r Ring + r.Buf = ringbuf + + nfirst := rng.Intn(dsize) / 2 + nsecond := rng.Intn(dsize) / 2 + if nfirst == 0 || nsecond == 0 { + return true + } + offset := rng.Intn(dsize - 1) + + setRingData(t, &r, offset, data[:nfirst]) + ngot, err := r.Write(data[nfirst : nfirst+nsecond]) + if err != nil { + t.Error(err) + return false + } + if ngot != nsecond { + t.Errorf("did not write data correctly: got %d; want %d", ngot, nsecond) + } + // Case where data wraps around end of buffer. + n, err := r.Read(auxbuf[:]) + if err != nil { + t.Error(err) + return false + } + + if n != nfirst+nsecond { + t.Errorf("got %d; want %d (%d+%d)", n, nfirst+nsecond, nfirst, nsecond) + } + if !bytes.Equal(auxbuf[:n], data[:n]) { + t.Errorf("got %q; want %q", auxbuf[:n], data[:n]) + } + return !t.Failed() +} + +func fragmentReadInto(r io.Reader, buf []byte) (n int, _ error) { + maxSize := len(buf) / 4 + for { + ntop := min(n+rand.Intn(maxSize)+1, len(buf)) + ngot, err := r.Read(buf[n:ntop]) + n += ngot + if err != nil { + if err == io.EOF { + return n, nil + } + return n, err + } + if n == len(buf) { + return n, nil + } + } +} + +func setRingData(t *testing.T, r *Ring, offset int, data []byte) { + t.Helper() + if len(data) > len(r.Buf) { + panic("data too large") + } + n := copy(r.Buf[offset:], data) + r.End = offset + n + if len(data)+offset > len(r.Buf) { + // End of buffer not enough to hold data, wrap around. + n = copy(r.Buf, data[n:]) + r.End = n + } + r.Off = offset + r.onReadEnd() + // println("buf:", len(r.Buf), "end:", r.End, "off:", r.Off, offset, "data:", len(data)) + free := r.Free() + wantFree := len(r.Buf) - len(data) + if free != wantFree { + t.Fatalf("free got %d; want %d", free, wantFree) + } + buffered := r.Buffered() + wantBuffered := len(data) + if buffered != wantBuffered { + t.Fatalf("buffered got %d; want %d", buffered, wantBuffered) + } + end := r.End + off := r.Off + sdata := r.string() + if sdata != string(data) { + t.Fatalf("data got %q; want %q", sdata, data) + } + r.End = end + r.Off = off +} diff --git a/internal/ringbuffer.go b/internal/ringbuffer.go new file mode 100644 index 0000000..55aad38 --- /dev/null +++ b/internal/ringbuffer.go @@ -0,0 +1,155 @@ +package internal + +import ( + "bytes" + "errors" + "io" +) + +var ( + errRingBufferFull = errors.New("seqs: ringbuffer full") + errRingBufferSmall = errors.New("seqs: ringbuffer too small") +) + +// Ring is a ring buffer implementation. +type Ring struct { + Buf []byte + Off int + End int +} + +func (r *Ring) Write(b []byte) (int, error) { + free := r.Free() + if len(b) > free { + return 0, errRingBufferFull + } + midFree := r.midFree() + if midFree > 0 { + // start end off len(buf) + // | used | mfree | used | + n := copy(r.Buf[r.End:r.Off], b) + r.End += n + return n, nil + } + // start off end len(buf) + // | sfree | used | efree | + n := copy(r.Buf[r.End:], b) + r.End += n + if n < len(b) { + n2 := copy(r.Buf, b[n:]) + r.End = n2 + n += n2 + } + return n, nil +} + +// WriteLimited performs a write that does not write over the ring buffer's +// limitOffset index, which points to a position to r.Buf. +func (r *Ring) WriteLimited(b []byte, limitOffset int) (int, error) { + if limitOffset > len(r.Buf) { + panic("bad limit offset") + } + if len(b) > len(r.Buf) { + return 0, errRingBufferSmall + } + writeEnd := r.Off + len(b) + if limitOffset >= r.Off && writeEnd > limitOffset { + return 0, errRingBufferFull + } else if writeEnd > len(r.Buf) { + writeEnd %= len(r.Buf) + if writeEnd > limitOffset { + return 0, errRingBufferFull + } + } + return r.Write(b) +} + +func (r *Ring) Read(b []byte) (int, error) { + if r.Buffered() == 0 { + return 0, io.EOF + } + + if r.End > r.Off { + // start off end len(buf) + // | sfree | used | efree | + n := copy(b, r.Buf[r.Off:r.End]) + r.Off += n + r.onReadEnd() + return n, nil + } + // start end off len(buf) + // | used | mfree | used | + n := copy(b, r.Buf[r.Off:]) + r.Off += n + if n < len(b) { + n2 := copy(b[n:], r.Buf[:r.End]) + r.Off = n2 + n += n2 + } + r.onReadEnd() + return n, nil +} + +func (r *Ring) Buffered() int { + return len(r.Buf) - r.Free() +} + +func (r *Ring) Reset() { + r.Off = 0 + r.End = 0 +} + +func (r *Ring) Free() int { + if r.Off == 0 { + return len(r.Buf) - r.End + } + if r.Off < r.End { + // start off end len(buf) + // | sfree | used | efree | + startFree := r.Off + endFree := len(r.Buf) - r.End + return startFree + endFree + } + // start end off len(buf) + // | used | mfree | used | + return r.Off - r.End +} + +func (r *Ring) midFree() int { + if r.End >= r.Off { + return 0 + } + return r.Off - r.End +} + +func (r *Ring) onReadEnd() { + if r.End == len(r.Buf) { + r.End = 0 // Wrap around. + } + if r.Off == len(r.Buf) { + r.Off = 0 // Wrap around. + } + if r.Off == r.End { + r.Reset() // We read everything, reset. + } +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func (r *Ring) string() string { + var b bytes.Buffer + b.ReadFrom(r) + return b.String() +} diff --git a/ring.go b/ring.go new file mode 100644 index 0000000..8830ccc --- /dev/null +++ b/ring.go @@ -0,0 +1,143 @@ +package seqs + +import ( + "errors" + "time" + + "github.com/soypat/seqs/internal" +) + +func NewRingTx(buf []byte, maxQueuedPackets int) *RingTx { + if maxQueuedPackets <= 0 || len(buf) < 2 || len(buf) < maxQueuedPackets { + panic("invalid argument to NewRingTx") + } + return &RingTx{ + rawbuf: buf, + packets: make([]ringidx, maxQueuedPackets), + } +} + +// RingTx is a ring buffer with retransmission queue functionality added. +type RingTx struct { + // rawbuf contains the ring buffer of ordered bytes. It should be the size of the window. + rawbuf []byte + // packets contains + packets []ringidx + // firstPkt is the index of the oldest packet in the packets field. + firstPkt int + lastPkt int + // unsentOff is the offset of start of unsent data into rawbuf. + unsentoff int + // unsentend is the offset of end of unsent data in rawbuf. + unsentend int +} + +// ringidx represents packet data inside RingTx +type ringidx struct { + // off is data start offset of packet data inside buf. + off int + // end is the ringed data end offset, non-inclusive. + end int + // seq is the sequence number of the packet. + seq Value + t time.Time + // acked flags if this packet has been acknowledged. Useful for SACK (selective acknowledgement) + // acked bool +} + +// Buffered returns the amount of unsent bytes. +func (tx *RingTx) Buffered() int { + r := tx.unsentRing() + return r.Buffered() +} + +// BufferedSent returns the total amount of bytes sent but not acked. +func (tx *RingTx) BufferedSent() int { + r := tx.sentRing() + return r.Buffered() +} + +// Write writes data to the underlying unsent data ring buffer. +func (tx *RingTx) Write(b []byte) (int, error) { + first := tx.packets[tx.firstPkt] + r := tx.unsentRing() + if first.off < 0 { + // No packets in queue case. + return r.Write(b) + } + return r.WriteLimited(b, first.off) +} + +// ReadPacket reads from the unsent data ring buffer and generates a new packet segment. +// It fails if the sent packet queue is full. +func (tx *RingTx) NewPacketAndRead(b []byte) (int, error) { + nxtpkt := (tx.lastPkt + 1) % len(tx.packets) + if tx.firstPkt == nxtpkt { + return 0, errors.New("packet queue full") + } + + r := tx.unsentRing() + start := r.Off + n, err := r.Read(b) + if err != nil { + return n, err + } + last := &tx.packets[tx.lastPkt] + rlast := tx.packetRing(tx.lastPkt) + tx.packets[nxtpkt].off = start + tx.packets[nxtpkt].end = r.Off + tx.packets[nxtpkt].seq = last.seq + Value(rlast.Buffered()) + tx.lastPkt = nxtpkt + tx.unsentoff = r.Off + return n, nil +} + +// IsQueueFull returns true if the sent packet queue is full in which +// case a call to ReadPacket is guaranteed to fail. +func (tx *RingTx) IsQueueFull() bool { + return tx.firstPkt == (tx.lastPkt+1)%len(tx.packets) +} + +func (tx *RingTx) packetRing(i int) internal.Ring { + pkt := tx.packets[i] + if pkt.off < 0 { + return internal.Ring{} + } + return tx.ring(pkt.off, pkt.end) +} + +// RecvSegment processes an incoming segment and updates the sent packet queue +func (tx *RingTx) RecvACK(ack Value) error { + i := tx.firstPkt + for { + pkt := &tx.packets[i] + if ack >= pkt.seq { + // Packet was received by remote. Mark it as acked. + pkt.off = -1 + tx.firstPkt++ + continue + } + if i == tx.lastPkt { + break + } + i = (i + 1) % len(tx.packets) + } + return nil +} + +func (tx *RingTx) unsentRing() internal.Ring { + return tx.ring(tx.unsentoff, tx.unsentend) +} + +func (tx *RingTx) sentRing() internal.Ring { + first := tx.packets[tx.firstPkt] + if first.off < 0 { + return tx.ring(0, 0) + } + last := tx.packets[tx.lastPkt] + return tx.ring(first.off, last.end) +} + +func (tx *RingTx) ring(off, end int) internal.Ring { + return internal.Ring{Buf: tx.rawbuf, Off: off, End: end} +} diff --git a/seqs_test.go b/seqs_test.go index 61b6c5c..4a02b13 100644 --- a/seqs_test.go +++ b/seqs_test.go @@ -795,6 +795,7 @@ func TestIssue19(t *testing.T) { } } + func FuzzTCBActions(f *testing.F) { const mtu = 2048 const ( diff --git a/stacks/intern_test.go b/stacks/intern_test.go index 0ccb63b..9913e90 100644 --- a/stacks/intern_test.go +++ b/stacks/intern_test.go @@ -1,254 +1,10 @@ package stacks import ( - "bytes" - "io" - "math/rand" - "testing" - "github.com/soypat/seqs" + "github.com/soypat/seqs/internal" ) -func TestRing(t *testing.T) { - rng := rand.New(rand.NewSource(0)) - const bufSize = 10 - r := &ring{ - buf: make([]byte, bufSize), - } - const data = "hello" - _, err := r.Write([]byte(data)) - if err != nil { - t.Error(err) - } - // Case where data is contiguous and at start of buffer. - var buf [bufSize]byte - n, err := fragmentReadInto(r, buf[:]) - if err != nil { - t.Fatal(err) - } - if string(buf[:n]) != data { - t.Fatalf("got %q; want %q", buf[:n], data) - } - - // Case where data overwrites end of buffer. - const overdata = "hello world" - n, err = r.Write([]byte(overdata)) - if err == nil || n > 0 { - t.Fatal(err, n) - } - - // Set Random data in ring buffer and read it back. - for i := 0; i < 32; i++ { - n := rng.Intn(bufSize) - copy(buf[:], overdata[:n]) - offset := rng.Intn(bufSize - 1) - setRingData(t, r, offset, buf[:n]) - - // Case where data wraps around end of buffer. - n, err = r.Read(buf[:]) - if err != nil { - break - } - if string(buf[:n]) != overdata[:n] { - t.Error("got", buf[:n], "want", overdata[:n]) - } - } - - // Set random data and write some more and read it back. - for i := 0; i < 32; i++ { - nfirst := rng.Intn(bufSize) / 2 - nsecond := rng.Intn(bufSize) / 2 - if nfirst+nsecond > bufSize { - nfirst = bufSize - nsecond - } - offset := rng.Intn(bufSize - 1) - - copy(buf[:], overdata[:nfirst]) - setRingData(t, r, offset, buf[:nfirst]) - // println("test", r.end, r.off, offset, r) - ngot, err := r.Write([]byte(overdata[nfirst : nfirst+nsecond])) - if err != nil { - t.Fatal(err) - } - if ngot != nsecond { - t.Errorf("%d did not write data correctly: got %d; want %d", i, ngot, nsecond) - } - buf = [bufSize]byte{} - // Case where data wraps around end of buffer. - n, err = r.Read(buf[:]) - if err != nil { - break - } - - if n != nfirst+nsecond { - t.Errorf("got %d; want %d (%d+%d)", n, nfirst+nsecond, nfirst, nsecond) - } - if string(buf[:n]) != overdata[:n] { - t.Errorf("got %q; want %q", buf[:n], overdata[:n]) - } - } - _ = r.string() -} - -func TestRing2(t *testing.T) { - const maxsize = 6 - const ntests = 800 - rng := rand.New(rand.NewSource(0)) - data := make([]byte, maxsize) - ringbuf := make([]byte, maxsize) - auxbuf := make([]byte, maxsize) - rng.Read(data) - // TODO(soypat): This test fails for greater ntests. - // It was not fixed because of a compiler bug: https://github.com/golang/go/issues/64854 - // and since the benefits of the changes in this PR are already much better than what we previously had. - for i := 0; i < ntests; i++ { - dsize := max(rng.Intn(len(data)), 1) - if !testRing1_loopback(t, rng, ringbuf, data[:dsize], auxbuf) { - t.Fatalf("failed test %d", i) - } - } -} - -func TestRing_findcrash(t *testing.T) { - const maxsize = 33 - const ntests = 800000 - r := ring{ - buf: make([]byte, maxsize*6), - } - rng := rand.New(rand.NewSource(0)) - data := make([]byte, maxsize) - - for i := 0; i < ntests; i++ { - free := r.Free() - if free < 0 { - t.Fatal("free < 0") - } - if rng.Intn(2) == 0 { - l := max(rng.Intn(len(data)), 1) - if l > free { - continue // Buffer full. - } - n, err := r.Write(data[:l]) - expectFree := free - n - free = r.Free() - if n != l { - t.Fatal(i, "write failed", n, l, err) - } else if expectFree != free { - t.Fatal(i, "free not updated correctly", expectFree, free) - } - } - buffered := r.Buffered() - if buffered < 0 { - t.Fatal("buffered < 0") - } - if rng.Intn(2) == 0 { - l := max(rng.Intn(len(data)), 1) - n, err := r.Read(data[:l]) - expectRead := min(buffered, l) - expectBuffered := buffered - n - buffered = r.Buffered() - if n != expectRead { - t.Fatal(i, "read failed", n, l, expectRead, err) - } else if buffered != expectBuffered { - t.Fatal(i, "buffered not updated correctly", expectBuffered, buffered) - } - } - } -} - -func testRing1_loopback(t *testing.T, rng *rand.Rand, ringbuf, data, auxbuf []byte) bool { - if len(data) > len(ringbuf) || len(data) > len(auxbuf) { - panic("invalid ringbuf or data") - } - dsize := len(data) - var r ring - r.buf = ringbuf - - nfirst := rng.Intn(dsize) / 2 - nsecond := rng.Intn(dsize) / 2 - if nfirst == 0 || nsecond == 0 { - return true - } - offset := rng.Intn(dsize - 1) - - setRingData(t, &r, offset, data[:nfirst]) - ngot, err := r.Write(data[nfirst : nfirst+nsecond]) - if err != nil { - t.Error(err) - return false - } - if ngot != nsecond { - t.Errorf("did not write data correctly: got %d; want %d", ngot, nsecond) - } - // Case where data wraps around end of buffer. - n, err := r.Read(auxbuf[:]) - if err != nil { - t.Error(err) - return false - } - - if n != nfirst+nsecond { - t.Errorf("got %d; want %d (%d+%d)", n, nfirst+nsecond, nfirst, nsecond) - } - if !bytes.Equal(auxbuf[:n], data[:n]) { - t.Errorf("got %q; want %q", auxbuf[:n], data[:n]) - } - return !t.Failed() -} - -func fragmentReadInto(r io.Reader, buf []byte) (n int, _ error) { - maxSize := len(buf) / 4 - for { - ntop := min(n+rand.Intn(maxSize)+1, len(buf)) - ngot, err := r.Read(buf[n:ntop]) - n += ngot - if err != nil { - if err == io.EOF { - return n, nil - } - return n, err - } - if n == len(buf) { - return n, nil - } - } -} - -func setRingData(t *testing.T, r *ring, offset int, data []byte) { - t.Helper() - if len(data) > len(r.buf) { - panic("data too large") - } - n := copy(r.buf[offset:], data) - r.end = offset + n - if len(data)+offset > len(r.buf) { - // End of buffer not enough to hold data, wrap around. - n = copy(r.buf, data[n:]) - r.end = n - } - r.off = offset - r.onReadEnd() - // println("buf:", len(r.buf), "end:", r.end, "off:", r.off, offset, "data:", len(data)) - free := r.Free() - wantFree := len(r.buf) - len(data) - if free != wantFree { - t.Fatalf("free got %d; want %d", free, wantFree) - } - buffered := r.Buffered() - wantBuffered := len(data) - if buffered != wantBuffered { - t.Fatalf("buffered got %d; want %d", buffered, wantBuffered) - } - end := r.end - off := r.off - sdata := r.string() - if sdata != string(data) { - t.Fatalf("data got %q; want %q", sdata, data) - } - r.end = end - r.off = off -} - // SCB is an internal routine for testing which returns the control block, // which is a simplified implementation of the TCB of RFC9293. func (tcp *TCPConn) SCB() *seqs.ControlBlock { return &tcp.scb } @@ -256,6 +12,6 @@ func (tcp *TCPConn) SCB() *seqs.ControlBlock { return &tcp.scb } func (dhcpc *DHCPClient) PortStack() *PortStack { return dhcpc.stack } func (dhcps *DHCPServer) PortStack() *PortStack { return dhcps.stack } -func (tcp *TCPConn) RingBuffers() (rx, tx *ring) { +func (tcp *TCPConn) RingBuffers() (rx, tx *internal.Ring) { return &tcp.rx, &tcp.tx } diff --git a/stacks/ring.go b/stacks/ring.go deleted file mode 100644 index 93843de..0000000 --- a/stacks/ring.go +++ /dev/null @@ -1,130 +0,0 @@ -package stacks - -import ( - "bytes" - "errors" - "io" -) - -var errRingBufferFull = errors.New("seqs/ring: buffer full") - -type ring struct { - buf []byte - off int - end int -} - -func (r *ring) Write(b []byte) (int, error) { - free := r.Free() - if len(b) > free { - return 0, errRingBufferFull - } - midFree := r.midFree() - if midFree > 0 { - // start end off len(buf) - // | used | mfree | used | - n := copy(r.buf[r.end:r.off], b) - r.end += n - return n, nil - } - // start off end len(buf) - // | sfree | used | efree | - n := copy(r.buf[r.end:], b) - r.end += n - if n < len(b) { - n2 := copy(r.buf, b[n:]) - r.end = n2 - n += n2 - } - return n, nil -} - -func (r *ring) Read(b []byte) (int, error) { - if r.Buffered() == 0 { - return 0, io.EOF - } - - if r.end > r.off { - // start off end len(buf) - // | sfree | used | efree | - n := copy(b, r.buf[r.off:r.end]) - r.off += n - r.onReadEnd() - return n, nil - } - // start end off len(buf) - // | used | mfree | used | - n := copy(b, r.buf[r.off:]) - r.off += n - if n < len(b) { - n2 := copy(b[n:], r.buf[:r.end]) - r.off = n2 - n += n2 - } - r.onReadEnd() - return n, nil -} - -func (r *ring) Buffered() int { - return len(r.buf) - r.Free() -} - -func (r *ring) Reset() { - r.off = 0 - r.end = 0 -} - -func (r *ring) Free() int { - if r.off == 0 { - return len(r.buf) - r.end - } - if r.off < r.end { - // start off end len(buf) - // | sfree | used | efree | - startFree := r.off - endFree := len(r.buf) - r.end - return startFree + endFree - } - // start end off len(buf) - // | used | mfree | used | - return r.off - r.end -} - -func (r *ring) midFree() int { - if r.end >= r.off { - return 0 - } - return r.off - r.end -} - -func (r *ring) onReadEnd() { - if r.end == len(r.buf) { - r.end = 0 // Wrap around. - } - if r.off == len(r.buf) { - r.off = 0 // Wrap around. - } - if r.off == r.end { - r.Reset() // We read everything, reset. - } -} - -func max(a, b int) int { - if a > b { - return a - } - return b -} - -func min(a, b int) int { - if a < b { - return a - } - return b -} - -func (r *ring) string() string { - var b bytes.Buffer - b.ReadFrom(r) - return b.String() -} diff --git a/stacks/tcpconn.go b/stacks/tcpconn.go index 544e885..2def073 100644 --- a/stacks/tcpconn.go +++ b/stacks/tcpconn.go @@ -34,8 +34,8 @@ type TCPConn struct { lastRx time.Time pkt TCPPacket scb seqs.ControlBlock - tx ring - rx ring + tx internal.Ring + rx internal.Ring // remote is the IP+port address of remote. remote netip.AddrPort localPort uint16 @@ -71,8 +71,8 @@ func NewTCPConn(stack *PortStack, cfg TCPConnConfig) (*TCPConn, error) { func makeTCPConn(stack *PortStack, tx, rx []byte) TCPConn { return TCPConn{ stack: stack, - tx: ring{buf: tx}, - rx: ring{buf: rx}, + tx: internal.Ring{Buf: tx}, + rx: internal.Ring{Buf: rx}, } } @@ -279,7 +279,7 @@ func (sock *TCPConn) openstack(state seqs.State, localPortNum uint16, iss seqs.V } func (sock *TCPConn) open(state seqs.State, localPortNum uint16, iss seqs.Value, remoteMAC [6]byte, remoteAddr netip.AddrPort) error { - err := sock.scb.Open(iss, seqs.Size(len(sock.rx.buf)), state) + err := sock.scb.Open(iss, seqs.Size(len(sock.rx.Buf)), state) if err != nil { return err } @@ -468,8 +468,8 @@ func (sock *TCPConn) deleteState() { sock.trace("TCPConn.deleteState", slog.Uint64("port", uint64(sock.localPort))) *sock = TCPConn{ stack: sock.stack, - rx: ring{buf: sock.rx.buf}, - tx: ring{buf: sock.tx.buf}, + rx: internal.Ring{Buf: sock.rx.Buf}, + tx: internal.Ring{Buf: sock.tx.Buf}, connid: sock.connid + 1, } } @@ -533,3 +533,17 @@ func (sock *TCPConn) info(msg string, attrs ...slog.Attr) { func (sock *TCPConn) logerr(msg string, attrs ...slog.Attr) { internal.LogAttrs(sock.stack.logger, slog.LevelError, msg, attrs...) } + +func max(a, b int) int { + if a > b { + return a + } + return b +} + +func min(a, b int) int { + if a < b { + return a + } + return b +}