@@ -2,6 +2,7 @@ package proxyproto
22
33import (
44 "bufio"
5+ "bytes"
56 "errors"
67 "fmt"
78 "io"
@@ -51,7 +52,8 @@ type Conn struct {
5152 once sync.Once
5253 readErr error
5354 conn net.Conn
54- bufReader * bufio.Reader
55+ reader io.Reader
56+ headerBuf * bytes.Buffer
5557 header * Header
5658 ProxyHeaderPolicy Policy
5759 Validate Validator
@@ -150,14 +152,8 @@ func (p *Listener) Addr() net.Addr {
150152// NewConn is used to wrap a net.Conn that may be speaking
151153// the proxy protocol into a proxyproto.Conn
152154func NewConn (conn net.Conn , opts ... func (* Conn )) * Conn {
153- // For v1 the header length is at most 108 bytes.
154- // For v2 the header length is at most 52 bytes plus the length of the TLVs.
155- // We use 256 bytes to be safe.
156- const bufSize = 256
157-
158155 pConn := & Conn {
159- bufReader : bufio .NewReaderSize (conn , bufSize ),
160- conn : conn ,
156+ conn : conn ,
161157 }
162158
163159 for _ , opt := range opts {
@@ -178,7 +174,7 @@ func (p *Conn) Read(b []byte) (int, error) {
178174 return 0 , p .readErr
179175 }
180176
181- return p .bufReader .Read (b )
177+ return p .reader .Read (b )
182178}
183179
184180// Write wraps original conn.Write
@@ -294,7 +290,22 @@ func (p *Conn) readHeader() error {
294290 }
295291 }
296292
297- header , err := Read (p .bufReader )
293+ // For v1 the header length is at most 108 bytes.
294+ // For v2 the header length is at most 52 bytes plus the length of the TLVs.
295+ // We use 256 bytes to be safe.
296+ const bufSize = 256
297+
298+ bb := bytes .NewBuffer (make ([]byte , 0 , bufSize ))
299+ tr := io .TeeReader (p .conn , bb )
300+ br := bufio .NewReaderSize (tr , bufSize )
301+
302+ header , err := Read (br )
303+
304+ if err == nil {
305+ io .CopyN (io .Discard , bb , int64 (header .length ))
306+ }
307+ p .headerBuf = bb
308+ p .reader = io .MultiReader (bb , p .conn )
298309
299310 // If the connection's readHeaderTimeout is more than 0, undo the change to the
300311 // deadline that we made above. Because we retain the readDeadline as part of our
@@ -360,5 +371,9 @@ func (p *Conn) WriteTo(w io.Writer) (int64, error) {
360371 if p .readErr != nil {
361372 return 0 , p .readErr
362373 }
363- return p .bufReader .WriteTo (w )
374+
375+ if n , err := p .headerBuf .WriteTo (w ); err != nil {
376+ return n , err
377+ }
378+ return io .Copy (w , p .conn )
364379}
0 commit comments