From 6ff5e9631d7c2d8d733d5af9ea95c82268bd7a38 Mon Sep 17 00:00:00 2001 From: Kate Ward Date: Tue, 12 May 2015 01:15:15 +0200 Subject: [PATCH 01/11] Reworking handshake() to support unit testing. - Split handshake() into separate functions - Created ClientIO interface to enable mocking - Added unit tests for clientInit() --- client.go | 140 ++++++++++++++++++++++++++++++++++--------------- client_test.go | 45 +++++++++++++++- 2 files changed, 141 insertions(+), 44 deletions(-) diff --git a/client.go b/client.go index 91402c9..77a60eb 100644 --- a/client.go +++ b/client.go @@ -1,7 +1,9 @@ -// Package vnc implements a VNC client. -// -// References: -// [PROTOCOL]: http://tools.ietf.org/html/rfc6143 +/* +Package vnc implements a VNC client. + +References: + [PROTOCOL]: http://tools.ietf.org/html/rfc6143 +*/ package vnc import ( @@ -13,9 +15,11 @@ import ( "unicode" ) +// The ClientConn type holds client connection information. type ClientConn struct { c net.Conn config *ClientConfig + rw ClientIO // If the pixel format uses a color map, then this is the color // map that is used. This should not be modified directly, since @@ -71,8 +75,25 @@ func Client(c net.Conn, cfg *ClientConfig) (*ClientConn, error) { c: c, config: cfg, } + conn.rw = NewClientIOReaderWriter(c, c) - if err := conn.handshake(); err != nil { + if err := conn.protocolVersionHandshake(); err != nil { + conn.Close() + return nil, err + } + if err := conn.securityHandshake(); err != nil { + conn.Close() + return nil, err + } + if err := conn.securityResultHandshake(); err != nil { + conn.Close() + return nil, err + } + if err := conn.clientInit(); err != nil { + conn.Close() + return nil, err + } + if err := conn.serverInit(); err != nil { conn.Close() return nil, err } @@ -279,7 +300,13 @@ func (c *ClientConn) SetPixelFormat(format *PixelFormat) error { return nil } -const pvLen = 12 // ProtocolVersion message length. +const ( + pvLen = 12 // ProtocolVersion message length. + + // Supported protocol versions. + PROTO_VERS_UNSUP = "UNSUP" + PROTO_VERS_3_8 = "003.008" +) func parseProtocolVersion(pv []byte) (uint, uint, error) { var major, minor uint @@ -299,42 +326,48 @@ func parseProtocolVersion(pv []byte) (uint, uint, error) { return major, minor, nil } -func (c *ClientConn) handshake() error { +// protocolVersionHandshake implements §7.1.1 ProtocolVersion Handshake. +func (c *ClientConn) protocolVersionHandshake() error { var protocolVersion [pvLen]byte - // 7.1.1, read the ProtocolVersion message sent by the server. + // Read the ProtocolVersion message sent by the server. if _, err := io.ReadFull(c.c, protocolVersion[:]); err != nil { return err } - maxMajor, maxMinor, err := parseProtocolVersion(protocolVersion[:]) + major, minor, err := parseProtocolVersion(protocolVersion[:]) if err != nil { return err } - if maxMajor < 3 { - return fmt.Errorf("unsupported major version, less than 3: %d", maxMajor) + pv := PROTO_VERS_UNSUP + if major == 3 && minor >= 8 { + pv = PROTO_VERS_3_8 } - if maxMinor < 8 { - return fmt.Errorf("unsupported minor version, less than 8: %d", maxMinor) + if pv == PROTO_VERS_UNSUP { + return fmt.Errorf("unsupported server ProtocolVersion '%v'", string(protocolVersion[:])) } // Respond with the version we will support - if _, err = c.c.Write([]byte("RFB 003.008\n")); err != nil { + if _, err = c.c.Write([]byte("RFB " + pv + "\n")); err != nil { return err } - // 7.1.2 Security Handshake from server + return nil +} + +// securityHandshake implements §7.1.2 Security Handshake. +func (c *ClientConn) securityHandshake() error { var numSecurityTypes uint8 - if err = binary.Read(c.c, binary.BigEndian, &numSecurityTypes); err != nil { + + if err := binary.Read(c.c, binary.BigEndian, &numSecurityTypes); err != nil { return err } - if numSecurityTypes == 0 { return fmt.Errorf("no security types: %s", c.readErrorReason()) } securityTypes := make([]uint8, numSecurityTypes) - if err = binary.Read(c.c, binary.BigEndian, &securityTypes); err != nil { + if err := binary.Read(c.c, binary.BigEndian, &securityTypes); err != nil { return err } @@ -354,64 +387,65 @@ FindAuth: } } } - if auth == nil { return fmt.Errorf("no suitable auth schemes found. server supported: %#v", securityTypes) } // Respond back with the security type we'll use - if err = binary.Write(c.c, binary.BigEndian, auth.SecurityType()); err != nil { + if err := binary.Write(c.c, binary.BigEndian, auth.SecurityType()); err != nil { return err } - - if err = auth.Handshake(c.c); err != nil { + if err := auth.Handshake(c.c); err != nil { return err } + return nil +} - // 7.1.3 SecurityResult Handshake +// securityResultHandshake implements §7.1.3 SecurityResult Handshake. +func (c *ClientConn) securityResultHandshake() error { var securityResult uint32 - if err = binary.Read(c.c, binary.BigEndian, &securityResult); err != nil { + if err := binary.Read(c.c, binary.BigEndian, &securityResult); err != nil { return err } - if securityResult == 1 { return fmt.Errorf("security handshake failed: %s", c.readErrorReason()) } + return nil +} - // 7.3.1 ClientInit - var sharedFlag uint8 = 1 - if c.config.Exclusive { - sharedFlag = 0 +// clientInit implements §7.3.1 ClientInit. +func (c *ClientConn) clientInit() error { + var sharedFlag uint8 + if !c.config.Exclusive { + sharedFlag = 1 } - - if err = binary.Write(c.c, binary.BigEndian, sharedFlag); err != nil { + //if err := binary.Write(c.c, binary.BigEndian, sharedFlag); err != nil { + if err := c.rw.Write(sharedFlag); err != nil { return err } + return nil +} - // 7.3.2 ServerInit - if err = binary.Read(c.c, binary.BigEndian, &c.FrameBufferWidth); err != nil { +// serverInit implements §7.3.2 ServerInit. +func (c *ClientConn) serverInit() error { + if err := binary.Read(c.c, binary.BigEndian, &c.FrameBufferWidth); err != nil { return err } - - if err = binary.Read(c.c, binary.BigEndian, &c.FrameBufferHeight); err != nil { + if err := binary.Read(c.c, binary.BigEndian, &c.FrameBufferHeight); err != nil { return err } - - // Read the pixel format - if err = readPixelFormat(c.c, &c.PixelFormat); err != nil { + if err := readPixelFormat(c.c, &c.PixelFormat); err != nil { return err } var nameLength uint32 - if err = binary.Read(c.c, binary.BigEndian, &nameLength); err != nil { + if err := binary.Read(c.c, binary.BigEndian, &nameLength); err != nil { return err } - nameBytes := make([]uint8, nameLength) - if err = binary.Read(c.c, binary.BigEndian, &nameBytes); err != nil { + if err := binary.Read(c.c, binary.BigEndian, &nameBytes); err != nil { return err } - c.DesktopName = string(nameBytes) return nil @@ -480,3 +514,25 @@ func (c *ClientConn) readErrorReason() string { return string(reason) } + +type ClientIO interface { + Read(data interface{}) error + Write(data interface{}) error +} + +type ClientIOReaderWriter struct { + reader io.Reader + writer io.Writer +} + +func NewClientIOReaderWriter(r io.Reader, w io.Writer) ClientIOReaderWriter { + return ClientIOReaderWriter{r, w} +} + +func (rw ClientIOReaderWriter) Read(data interface{}) error { + return binary.Read(rw.reader, binary.BigEndian, data) +} + +func (rw ClientIOReaderWriter) Write(data interface{}) error { + return binary.Write(rw.writer, binary.BigEndian, data) +} diff --git a/client_test.go b/client_test.go index 31591b4..c27082d 100644 --- a/client_test.go +++ b/client_test.go @@ -40,7 +40,7 @@ func TestClient_LowMajorVersion(t *testing.T) { t.Fatal("error expected") } - if err.Error() != "unsupported major version, less than 3: 2" { + if err.Error() != "unsupported server ProtocolVersion 'RFB 002.009\n'" { t.Fatalf("unexpected error: %s", err) } } @@ -56,7 +56,7 @@ func TestClient_LowMinorVersion(t *testing.T) { t.Fatal("error expected") } - if err.Error() != "unsupported minor version, less than 8: 7" { + if err.Error() != "unsupported server ProtocolVersion 'RFB 003.007\n'" { t.Fatalf("unexpected error: %s", err) } } @@ -93,3 +93,44 @@ func TestParseProtocolVersion(t *testing.T) { } } } + +type MockClientIOReaderWriter struct { + i, o interface{} +} + +func (rw *MockClientIOReaderWriter) Read(data interface{}) error { + rw.i = data + return nil +} + +func (rw *MockClientIOReaderWriter) Write(data interface{}) error { + rw.o = data + return nil +} + +func TestClientInit(t *testing.T) { + var err error + + tests := []struct { + exclusive bool + shared uint8 + }{ + {true, 0}, + {false, 1}, + } + + rw := &MockClientIOReaderWriter{} + cfg := &ClientConfig{} + conn := &ClientConn{config: cfg, rw: rw} + + for _, tt := range tests { + cfg.Exclusive = tt.exclusive + err = conn.clientInit() + if err != nil { + t.Fatalf("clientInit() error %v", err) + } + if rw.o != uint8(tt.shared) { + t.Errorf("clientInit() got = %v, want %v", rw.o, tt.shared) + } + } +} From 9fd5d84baa5a41af30f63101b82bf6860aeb1959 Mon Sep 17 00:00:00 2001 From: Kate Ward Date: Tue, 12 May 2015 01:23:02 +0200 Subject: [PATCH 02/11] moving the client protocol version const lower --- client.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/client.go b/client.go index 77a60eb..7c7cfa7 100644 --- a/client.go +++ b/client.go @@ -300,13 +300,7 @@ func (c *ClientConn) SetPixelFormat(format *PixelFormat) error { return nil } -const ( - pvLen = 12 // ProtocolVersion message length. - - // Supported protocol versions. - PROTO_VERS_UNSUP = "UNSUP" - PROTO_VERS_3_8 = "003.008" -) +const pvLen = 12 // ProtocolVersion message length. func parseProtocolVersion(pv []byte) (uint, uint, error) { var major, minor uint @@ -326,6 +320,12 @@ func parseProtocolVersion(pv []byte) (uint, uint, error) { return major, minor, nil } +const ( + // Client ProtocolVersions. + PROTO_VERS_UNSUP = "UNSUPPORTED" + PROTO_VERS_3_8 = "RFB 003.008\n" +) + // protocolVersionHandshake implements §7.1.1 ProtocolVersion Handshake. func (c *ClientConn) protocolVersionHandshake() error { var protocolVersion [pvLen]byte @@ -348,7 +348,7 @@ func (c *ClientConn) protocolVersionHandshake() error { } // Respond with the version we will support - if _, err = c.c.Write([]byte("RFB " + pv + "\n")); err != nil { + if _, err = c.c.Write([]byte(pv)); err != nil { return err } From a56035f1aef7205c0b1f87dbd1ddd99ebe1f1489 Mon Sep 17 00:00:00 2001 From: Kate Ward Date: Tue, 12 May 2015 16:18:23 +0200 Subject: [PATCH 03/11] Reworked and simplified mocking, and now test securityResultHandshake(). --- client.go | 59 ++++++++++++++++---------------- client_test.go | 92 +++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 102 insertions(+), 49 deletions(-) diff --git a/client.go b/client.go index 7c7cfa7..c14a8ca 100644 --- a/client.go +++ b/client.go @@ -19,7 +19,6 @@ import ( type ClientConn struct { c net.Conn config *ClientConfig - rw ClientIO // If the pixel format uses a color map, then this is the color // map that is used. This should not be modified directly, since @@ -75,7 +74,6 @@ func Client(c net.Conn, cfg *ClientConfig) (*ClientConn, error) { c: c, config: cfg, } - conn.rw = NewClientIOReaderWriter(c, c) if err := conn.protocolVersionHandshake(); err != nil { conn.Close() @@ -89,7 +87,7 @@ func Client(c net.Conn, cfg *ClientConfig) (*ClientConn, error) { conn.Close() return nil, err } - if err := conn.clientInit(); err != nil { + if _, err := conn.clientInit(); err != nil { conn.Close() return nil, err } @@ -363,7 +361,11 @@ func (c *ClientConn) securityHandshake() error { return err } if numSecurityTypes == 0 { - return fmt.Errorf("no security types: %s", c.readErrorReason()) + reason, err := c.readErrorReason() + if err != nil { + return err + } + return fmt.Errorf("no security types: %s", reason) } securityTypes := make([]uint8, numSecurityTypes) @@ -404,26 +406,33 @@ FindAuth: // securityResultHandshake implements §7.1.3 SecurityResult Handshake. func (c *ClientConn) securityResultHandshake() error { var securityResult uint32 + if err := binary.Read(c.c, binary.BigEndian, &securityResult); err != nil { return err } if securityResult == 1 { - return fmt.Errorf("security handshake failed: %s", c.readErrorReason()) + reason, err := c.readErrorReason() + if err != nil { + return err + } + return NewVNCError(fmt.Sprintf("SecurityResult handshake failed: %s", reason)) } + return nil } // clientInit implements §7.3.1 ClientInit. -func (c *ClientConn) clientInit() error { +func (c *ClientConn) clientInit() (uint8, error) { var sharedFlag uint8 + if !c.config.Exclusive { sharedFlag = 1 } - //if err := binary.Write(c.c, binary.BigEndian, sharedFlag); err != nil { - if err := c.rw.Write(sharedFlag); err != nil { - return err + if err := binary.Write(c.c, binary.BigEndian, sharedFlag); err != nil { + return 0, err } - return nil + + return sharedFlag, nil } // serverInit implements §7.3.2 ServerInit. @@ -501,38 +510,28 @@ func (c *ClientConn) mainLoop() { } } -func (c *ClientConn) readErrorReason() string { +func (c *ClientConn) readErrorReason() (string, error) { var reasonLen uint32 if err := binary.Read(c.c, binary.BigEndian, &reasonLen); err != nil { - return "" + return "", err } reason := make([]uint8, reasonLen) if err := binary.Read(c.c, binary.BigEndian, &reason); err != nil { - return "" + return "", err } - return string(reason) -} - -type ClientIO interface { - Read(data interface{}) error - Write(data interface{}) error -} - -type ClientIOReaderWriter struct { - reader io.Reader - writer io.Writer + return string(reason), nil } -func NewClientIOReaderWriter(r io.Reader, w io.Writer) ClientIOReaderWriter { - return ClientIOReaderWriter{r, w} +type vncError struct { + s string } -func (rw ClientIOReaderWriter) Read(data interface{}) error { - return binary.Read(rw.reader, binary.BigEndian, data) +func NewVNCError(s string) error { + return &vncError{s} } -func (rw ClientIOReaderWriter) Write(data interface{}) error { - return binary.Write(rw.writer, binary.BigEndian, data) +func (e vncError) Error() string { + return e.s } diff --git a/client_test.go b/client_test.go index c27082d..2603d8a 100644 --- a/client_test.go +++ b/client_test.go @@ -1,9 +1,13 @@ package vnc import ( + "bytes" + "encoding/binary" "fmt" "net" + "reflect" "testing" + "time" ) func newMockServer(t *testing.T, version string) string { @@ -94,23 +98,47 @@ func TestParseProtocolVersion(t *testing.T) { } } -type MockClientIOReaderWriter struct { - i, o interface{} -} +func TestSecurityResultHandshake(t *testing.T) { + tests := []struct { + result uint32 + ok bool + reason string + }{ + {0, true, ""}, + {1, false, "SecurityResult error"}, + } -func (rw *MockClientIOReaderWriter) Read(data interface{}) error { - rw.i = data - return nil -} + mockConn := &MockConn{} + conn := &ClientConn{ + c: mockConn, + config: &ClientConfig{}, + } -func (rw *MockClientIOReaderWriter) Write(data interface{}) error { - rw.o = data - return nil + for _, tt := range tests { + mockConn.Reset() + if err := binary.Write(conn.c, binary.BigEndian, tt.result); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, uint32(len(tt.reason))); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.reason)); err != nil { + t.Fatal(err) + } + + err := conn.securityResultHandshake() + if err == nil && !tt.ok { + t.Fatalf("securityResultHandshake() expected error for result %v", tt.result) + } + if err != nil { + if verr, ok := err.(*vncError); !ok { + t.Errorf("securityResultHandshake() unexpected %v error: %v", reflect.TypeOf(err), verr) + } + } + } } func TestClientInit(t *testing.T) { - var err error - tests := []struct { exclusive bool shared uint8 @@ -119,18 +147,44 @@ func TestClientInit(t *testing.T) { {false, 1}, } - rw := &MockClientIOReaderWriter{} - cfg := &ClientConfig{} - conn := &ClientConn{config: cfg, rw: rw} + mockConn := &MockConn{} + conn := &ClientConn{ + c: mockConn, + config: &ClientConfig{}, + } for _, tt := range tests { - cfg.Exclusive = tt.exclusive - err = conn.clientInit() + mockConn.Reset() + conn.config.Exclusive = tt.exclusive + + shared, err := conn.clientInit() if err != nil { t.Fatalf("clientInit() error %v", err) } - if rw.o != uint8(tt.shared) { - t.Errorf("clientInit() got = %v, want %v", rw.o, tt.shared) + if shared != tt.shared { + t.Errorf("clientInit() got = %v, want %v", shared, tt.shared) } } } + +// MockConn implements the net.Conn interface. +type MockConn struct { + b bytes.Buffer +} + +func (m *MockConn) Read(b []byte) (int, error) { + return m.b.Read(b) +} +func (m *MockConn) Write(b []byte) (int, error) { + return m.b.Write(b) +} +func (m *MockConn) Close() error { return nil } +func (m *MockConn) LocalAddr() net.Addr { return nil } +func (m *MockConn) RemoteAddr() net.Addr { return nil } +func (m *MockConn) SetDeadline(t time.Time) error { return nil } +func (m *MockConn) SetReadDeadline(t time.Time) error { return nil } +func (m *MockConn) SetWriteDeadline(t time.Time) error { return nil } +// Implement additional buffer.Buffer functions. +func (m *MockConn) Reset() { + m.b.Reset() +} From 533f16f8c597412ec98a284c6a85250aa7213e71 Mon Sep 17 00:00:00 2001 From: Kate Ward Date: Tue, 12 May 2015 16:24:31 +0200 Subject: [PATCH 04/11] renamed vncError to VNCError to make it externally visible --- client.go | 8 +++++--- client_test.go | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index c14a8ca..9f81973 100644 --- a/client.go +++ b/client.go @@ -524,14 +524,16 @@ func (c *ClientConn) readErrorReason() (string, error) { return string(reason), nil } -type vncError struct { +// VNCError implements error interface. +type VNCError struct { s string } +// NewVNCError returns a custom VNCError error. func NewVNCError(s string) error { - return &vncError{s} + return &VNCError{s} } -func (e vncError) Error() string { +func (e VNCError) Error() string { return e.s } diff --git a/client_test.go b/client_test.go index 2603d8a..6f6cec3 100644 --- a/client_test.go +++ b/client_test.go @@ -131,7 +131,7 @@ func TestSecurityResultHandshake(t *testing.T) { t.Fatalf("securityResultHandshake() expected error for result %v", tt.result) } if err != nil { - if verr, ok := err.(*vncError); !ok { + if verr, ok := err.(*VNCError); !ok { t.Errorf("securityResultHandshake() unexpected %v error: %v", reflect.TypeOf(err), verr) } } From d1f6b14a56b4c46ff31dfa9408cb36ed2a291de1 Mon Sep 17 00:00:00 2001 From: Kate Ward Date: Tue, 12 May 2015 17:13:38 +0200 Subject: [PATCH 05/11] added unit test for serverInit() --- client_test.go | 51 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/client_test.go b/client_test.go index 6f6cec3..1c219ae 100644 --- a/client_test.go +++ b/client_test.go @@ -167,6 +167,56 @@ func TestClientInit(t *testing.T) { } } +func TestServerInit(t *testing.T) { + tests := []struct { + fbWidth, fbHeight uint16 + pixelFormat [16]byte // TODO(kward): replace with PixelFormat + desktopName string + }{ + {100, 200, [16]byte{}, "foo"}, + } + + mockConn := &MockConn{} + conn := &ClientConn{ + c: mockConn, + config: &ClientConfig{}, + } + + for _, tt := range tests { + mockConn.Reset() + if err := binary.Write(conn.c, binary.BigEndian, tt.fbWidth); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, tt.fbHeight); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, tt.pixelFormat); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, uint32(len(tt.desktopName))); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.desktopName)); err != nil { + t.Fatal(err) + } + + err := conn.serverInit() + if err != nil { + t.Fatalf("serverInit() error %v", err) + } + if conn.FrameBufferWidth != tt.fbWidth { + t.Errorf("serverInit() FrameBufferWidth: got = %v, want = %v", conn.FrameBufferWidth, tt.fbWidth) + } + if conn.FrameBufferHeight != tt.fbHeight { + t.Errorf("serverInit() FrameBufferHeight: got = %v, want = %v", conn.FrameBufferHeight, tt.fbHeight) + } + // TODO(kward): add test for PixelFormat. + if conn.DesktopName != tt.desktopName { + t.Errorf("serverInit() DesktopName: got = %v, want = %v", conn.DesktopName, tt.desktopName) + } + } +} + // MockConn implements the net.Conn interface. type MockConn struct { b bytes.Buffer @@ -184,6 +234,7 @@ func (m *MockConn) RemoteAddr() net.Addr { return nil } func (m *MockConn) SetDeadline(t time.Time) error { return nil } func (m *MockConn) SetReadDeadline(t time.Time) error { return nil } func (m *MockConn) SetWriteDeadline(t time.Time) error { return nil } + // Implement additional buffer.Buffer functions. func (m *MockConn) Reset() { m.b.Reset() From 306a2489e6efb54ec9f45fbf8778a4fb4aca891a Mon Sep 17 00:00:00 2001 From: Kate Ward Date: Tue, 12 May 2015 17:37:15 +0200 Subject: [PATCH 06/11] add invalid protocol tests to TestServerInit() --- client_test.go | 53 ++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 12 deletions(-) diff --git a/client_test.go b/client_test.go index 1c219ae..62293ea 100644 --- a/client_test.go +++ b/client_test.go @@ -168,12 +168,26 @@ func TestClientInit(t *testing.T) { } func TestServerInit(t *testing.T) { + const ( + none = iota + fbw + fbh + pf + dn + ) tests := []struct { + eof int fbWidth, fbHeight uint16 pixelFormat [16]byte // TODO(kward): replace with PixelFormat desktopName string }{ - {100, 200, [16]byte{}, "foo"}, + // Valid protocol. + {dn, 100, 200, [16]byte{}, "foo"}, + // Invalid protocol (missing fields). + {eof: none}, + {eof: fbw, fbWidth: 1}, + {eof: fbh, fbWidth: 2, fbHeight: 1}, + {eof: pf, fbWidth: 3, fbHeight: 2, pixelFormat: [16]byte{}}, } mockConn := &MockConn{} @@ -184,23 +198,38 @@ func TestServerInit(t *testing.T) { for _, tt := range tests { mockConn.Reset() - if err := binary.Write(conn.c, binary.BigEndian, tt.fbWidth); err != nil { - t.Fatal(err) - } - if err := binary.Write(conn.c, binary.BigEndian, tt.fbHeight); err != nil { - t.Fatal(err) + if tt.eof >= fbw { + if err := binary.Write(conn.c, binary.BigEndian, tt.fbWidth); err != nil { + t.Fatal(err) + } } - if err := binary.Write(conn.c, binary.BigEndian, tt.pixelFormat); err != nil { - t.Fatal(err) + if tt.eof >= fbh { + if err := binary.Write(conn.c, binary.BigEndian, tt.fbHeight); err != nil { + t.Fatal(err) + } } - if err := binary.Write(conn.c, binary.BigEndian, uint32(len(tt.desktopName))); err != nil { - t.Fatal(err) + if tt.eof >= pf { + if err := binary.Write(conn.c, binary.BigEndian, tt.pixelFormat); err != nil { + t.Fatal(err) + } } - if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.desktopName)); err != nil { - t.Fatal(err) + if tt.eof >= dn { + if err := binary.Write(conn.c, binary.BigEndian, uint32(len(tt.desktopName))); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.desktopName)); err != nil { + t.Fatal(err) + } } err := conn.serverInit() + if tt.eof < dn && err == nil { + t.Fatalf("serverInit() expected error") + } + if tt.eof < dn { + // If the protocol was incomplete, there is no point in checking values. + continue + } if err != nil { t.Fatalf("serverInit() error %v", err) } From f190d8e05217c4053181497c1b41badaa0cb11f2 Mon Sep 17 00:00:00 2001 From: Kate Ward Date: Tue, 12 May 2015 22:09:40 +0200 Subject: [PATCH 07/11] added constants for standard security types --- client_auth.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/client_auth.go b/client_auth.go index c88f911..c6dd037 100644 --- a/client_auth.go +++ b/client_auth.go @@ -4,6 +4,12 @@ import ( "net" ) +const ( + secTypeInvalid = iota + secTypeNone + secTypeVNCAuth +) + // A ClientAuth implements a method of authenticating with a remote server. type ClientAuth interface { // SecurityType returns the byte identifier sent by the server to @@ -16,10 +22,10 @@ type ClientAuth interface { } // ClientAuthNone is the "none" authentication. See 7.1.2 -type ClientAuthNone byte +type ClientAuthNone struct{} func (*ClientAuthNone) SecurityType() uint8 { - return 1 + return secTypeNone } func (*ClientAuthNone) Handshake(net.Conn) error { From 2f17980fb9f01043e38e1828fa848e657317303f Mon Sep 17 00:00:00 2001 From: Kate Ward Date: Tue, 12 May 2015 22:10:18 +0200 Subject: [PATCH 08/11] added tests for protocolVersionHandshake() and securityHandshake() --- client.go | 15 +++--- client_test.go | 137 +++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 134 insertions(+), 18 deletions(-) diff --git a/client.go b/client.go index 9f81973..8352bf0 100644 --- a/client.go +++ b/client.go @@ -87,7 +87,7 @@ func Client(c net.Conn, cfg *ClientConfig) (*ClientConn, error) { conn.Close() return nil, err } - if _, err := conn.clientInit(); err != nil { + if err := conn.clientInit(); err != nil { conn.Close() return nil, err } @@ -342,7 +342,7 @@ func (c *ClientConn) protocolVersionHandshake() error { pv = PROTO_VERS_3_8 } if pv == PROTO_VERS_UNSUP { - return fmt.Errorf("unsupported server ProtocolVersion '%v'", string(protocolVersion[:])) + return NewVNCError(fmt.Sprintf("ProtocolVersion handshake failed; unsupported version '%v'", string(protocolVersion[:]))) } // Respond with the version we will support @@ -356,7 +356,6 @@ func (c *ClientConn) protocolVersionHandshake() error { // securityHandshake implements §7.1.2 Security Handshake. func (c *ClientConn) securityHandshake() error { var numSecurityTypes uint8 - if err := binary.Read(c.c, binary.BigEndian, &numSecurityTypes); err != nil { return err } @@ -365,7 +364,7 @@ func (c *ClientConn) securityHandshake() error { if err != nil { return err } - return fmt.Errorf("no security types: %s", reason) + return NewVNCError(fmt.Sprintf("Security handshake failed; no security types: %v", reason)) } securityTypes := make([]uint8, numSecurityTypes) @@ -390,7 +389,7 @@ FindAuth: } } if auth == nil { - return fmt.Errorf("no suitable auth schemes found. server supported: %#v", securityTypes) + return NewVNCError(fmt.Sprintf("Security handshake failed; no suitable auth schemes found; server supports: %#v", securityTypes)) } // Respond back with the security type we'll use @@ -422,17 +421,17 @@ func (c *ClientConn) securityResultHandshake() error { } // clientInit implements §7.3.1 ClientInit. -func (c *ClientConn) clientInit() (uint8, error) { +func (c *ClientConn) clientInit() error { var sharedFlag uint8 if !c.config.Exclusive { sharedFlag = 1 } if err := binary.Write(c.c, binary.BigEndian, sharedFlag); err != nil { - return 0, err + return err } - return sharedFlag, nil + return nil } // serverInit implements §7.3.2 ServerInit. diff --git a/client_test.go b/client_test.go index 62293ea..0faa55a 100644 --- a/client_test.go +++ b/client_test.go @@ -43,9 +43,10 @@ func TestClient_LowMajorVersion(t *testing.T) { if err == nil { t.Fatal("error expected") } - - if err.Error() != "unsupported server ProtocolVersion 'RFB 002.009\n'" { - t.Fatalf("unexpected error: %s", err) + if err != nil { + if verr, ok := err.(*VNCError); !ok { + t.Errorf("Client() unexpected %v error: %v", reflect.TypeOf(err), verr) + } } } @@ -59,9 +60,10 @@ func TestClient_LowMinorVersion(t *testing.T) { if err == nil { t.Fatal("error expected") } - - if err.Error() != "unsupported server ProtocolVersion 'RFB 003.007\n'" { - t.Fatalf("unexpected error: %s", err) + if err != nil { + if verr, ok := err.(*VNCError); !ok { + t.Errorf("Client() unexpected %v error: %v", reflect.TypeOf(err), verr) + } } } @@ -98,6 +100,116 @@ func TestParseProtocolVersion(t *testing.T) { } } +func TestProtocolVersionHandshake(t *testing.T) { + tests := []struct { + server string + client string + ok bool + }{ + // Supported versions. + {"RFB 003.008\n", "RFB 003.008\n", true}, + {"RFB 003.389\n", "RFB 003.008\n", true}, + // Unsupported versions. + {server: "RFB 003.003\n", ok: false}, + {server: "RFB 002.009\n", ok: false}, + } + + mockConn := &MockConn{} + conn := &ClientConn{ + c: mockConn, + config: &ClientConfig{}, + } + + for _, tt := range tests { + mockConn.Reset() + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.server)); err != nil { + t.Fatal(err) + } + + // Validate server message. + err := conn.protocolVersionHandshake() + if err == nil && !tt.ok { + t.Fatalf("protocolVersionHandshake() expected error for server protocol version %v", tt.server) + } + if err != nil { + if verr, ok := err.(*VNCError); !ok { + t.Errorf("protocolVersionHandshake() unexpected %v error: %v", reflect.TypeOf(err), verr) + } + } + + // Validate client response. + var client [pvLen]byte + err = binary.Read(conn.c, binary.BigEndian, &client) + if err == nil && !tt.ok { + t.Fatalf("protocolVersionHandshake() unexpected error: %v", err) + } + if string(client[:]) != tt.client && tt.ok { + t.Errorf("protocolVersionHandshake() client version: got = %v, want = %v", string(client[:]), tt.client) + } + } +} + +func TestSecurityHandshake(t *testing.T) { + tests := []struct { + server []uint8 + client []ClientAuth + secType uint8 + ok bool + }{ + //-- Supported security types. -- + // Both server and client support the None security type. + {[]uint8{secTypeNone}, []ClientAuth{&ClientAuthNone{}}, secTypeNone, true}, + // Server supports None and VNCAuth, client supports only None. + {[]uint8{secTypeVNCAuth, secTypeNone}, []ClientAuth{&ClientAuthNone{}}, secTypeNone, true}, + //-- Unsupported security types. -- + // Server provided no security types. + {[]uint8{}, []ClientAuth{&ClientAuthNone{}}, secTypeInvalid, false}, + // Client and server don't support same security types. + {[]uint8{secTypeVNCAuth}, []ClientAuth{&ClientAuthNone{}}, secTypeInvalid, false}, + } + + mockConn := &MockConn{} + conn := &ClientConn{ + c: mockConn, + config: &ClientConfig{}, + } + + for _, tt := range tests { + mockConn.Reset() + if len(tt.server) > 0 { + if err := binary.Write(conn.c, binary.BigEndian, uint8(len(tt.server))); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.server)); err != nil { + t.Fatal(err) + } + } + + // Validate server message. + conn.config.Auth = tt.client + err := conn.securityHandshake() + if err == nil && !tt.ok { + t.Fatalf("securityHandshake() expected error for server auth %v", tt.server) + } + if len(tt.server) == 0 { + // The protocol was incomplete; no point in checking values. + continue + } + if err != nil { + if verr, ok := err.(*VNCError); !ok { + t.Errorf("securityHandshake() unexpected %v error: %v", reflect.TypeOf(err), verr) + } + } + + // Validate client response. + var secType uint8 + err = binary.Read(conn.c, binary.BigEndian, &secType) + if secType != tt.secType { + t.Errorf("securityHandshake() secType: got = %v, want = %v", secType, tt.secType) + } + } +} + func TestSecurityResultHandshake(t *testing.T) { tests := []struct { result uint32 @@ -126,6 +238,7 @@ func TestSecurityResultHandshake(t *testing.T) { t.Fatal(err) } + // Validate server message. err := conn.securityResultHandshake() if err == nil && !tt.ok { t.Fatalf("securityResultHandshake() expected error for result %v", tt.result) @@ -157,12 +270,15 @@ func TestClientInit(t *testing.T) { mockConn.Reset() conn.config.Exclusive = tt.exclusive - shared, err := conn.clientInit() + // Validate client response. + err := conn.clientInit() if err != nil { - t.Fatalf("clientInit() error %v", err) + t.Fatalf("clientInit() unexpected error %v", err) } + var shared uint8 + err = binary.Read(conn.c, binary.BigEndian, &shared) if shared != tt.shared { - t.Errorf("clientInit() got = %v, want %v", shared, tt.shared) + t.Errorf("clientInit() shared: got = %v, want = %v", shared, tt.shared) } } } @@ -222,12 +338,13 @@ func TestServerInit(t *testing.T) { } } + // Validate server message. err := conn.serverInit() if tt.eof < dn && err == nil { t.Fatalf("serverInit() expected error") } if tt.eof < dn { - // If the protocol was incomplete, there is no point in checking values. + // The protocol was incomplete; no point in checking values. continue } if err != nil { From ca57dc6fcaaba7f4a571139d34e3a6806a6dd7f4 Mon Sep 17 00:00:00 2001 From: Kate Ward Date: Tue, 12 May 2015 22:38:35 +0200 Subject: [PATCH 09/11] add support for RFB 003.003 --- client.go | 9 +++++++-- client_test.go | 5 +++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index 8352bf0..6a6b6cd 100644 --- a/client.go +++ b/client.go @@ -321,6 +321,7 @@ func parseProtocolVersion(pv []byte) (uint, uint, error) { const ( // Client ProtocolVersions. PROTO_VERS_UNSUP = "UNSUPPORTED" + PROTO_VERS_3_3 = "RFB 003.003\n" PROTO_VERS_3_8 = "RFB 003.008\n" ) @@ -338,8 +339,12 @@ func (c *ClientConn) protocolVersionHandshake() error { return err } pv := PROTO_VERS_UNSUP - if major == 3 && minor >= 8 { - pv = PROTO_VERS_3_8 + if major == 3 { + if minor >= 8 { + pv = PROTO_VERS_3_8 + } else if minor >= 3 { + pv = PROTO_VERS_3_3 + } } if pv == PROTO_VERS_UNSUP { return NewVNCError(fmt.Sprintf("ProtocolVersion handshake failed; unsupported version '%v'", string(protocolVersion[:]))) diff --git a/client_test.go b/client_test.go index 0faa55a..84e3028 100644 --- a/client_test.go +++ b/client_test.go @@ -51,7 +51,7 @@ func TestClient_LowMajorVersion(t *testing.T) { } func TestClient_LowMinorVersion(t *testing.T) { - nc, err := net.Dial("tcp", newMockServer(t, "003.007")) + nc, err := net.Dial("tcp", newMockServer(t, "003.002")) if err != nil { t.Fatalf("error connecting to mock server: %s", err) } @@ -107,10 +107,11 @@ func TestProtocolVersionHandshake(t *testing.T) { ok bool }{ // Supported versions. + {"RFB 003.003\n", "RFB 003.003\n", true}, + {"RFB 003.006\n", "RFB 003.003\n", true}, {"RFB 003.008\n", "RFB 003.008\n", true}, {"RFB 003.389\n", "RFB 003.008\n", true}, // Unsupported versions. - {server: "RFB 003.003\n", ok: false}, {server: "RFB 002.009\n", ok: false}, } From 044241562a5eb288b93d92b8d90192accdd54c86 Mon Sep 17 00:00:00 2001 From: Kate Ward Date: Thu, 14 May 2015 15:01:22 +0200 Subject: [PATCH 10/11] support full RFB 003.003 protocol --- client_auth_vnc.go | 72 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 client_auth_vnc.go diff --git a/client_auth_vnc.go b/client_auth_vnc.go new file mode 100644 index 0000000..9d7495f --- /dev/null +++ b/client_auth_vnc.go @@ -0,0 +1,72 @@ +/* +ClientAuthVNC implements the ClientAuth interface to provide support for +VNC Authentication. + +See http://tools.ietf.org/html/rfc6143#section-7.2.2 for more info. +*/ +package vnc + +import ( + "crypto/des" + "encoding/binary" + "net" +) + +// ClientAuthVNC is the standard password authentication +type ClientAuthVNC struct { + Password string +} + +func (*ClientAuthVNC) SecurityType() uint8 { + return secTypeVNCAuth +} + +// 7.2.2. VNC Authentication uses a 16-byte challenge. +const vncAuthChallengeSize = 16 + +func (auth *ClientAuthVNC) Handshake(conn net.Conn) error { + + if auth.Password == "" { + return NewVNCError("securityHandshake: handshake failed; no password provided for VNCAuth.") + } + + // Read challenge block + var challenge [vncAuthChallengeSize]byte + if err := binary.Read(conn, binary.BigEndian, &challenge); err != nil { + return err + } + + auth.encode(&challenge) + + // Send the encrypted challenge back to server + if err := binary.Write(conn, binary.BigEndian, challenge); err != nil { + return err + } + + return nil +} + +func (auth *ClientAuthVNC) encode(c *[vncAuthChallengeSize]byte) error { + // Copy password string to 8 byte 0-padded slice + key := make([]byte, 8) + copy(key, auth.Password) + + // Each byte of the password needs to be reversed. This is a + // non RFC-documented behaviour of VNC clients and servers + for i := range key { + key[i] = (key[i]&0x55)<<1 | (key[i]&0xAA)>>1 // Swap adjacent bits + key[i] = (key[i]&0x33)<<2 | (key[i]&0xCC)>>2 // Swap adjacent pairs + key[i] = (key[i]&0x0F)<<4 | (key[i]&0xF0)>>4 // Swap the 2 halves + } + + // Encrypt challenge with key. + cipher, err := des.NewCipher(key) + if err != nil { + return err + } + for i := 0; i < vncAuthChallengeSize; i += cipher.BlockSize() { + cipher.Encrypt(c[i:i+cipher.BlockSize()], c[i:i+cipher.BlockSize()]) + } + + return nil +} From e118cb2562de2fb6dab1243469fcc4f8147bdcda Mon Sep 17 00:00:00 2001 From: Kate Ward Date: Thu, 14 May 2015 15:01:56 +0200 Subject: [PATCH 11/11] support full RFB 003.003 protocol --- client.go | 96 +++++++++++++++++++++------ client_auth_test.go | 51 ++++++++++++++- client_test.go | 153 +++++++++++++++++++++++++++++++++++++------- 3 files changed, 258 insertions(+), 42 deletions(-) diff --git a/client.go b/client.go index 6a6b6cd..e628bd6 100644 --- a/client.go +++ b/client.go @@ -17,8 +17,9 @@ import ( // The ClientConn type holds client connection information. type ClientConn struct { - c net.Conn - config *ClientConfig + c net.Conn + config *ClientConfig + protocolVersion string // If the pixel format uses a color map, then this is the color // map that is used. This should not be modified directly, since @@ -51,6 +52,9 @@ type ClientConfig struct { // suitable by the server will be used to authenticate. Auth []ClientAuth + // Password for servers that require authentication. + Password string + // Exclusive determines whether the connection is shared with other // clients. If true, then all other clients connected will be // disconnected when a connection is established to the VNC server. @@ -69,6 +73,13 @@ type ClientConfig struct { ServerMessages []ServerMessage } +func NewClientConfig(p string) *ClientConfig { + return &ClientConfig{ + Auth: []ClientAuth{&ClientAuthNone{}, &ClientAuthVNC{p}}, + Password: p, + } +} + func Client(c net.Conn, cfg *ClientConfig) (*ClientConn, error) { conn := &ClientConn{ c: c, @@ -171,7 +182,6 @@ func (c *ClientConn) FramebufferUpdateRequest(incremental bool, x, y, width, hei return err } } - if _, err := c.c.Write(buf.Bytes()[0:10]); err != nil { return err } @@ -179,7 +189,7 @@ func (c *ClientConn) FramebufferUpdateRequest(incremental bool, x, y, width, hei return nil } -// KeyEvent indiciates a key press or release and sends it to the server. +// KeyEvent indicates a key press or release and sends it to the server. // The key is indicated using the X Window System "keysym" value. Use // Google to find a reference of these values. To simulate a key press, // you must send a key with both a down event, and a non-down event. @@ -349,6 +359,7 @@ func (c *ClientConn) protocolVersionHandshake() error { if pv == PROTO_VERS_UNSUP { return NewVNCError(fmt.Sprintf("ProtocolVersion handshake failed; unsupported version '%v'", string(protocolVersion[:]))) } + c.protocolVersion = pv // Respond with the version we will support if _, err = c.c.Write([]byte(pv)); err != nil { @@ -360,6 +371,54 @@ func (c *ClientConn) protocolVersionHandshake() error { // securityHandshake implements §7.1.2 Security Handshake. func (c *ClientConn) securityHandshake() error { + + switch c.protocolVersion { + case PROTO_VERS_3_3: + err := c.securityHandshake33() + if err != nil { + return err + } + case PROTO_VERS_3_8: + err := c.securityHandshake38() + if err != nil { + return err + } + default: + return NewVNCError(fmt.Sprintf("Security handshake failed; unsupported protocol")) + } + return nil +} + +func (c *ClientConn) securityHandshake33() error { + var secType uint32 + if err := binary.Read(c.c, binary.BigEndian, &secType); err != nil { + return err + } + + var auth ClientAuth + switch secType { + case secTypeInvalid: // Connection failed. + reason, err := c.readErrorReason() + if err != nil { + return err + } + return NewVNCError(fmt.Sprintf("Security handshake failed; connection failed: %s", reason)) + case secTypeNone: + auth = &ClientAuthNone{} + case secTypeVNCAuth: + auth = &ClientAuthVNC{c.config.Password} + default: + return NewVNCError(fmt.Sprintf("Security handshake failed; invalid security type: %v", secType)) + } + if err := auth.Handshake(c.c); err != nil { + return err + } + + return nil +} + +func (c *ClientConn) securityHandshake38() error { + // Determine server supported security types. var numSecurityTypes uint8 if err := binary.Read(c.c, binary.BigEndian, &numSecurityTypes); err != nil { return err @@ -371,24 +430,20 @@ func (c *ClientConn) securityHandshake() error { } return NewVNCError(fmt.Sprintf("Security handshake failed; no security types: %v", reason)) } - securityTypes := make([]uint8, numSecurityTypes) if err := binary.Read(c.c, binary.BigEndian, &securityTypes); err != nil { return err } - clientSecurityTypes := c.config.Auth - if clientSecurityTypes == nil { - clientSecurityTypes = []ClientAuth{new(ClientAuthNone)} - } - + // Choose client security type. + // TODO(kward): try "better" security types first. var auth ClientAuth FindAuth: - for _, curAuth := range clientSecurityTypes { - for _, securityType := range securityTypes { - if curAuth.SecurityType() == securityType { - // We use the first matching supported authentication - auth = curAuth + for _, securityType := range securityTypes { + for _, a := range c.config.Auth { + if a.SecurityType() == securityType { + // We use the first matching supported authentication. + auth = a break FindAuth } } @@ -396,11 +451,10 @@ FindAuth: if auth == nil { return NewVNCError(fmt.Sprintf("Security handshake failed; no suitable auth schemes found; server supports: %#v", securityTypes)) } - - // Respond back with the security type we'll use if err := binary.Write(c.c, binary.BigEndian, auth.SecurityType()); err != nil { return err } + if err := auth.Handshake(c.c); err != nil { return err } @@ -409,8 +463,12 @@ FindAuth: // securityResultHandshake implements §7.1.3 SecurityResult Handshake. func (c *ClientConn) securityResultHandshake() error { - var securityResult uint32 + if c.protocolVersion == PROTO_VERS_3_3 { + // Not required for 3.3. + return nil + } + var securityResult uint32 if err := binary.Read(c.c, binary.BigEndian, &securityResult); err != nil { return err } @@ -421,7 +479,6 @@ func (c *ClientConn) securityResultHandshake() error { } return NewVNCError(fmt.Sprintf("SecurityResult handshake failed: %s", reason)) } - return nil } @@ -514,6 +571,7 @@ func (c *ClientConn) mainLoop() { } } +// TODO(kward): need a context for timeout func (c *ClientConn) readErrorReason() (string, error) { var reasonLen uint32 if err := binary.Read(c.c, binary.BigEndian, &reasonLen); err != nil { diff --git a/client_auth_test.go b/client_auth_test.go index 746b46f..f216db1 100644 --- a/client_auth_test.go +++ b/client_auth_test.go @@ -1,6 +1,10 @@ package vnc -import "testing" +import ( + "encoding/hex" + "strings" + "testing" +) func TestClientAuthNone_Impl(t *testing.T) { var raw interface{} @@ -9,3 +13,48 @@ func TestClientAuthNone_Impl(t *testing.T) { t.Fatal("ClientAuthNone doesn't implement ClientAuth") } } + +func TestClientAuthVNC_Impl(t *testing.T) { + var raw interface{} + raw = new(ClientAuthVNC) + if _, ok := raw.(ClientAuth); !ok { + t.Fatal("ClientAuthVNC doesn't implement ClientAuth") + } +} + +// wiresharkToChallenge converts VNC authentication challenge and response +// values captured with Wireshark (https://www.wireshark.org) into usable byte +// streams. +func wiresharkToChallenge(h string) [vncAuthChallengeSize]byte { + var c [vncAuthChallengeSize]byte + r := strings.NewReplacer(":", "") + b, err := hex.DecodeString(r.Replace(h)) + if err != nil { + return c + } + copy(c[:], b) + return c +} + +func TestClientAuthVNCEncode(t *testing.T) { + tests := []struct { + pw string + challenge, response string + }{ + {".", "7f:e2:e1:3d:a4:ae:10:9c:54:c5:5f:52:74:aa:db:31", "1d:86:92:71:1f:00:24:35:02:d3:91:ef:e9:bc:c5:d5"}, + {"12345678", "13:8e:a4:2e:0e:66:f3:ad:2d:f3:08:c3:04:cd:c4:2a", "5b:e1:56:fa:49:49:ef:56:d3:f8:44:97:73:27:95:9f"}, + {"abc123", "c6:30:45:d2:57:9e:e7:f2:f9:0c:62:3e:52:40:86:c6", "a3:63:59:e4:28:c8:7f:b3:45:2c:d7:e0:ca:d6:70:3e"}, + } + + for _, tt := range tests { + challenge := wiresharkToChallenge(tt.challenge) + a := ClientAuthVNC{tt.pw} + if err := a.encode(&challenge); err != nil { + t.Errorf("ClientAuthVNC.encode() failed: key=%v, err=%v", tt.pw, err) + } + response := wiresharkToChallenge(tt.response) + if challenge != response { + t.Errorf("ClientAuthVNC.encode() failed: key=%v got=%v, want=%v", tt.pw, challenge, response) + } + } +} diff --git a/client_test.go b/client_test.go index 84e3028..6423542 100644 --- a/client_test.go +++ b/client_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "fmt" + "io" "net" "reflect" "testing" @@ -88,6 +89,7 @@ func TestParseProtocolVersion(t *testing.T) { if err != nil && !tt.isErr { t.Fatalf("parseProtocolVersion(%v) unexpected error %v", tt.proto, err) } + // TODO(kward): validate VNCError thrown. if err == nil && tt.isErr { t.Fatalf("parseProtocolVersion(%v) expected error", tt.proto) } @@ -127,7 +129,7 @@ func TestProtocolVersionHandshake(t *testing.T) { t.Fatal(err) } - // Validate server message. + // Validate server message handling. err := conn.protocolVersionHandshake() if err == nil && !tt.ok { t.Fatalf("protocolVersionHandshake() expected error for server protocol version %v", tt.server) @@ -150,57 +152,159 @@ func TestProtocolVersionHandshake(t *testing.T) { } } -func TestSecurityHandshake(t *testing.T) { +func writeVNCAuthChallenge(w io.Writer) error { + var c [vncAuthChallengeSize]uint8 + for i := 0; i < vncAuthChallengeSize; i++ { + c[i] = uint8(i) + } + if err := binary.Write(w, binary.BigEndian, c); err != nil { + return err + } + return nil +} + +func readVNCAuthChallenge(r io.Reader) error { + var c [vncAuthChallengeSize]uint8 + if err := binary.Read(r, binary.BigEndian, &c); err != nil { + return fmt.Errorf("error reading back VNCAuth challenge") + } + return nil +} + +func TestSecurityHandshake33(t *testing.T) { + tests := []struct { + server uint32 + ok bool + reason string + }{ + //-- Supported security types. -- + // Server supports None. + {secTypeNone, true, ""}, + // Server supports VNCAuth. + {secTypeVNCAuth, true, ""}, + //-- Unsupported security types. -- + {secTypeInvalid, false, "some reason"}, + {255, false, ""}, + } + + mockConn := &MockConn{} + conn := &ClientConn{ + c: mockConn, + config: NewClientConfig("."), + protocolVersion: PROTO_VERS_3_3, + } + + for _, tt := range tests { + mockConn.Reset() + if err := binary.Write(conn.c, binary.BigEndian, tt.server); err != nil { + t.Fatal(err) + } + if len(tt.reason) > 0 { + if err := binary.Write(conn.c, binary.BigEndian, uint32(len(tt.reason))); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.reason)); err != nil { + t.Fatal(err) + } + } + if tt.server == secTypeVNCAuth { + if err := writeVNCAuthChallenge(conn.c); err != nil { + t.Fatal(err) + } + } + + // Validate server message handling. + err := conn.securityHandshake() + if err == nil && !tt.ok { + t.Fatalf("securityHandshake() expected error for server auth %v", tt.server) + } + if err != nil { + if verr, ok := err.(*VNCError); !ok { + t.Errorf("securityHandshake() unexpected %v error: %v", reflect.TypeOf(err), verr) + } + } + if !tt.ok { + continue + } + + // Validate client response. + if tt.server == secTypeVNCAuth { + if err := readVNCAuthChallenge(conn.c); err != nil { + t.Fatal(err) + } + } + } +} + +func TestSecurityHandshake38(t *testing.T) { tests := []struct { server []uint8 client []ClientAuth secType uint8 ok bool + reason string }{ //-- Supported security types. -- - // Both server and client support the None security type. - {[]uint8{secTypeNone}, []ClientAuth{&ClientAuthNone{}}, secTypeNone, true}, - // Server supports None and VNCAuth, client supports only None. - {[]uint8{secTypeVNCAuth, secTypeNone}, []ClientAuth{&ClientAuthNone{}}, secTypeNone, true}, + // Server and client support None. + {[]uint8{secTypeNone}, []ClientAuth{&ClientAuthNone{}}, secTypeNone, true, ""}, + // Server and client support VNCAuth. + {[]uint8{secTypeVNCAuth}, []ClientAuth{&ClientAuthVNC{"."}}, secTypeVNCAuth, true, ""}, + // Server and client both support VNCAuth and None. + {[]uint8{secTypeVNCAuth, secTypeNone}, []ClientAuth{&ClientAuthVNC{"."}, &ClientAuthNone{}}, secTypeVNCAuth, true, ""}, + // Server supports unknown #255, VNCAuth and None. + {[]uint8{255, secTypeVNCAuth, secTypeNone}, []ClientAuth{&ClientAuthVNC{"."}, &ClientAuthNone{}}, secTypeVNCAuth, true, ""}, //-- Unsupported security types. -- - // Server provided no security types. - {[]uint8{}, []ClientAuth{&ClientAuthNone{}}, secTypeInvalid, false}, + // Server provided no valid security types. + {[]uint8{secTypeInvalid}, []ClientAuth{}, secTypeInvalid, false, "some reason"}, // Client and server don't support same security types. - {[]uint8{secTypeVNCAuth}, []ClientAuth{&ClientAuthNone{}}, secTypeInvalid, false}, + {[]uint8{secTypeVNCAuth}, []ClientAuth{&ClientAuthNone{}}, secTypeInvalid, false, ""}, + // Server supports only unknown #255. + {[]uint8{255}, []ClientAuth{&ClientAuthNone{}}, secTypeInvalid, false, ""}, } mockConn := &MockConn{} conn := &ClientConn{ - c: mockConn, - config: &ClientConfig{}, + c: mockConn, + config: &ClientConfig{}, + protocolVersion: PROTO_VERS_3_8, } for _, tt := range tests { mockConn.Reset() - if len(tt.server) > 0 { - if err := binary.Write(conn.c, binary.BigEndian, uint8(len(tt.server))); err != nil { + if err := binary.Write(conn.c, binary.BigEndian, uint8(len(tt.server))); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.server)); err != nil { + t.Fatal(err) + } + if len(tt.reason) > 0 { + if err := binary.Write(conn.c, binary.BigEndian, uint32(len(tt.reason))); err != nil { t.Fatal(err) } - if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.server)); err != nil { + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.reason)); err != nil { + t.Fatal(err) + } + } + if tt.secType == secTypeVNCAuth { + if err := writeVNCAuthChallenge(conn.c); err != nil { t.Fatal(err) } } - - // Validate server message. conn.config.Auth = tt.client + + // Validate server message handling. err := conn.securityHandshake() if err == nil && !tt.ok { t.Fatalf("securityHandshake() expected error for server auth %v", tt.server) } - if len(tt.server) == 0 { - // The protocol was incomplete; no point in checking values. - continue - } if err != nil { if verr, ok := err.(*VNCError); !ok { t.Errorf("securityHandshake() unexpected %v error: %v", reflect.TypeOf(err), verr) } } + if !tt.ok { + continue + } // Validate client response. var secType uint8 @@ -208,6 +312,11 @@ func TestSecurityHandshake(t *testing.T) { if secType != tt.secType { t.Errorf("securityHandshake() secType: got = %v, want = %v", secType, tt.secType) } + if tt.secType == secTypeVNCAuth { + if err := readVNCAuthChallenge(conn.c); err != nil { + t.Fatal(err) + } + } } } @@ -239,7 +348,7 @@ func TestSecurityResultHandshake(t *testing.T) { t.Fatal(err) } - // Validate server message. + // Validate server message handling. err := conn.securityResultHandshake() if err == nil && !tt.ok { t.Fatalf("securityResultHandshake() expected error for result %v", tt.result) @@ -339,7 +448,7 @@ func TestServerInit(t *testing.T) { } } - // Validate server message. + // Validate server message handling. err := conn.serverInit() if tt.eof < dn && err == nil { t.Fatalf("serverInit() expected error")