diff --git a/client.go b/client.go index 91402c9..8352bf0 100644 --- a/client.go +++ b/client.go @@ -1,7 +1,9 @@ -// Package vnc implements a VNC client. -// -// References: -// [PROTOCOL]: http://tools.ietf.org/html/rfc6143 +/* +Package vnc implements a VNC client. + +References: + [PROTOCOL]: http://tools.ietf.org/html/rfc6143 +*/ package vnc import ( @@ -13,6 +15,7 @@ import ( "unicode" ) +// The ClientConn type holds client connection information. type ClientConn struct { c net.Conn config *ClientConfig @@ -72,7 +75,23 @@ func Client(c net.Conn, cfg *ClientConfig) (*ClientConn, error) { config: cfg, } - if err := conn.handshake(); err != nil { + if err := conn.protocolVersionHandshake(); err != nil { + conn.Close() + return nil, err + } + if err := conn.securityHandshake(); err != nil { + conn.Close() + return nil, err + } + if err := conn.securityResultHandshake(); err != nil { + conn.Close() + return nil, err + } + if err := conn.clientInit(); err != nil { + conn.Close() + return nil, err + } + if err := conn.serverInit(); err != nil { conn.Close() return nil, err } @@ -299,42 +318,57 @@ func parseProtocolVersion(pv []byte) (uint, uint, error) { return major, minor, nil } -func (c *ClientConn) handshake() error { +const ( + // Client ProtocolVersions. + PROTO_VERS_UNSUP = "UNSUPPORTED" + PROTO_VERS_3_8 = "RFB 003.008\n" +) + +// protocolVersionHandshake implements §7.1.1 ProtocolVersion Handshake. +func (c *ClientConn) protocolVersionHandshake() error { var protocolVersion [pvLen]byte - // 7.1.1, read the ProtocolVersion message sent by the server. + // Read the ProtocolVersion message sent by the server. if _, err := io.ReadFull(c.c, protocolVersion[:]); err != nil { return err } - maxMajor, maxMinor, err := parseProtocolVersion(protocolVersion[:]) + major, minor, err := parseProtocolVersion(protocolVersion[:]) if err != nil { return err } - if maxMajor < 3 { - return fmt.Errorf("unsupported major version, less than 3: %d", maxMajor) + pv := PROTO_VERS_UNSUP + if major == 3 && minor >= 8 { + pv = PROTO_VERS_3_8 } - if maxMinor < 8 { - return fmt.Errorf("unsupported minor version, less than 8: %d", maxMinor) + if pv == PROTO_VERS_UNSUP { + return NewVNCError(fmt.Sprintf("ProtocolVersion handshake failed; unsupported version '%v'", string(protocolVersion[:]))) } // Respond with the version we will support - if _, err = c.c.Write([]byte("RFB 003.008\n")); err != nil { + if _, err = c.c.Write([]byte(pv)); err != nil { return err } - // 7.1.2 Security Handshake from server + return nil +} + +// securityHandshake implements §7.1.2 Security Handshake. +func (c *ClientConn) securityHandshake() error { var numSecurityTypes uint8 - if err = binary.Read(c.c, binary.BigEndian, &numSecurityTypes); err != nil { + if err := binary.Read(c.c, binary.BigEndian, &numSecurityTypes); err != nil { return err } - if numSecurityTypes == 0 { - return fmt.Errorf("no security types: %s", c.readErrorReason()) + reason, err := c.readErrorReason() + if err != nil { + return err + } + return NewVNCError(fmt.Sprintf("Security handshake failed; no security types: %v", reason)) } securityTypes := make([]uint8, numSecurityTypes) - if err = binary.Read(c.c, binary.BigEndian, &securityTypes); err != nil { + if err := binary.Read(c.c, binary.BigEndian, &securityTypes); err != nil { return err } @@ -354,64 +388,72 @@ FindAuth: } } } - if auth == nil { - return fmt.Errorf("no suitable auth schemes found. server supported: %#v", securityTypes) + return NewVNCError(fmt.Sprintf("Security handshake failed; no suitable auth schemes found; server supports: %#v", securityTypes)) } // Respond back with the security type we'll use - if err = binary.Write(c.c, binary.BigEndian, auth.SecurityType()); err != nil { + if err := binary.Write(c.c, binary.BigEndian, auth.SecurityType()); err != nil { return err } - - if err = auth.Handshake(c.c); err != nil { + if err := auth.Handshake(c.c); err != nil { return err } + return nil +} - // 7.1.3 SecurityResult Handshake +// securityResultHandshake implements §7.1.3 SecurityResult Handshake. +func (c *ClientConn) securityResultHandshake() error { var securityResult uint32 - if err = binary.Read(c.c, binary.BigEndian, &securityResult); err != nil { + + if err := binary.Read(c.c, binary.BigEndian, &securityResult); err != nil { return err } - if securityResult == 1 { - return fmt.Errorf("security handshake failed: %s", c.readErrorReason()) + reason, err := c.readErrorReason() + if err != nil { + return err + } + return NewVNCError(fmt.Sprintf("SecurityResult handshake failed: %s", reason)) } - // 7.3.1 ClientInit - var sharedFlag uint8 = 1 - if c.config.Exclusive { - sharedFlag = 0 - } + return nil +} - if err = binary.Write(c.c, binary.BigEndian, sharedFlag); err != nil { +// clientInit implements §7.3.1 ClientInit. +func (c *ClientConn) clientInit() error { + var sharedFlag uint8 + + if !c.config.Exclusive { + sharedFlag = 1 + } + if err := binary.Write(c.c, binary.BigEndian, sharedFlag); err != nil { return err } - // 7.3.2 ServerInit - if err = binary.Read(c.c, binary.BigEndian, &c.FrameBufferWidth); err != nil { + return nil +} + +// serverInit implements §7.3.2 ServerInit. +func (c *ClientConn) serverInit() error { + if err := binary.Read(c.c, binary.BigEndian, &c.FrameBufferWidth); err != nil { return err } - - if err = binary.Read(c.c, binary.BigEndian, &c.FrameBufferHeight); err != nil { + if err := binary.Read(c.c, binary.BigEndian, &c.FrameBufferHeight); err != nil { return err } - - // Read the pixel format - if err = readPixelFormat(c.c, &c.PixelFormat); err != nil { + if err := readPixelFormat(c.c, &c.PixelFormat); err != nil { return err } var nameLength uint32 - if err = binary.Read(c.c, binary.BigEndian, &nameLength); err != nil { + if err := binary.Read(c.c, binary.BigEndian, &nameLength); err != nil { return err } - nameBytes := make([]uint8, nameLength) - if err = binary.Read(c.c, binary.BigEndian, &nameBytes); err != nil { + if err := binary.Read(c.c, binary.BigEndian, &nameBytes); err != nil { return err } - c.DesktopName = string(nameBytes) return nil @@ -467,16 +509,30 @@ func (c *ClientConn) mainLoop() { } } -func (c *ClientConn) readErrorReason() string { +func (c *ClientConn) readErrorReason() (string, error) { var reasonLen uint32 if err := binary.Read(c.c, binary.BigEndian, &reasonLen); err != nil { - return "" + return "", err } reason := make([]uint8, reasonLen) if err := binary.Read(c.c, binary.BigEndian, &reason); err != nil { - return "" + return "", err } - return string(reason) + return string(reason), nil +} + +// VNCError implements error interface. +type VNCError struct { + s string +} + +// NewVNCError returns a custom VNCError error. +func NewVNCError(s string) error { + return &VNCError{s} +} + +func (e VNCError) Error() string { + return e.s } diff --git a/client_auth.go b/client_auth.go index c88f911..c6dd037 100644 --- a/client_auth.go +++ b/client_auth.go @@ -4,6 +4,12 @@ import ( "net" ) +const ( + secTypeInvalid = iota + secTypeNone + secTypeVNCAuth +) + // A ClientAuth implements a method of authenticating with a remote server. type ClientAuth interface { // SecurityType returns the byte identifier sent by the server to @@ -16,10 +22,10 @@ type ClientAuth interface { } // ClientAuthNone is the "none" authentication. See 7.1.2 -type ClientAuthNone byte +type ClientAuthNone struct{} func (*ClientAuthNone) SecurityType() uint8 { - return 1 + return secTypeNone } func (*ClientAuthNone) Handshake(net.Conn) error { diff --git a/client_test.go b/client_test.go index 31591b4..0faa55a 100644 --- a/client_test.go +++ b/client_test.go @@ -1,9 +1,13 @@ package vnc import ( + "bytes" + "encoding/binary" "fmt" "net" + "reflect" "testing" + "time" ) func newMockServer(t *testing.T, version string) string { @@ -39,9 +43,10 @@ func TestClient_LowMajorVersion(t *testing.T) { if err == nil { t.Fatal("error expected") } - - if err.Error() != "unsupported major version, less than 3: 2" { - t.Fatalf("unexpected error: %s", err) + if err != nil { + if verr, ok := err.(*VNCError); !ok { + t.Errorf("Client() unexpected %v error: %v", reflect.TypeOf(err), verr) + } } } @@ -55,9 +60,10 @@ func TestClient_LowMinorVersion(t *testing.T) { if err == nil { t.Fatal("error expected") } - - if err.Error() != "unsupported minor version, less than 8: 7" { - t.Fatalf("unexpected error: %s", err) + if err != nil { + if verr, ok := err.(*VNCError); !ok { + t.Errorf("Client() unexpected %v error: %v", reflect.TypeOf(err), verr) + } } } @@ -93,3 +99,289 @@ func TestParseProtocolVersion(t *testing.T) { } } } + +func TestProtocolVersionHandshake(t *testing.T) { + tests := []struct { + server string + client string + ok bool + }{ + // Supported versions. + {"RFB 003.008\n", "RFB 003.008\n", true}, + {"RFB 003.389\n", "RFB 003.008\n", true}, + // Unsupported versions. + {server: "RFB 003.003\n", ok: false}, + {server: "RFB 002.009\n", ok: false}, + } + + mockConn := &MockConn{} + conn := &ClientConn{ + c: mockConn, + config: &ClientConfig{}, + } + + for _, tt := range tests { + mockConn.Reset() + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.server)); err != nil { + t.Fatal(err) + } + + // Validate server message. + err := conn.protocolVersionHandshake() + if err == nil && !tt.ok { + t.Fatalf("protocolVersionHandshake() expected error for server protocol version %v", tt.server) + } + if err != nil { + if verr, ok := err.(*VNCError); !ok { + t.Errorf("protocolVersionHandshake() unexpected %v error: %v", reflect.TypeOf(err), verr) + } + } + + // Validate client response. + var client [pvLen]byte + err = binary.Read(conn.c, binary.BigEndian, &client) + if err == nil && !tt.ok { + t.Fatalf("protocolVersionHandshake() unexpected error: %v", err) + } + if string(client[:]) != tt.client && tt.ok { + t.Errorf("protocolVersionHandshake() client version: got = %v, want = %v", string(client[:]), tt.client) + } + } +} + +func TestSecurityHandshake(t *testing.T) { + tests := []struct { + server []uint8 + client []ClientAuth + secType uint8 + ok bool + }{ + //-- Supported security types. -- + // Both server and client support the None security type. + {[]uint8{secTypeNone}, []ClientAuth{&ClientAuthNone{}}, secTypeNone, true}, + // Server supports None and VNCAuth, client supports only None. + {[]uint8{secTypeVNCAuth, secTypeNone}, []ClientAuth{&ClientAuthNone{}}, secTypeNone, true}, + //-- Unsupported security types. -- + // Server provided no security types. + {[]uint8{}, []ClientAuth{&ClientAuthNone{}}, secTypeInvalid, false}, + // Client and server don't support same security types. + {[]uint8{secTypeVNCAuth}, []ClientAuth{&ClientAuthNone{}}, secTypeInvalid, false}, + } + + mockConn := &MockConn{} + conn := &ClientConn{ + c: mockConn, + config: &ClientConfig{}, + } + + for _, tt := range tests { + mockConn.Reset() + if len(tt.server) > 0 { + if err := binary.Write(conn.c, binary.BigEndian, uint8(len(tt.server))); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.server)); err != nil { + t.Fatal(err) + } + } + + // Validate server message. + conn.config.Auth = tt.client + err := conn.securityHandshake() + if err == nil && !tt.ok { + t.Fatalf("securityHandshake() expected error for server auth %v", tt.server) + } + if len(tt.server) == 0 { + // The protocol was incomplete; no point in checking values. + continue + } + if err != nil { + if verr, ok := err.(*VNCError); !ok { + t.Errorf("securityHandshake() unexpected %v error: %v", reflect.TypeOf(err), verr) + } + } + + // Validate client response. + var secType uint8 + err = binary.Read(conn.c, binary.BigEndian, &secType) + if secType != tt.secType { + t.Errorf("securityHandshake() secType: got = %v, want = %v", secType, tt.secType) + } + } +} + +func TestSecurityResultHandshake(t *testing.T) { + tests := []struct { + result uint32 + ok bool + reason string + }{ + {0, true, ""}, + {1, false, "SecurityResult error"}, + } + + mockConn := &MockConn{} + conn := &ClientConn{ + c: mockConn, + config: &ClientConfig{}, + } + + for _, tt := range tests { + mockConn.Reset() + if err := binary.Write(conn.c, binary.BigEndian, tt.result); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, uint32(len(tt.reason))); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.reason)); err != nil { + t.Fatal(err) + } + + // Validate server message. + err := conn.securityResultHandshake() + if err == nil && !tt.ok { + t.Fatalf("securityResultHandshake() expected error for result %v", tt.result) + } + if err != nil { + if verr, ok := err.(*VNCError); !ok { + t.Errorf("securityResultHandshake() unexpected %v error: %v", reflect.TypeOf(err), verr) + } + } + } +} + +func TestClientInit(t *testing.T) { + tests := []struct { + exclusive bool + shared uint8 + }{ + {true, 0}, + {false, 1}, + } + + mockConn := &MockConn{} + conn := &ClientConn{ + c: mockConn, + config: &ClientConfig{}, + } + + for _, tt := range tests { + mockConn.Reset() + conn.config.Exclusive = tt.exclusive + + // Validate client response. + err := conn.clientInit() + if err != nil { + t.Fatalf("clientInit() unexpected error %v", err) + } + var shared uint8 + err = binary.Read(conn.c, binary.BigEndian, &shared) + if shared != tt.shared { + t.Errorf("clientInit() shared: got = %v, want = %v", shared, tt.shared) + } + } +} + +func TestServerInit(t *testing.T) { + const ( + none = iota + fbw + fbh + pf + dn + ) + tests := []struct { + eof int + fbWidth, fbHeight uint16 + pixelFormat [16]byte // TODO(kward): replace with PixelFormat + desktopName string + }{ + // Valid protocol. + {dn, 100, 200, [16]byte{}, "foo"}, + // Invalid protocol (missing fields). + {eof: none}, + {eof: fbw, fbWidth: 1}, + {eof: fbh, fbWidth: 2, fbHeight: 1}, + {eof: pf, fbWidth: 3, fbHeight: 2, pixelFormat: [16]byte{}}, + } + + mockConn := &MockConn{} + conn := &ClientConn{ + c: mockConn, + config: &ClientConfig{}, + } + + for _, tt := range tests { + mockConn.Reset() + if tt.eof >= fbw { + if err := binary.Write(conn.c, binary.BigEndian, tt.fbWidth); err != nil { + t.Fatal(err) + } + } + if tt.eof >= fbh { + if err := binary.Write(conn.c, binary.BigEndian, tt.fbHeight); err != nil { + t.Fatal(err) + } + } + if tt.eof >= pf { + if err := binary.Write(conn.c, binary.BigEndian, tt.pixelFormat); err != nil { + t.Fatal(err) + } + } + if tt.eof >= dn { + if err := binary.Write(conn.c, binary.BigEndian, uint32(len(tt.desktopName))); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.desktopName)); err != nil { + t.Fatal(err) + } + } + + // Validate server message. + err := conn.serverInit() + if tt.eof < dn && err == nil { + t.Fatalf("serverInit() expected error") + } + if tt.eof < dn { + // The protocol was incomplete; no point in checking values. + continue + } + if err != nil { + t.Fatalf("serverInit() error %v", err) + } + if conn.FrameBufferWidth != tt.fbWidth { + t.Errorf("serverInit() FrameBufferWidth: got = %v, want = %v", conn.FrameBufferWidth, tt.fbWidth) + } + if conn.FrameBufferHeight != tt.fbHeight { + t.Errorf("serverInit() FrameBufferHeight: got = %v, want = %v", conn.FrameBufferHeight, tt.fbHeight) + } + // TODO(kward): add test for PixelFormat. + if conn.DesktopName != tt.desktopName { + t.Errorf("serverInit() DesktopName: got = %v, want = %v", conn.DesktopName, tt.desktopName) + } + } +} + +// MockConn implements the net.Conn interface. +type MockConn struct { + b bytes.Buffer +} + +func (m *MockConn) Read(b []byte) (int, error) { + return m.b.Read(b) +} +func (m *MockConn) Write(b []byte) (int, error) { + return m.b.Write(b) +} +func (m *MockConn) Close() error { return nil } +func (m *MockConn) LocalAddr() net.Addr { return nil } +func (m *MockConn) RemoteAddr() net.Addr { return nil } +func (m *MockConn) SetDeadline(t time.Time) error { return nil } +func (m *MockConn) SetReadDeadline(t time.Time) error { return nil } +func (m *MockConn) SetWriteDeadline(t time.Time) error { return nil } + +// Implement additional buffer.Buffer functions. +func (m *MockConn) Reset() { + m.b.Reset() +}