diff --git a/inproc.go b/inproc.go index 0f2acaa..56b6643 100644 --- a/inproc.go +++ b/inproc.go @@ -6,7 +6,6 @@ package inproc import ( "errors" - "io" "net" "sync" "sync/atomic" @@ -46,10 +45,11 @@ func init() { // conn implements the net.Conn interface. type conn struct { - r io.ReadCloser - w io.WriteCloser - laddr addr - raddr addr + r net.Conn + w net.Conn + laddr addr + raddr addr + rDeadline time.Time } // Read reads data from the connection. @@ -85,24 +85,35 @@ func (c *conn) RemoteAddr() net.Addr { // SetDeadline implements the net.Conn SetDeadline method. func (c *conn) SetDeadline(t time.Time) error { - return errors.New("not supported") + if err := c.r.SetReadDeadline(t); err != nil { + return err + } + if err := c.w.SetWriteDeadline(t); err != nil { + // if setting read deadline succeeded, but setting write deadline + // failed, revert old read deadline. + c.r.SetReadDeadline(c.rDeadline) + return err + } + // update read deadline + c.rDeadline = t + return nil } // SetReadDeadline implements the net.Conn SetReadDeadline method. func (c *conn) SetReadDeadline(t time.Time) error { - return errors.New("not supported") + return c.r.SetReadDeadline(t) } // SetWriteDeadline implements the net.Conn SetWriteDeadline method. func (c *conn) SetWriteDeadline(t time.Time) error { - return errors.New("not supported") + return c.w.SetWriteDeadline(t) } // Dial connects to an address. func Dial(address string) (net.Conn, error) { raddr := addr{network: network, address: address} var accepter *accepter - r, w := io.Pipe() + r, w := net.Pipe() conn := &conn{w: w, laddr: raddr} addrs.locker.RLock() l, ok := addrs.listeners[raddr] @@ -141,7 +152,7 @@ type listener struct { type accepter struct { *conn - reader io.ReadCloser + reader net.Conn done chan struct{} } @@ -162,7 +173,7 @@ func Listen(address string) (net.Listener, error) { // Accept waits for and returns the next connection to the listener. func (l *listener) Accept() (net.Conn, error) { - r, w := io.Pipe() + r, w := net.Pipe() accepter := &accepter{conn: &conn{w: w, laddr: l.laddr}, reader: r} accepter.done = make(chan struct{}) l.locker.Lock() diff --git a/inproc_test.go b/inproc_test.go index 57a542e..18ca40b 100644 --- a/inproc_test.go +++ b/inproc_test.go @@ -74,3 +74,107 @@ func TestINPROC(t *testing.T) { l.Close() wg.Wait() } + +func TestDeadline(t *testing.T) { + address := ":9999" + l, err := Listen(address) + if err != nil { + t.Error(err) + return + } + defer l.Close() + go func() { + for i := 0; i < 2; i++ { + conn, err := l.Accept() + if err != nil { + t.Error(err) + return + } + if _, err := conn.Read(make([]byte, 1)); err != nil { + t.Error(err) + return + } + if _, err := conn.Write([]byte{0}); err != nil { + t.Error(err) + return + } + conn.Close() + } + }() + conn, err := Dial(address) + if err != nil { + t.Error(err) + return + } + + // Test SetReadDeadline + conn.SetReadDeadline(time.Now().Add(1)) + n, _ := conn.Read(make([]byte, 1)) + if n != 0 { + t.Errorf("read deadline error %d != %d", n, 0) + return + } + // Write value to progress listener status + if _, err := conn.Write([]byte{0}); err != nil { + t.Error(err) + return + } + // Reset read deadline + conn.SetReadDeadline(time.Time{}) + + // Test SetWriteDeadline + conn.SetWriteDeadline(time.Now().Add(1)) + if _, err := conn.Write([]byte{0}); err == nil { + t.Error(err) + return + } + n, _ = conn.Write([]byte{0}) + if n != 0 { + t.Errorf("write deadline error %d != %d", n, 0) + return + } + // Read value to progress listener status, then close connection + if _, err := conn.Read(make([]byte, 1)); err != nil { + t.Error(err) + return + } + // Reset write deadline + conn.SetWriteDeadline(time.Time{}) + conn.Close() + + // Test SetDeadline + conn, err = Dial(address) + if err != nil { + t.Error(err) + return + } + conn.SetDeadline(time.Now().Add(1)) + // Test read + n, _ = conn.Read(make([]byte, 1)) + if n != 0 { + t.Errorf("deadline error %d != %d", n, 0) + return + } + // Reset deadline + conn.SetDeadline(time.Time{}) + + // Write value to progress listener status + if _, err := conn.Write([]byte{0}); err != nil { + t.Error(err) + return + } + conn.SetDeadline(time.Now().Add(1)) + // Test write + n, _ = conn.Write([]byte{0}) + if n != 0 { + t.Errorf("deadline error %d != %d", n, 0) + return + } + // Read value to progress listener status, then close connection + _, err = conn.Read(make([]byte, 1)) + if err == nil { + t.Error(err) + return + } + conn.Close() +}