diff --git a/client.go b/client.go index 91402c9..e628bd6 100644 --- a/client.go +++ b/client.go @@ -1,7 +1,9 @@ -// Package vnc implements a VNC client. -// -// References: -// [PROTOCOL]: http://tools.ietf.org/html/rfc6143 +/* +Package vnc implements a VNC client. + +References: + [PROTOCOL]: http://tools.ietf.org/html/rfc6143 +*/ package vnc import ( @@ -13,9 +15,11 @@ import ( "unicode" ) +// The ClientConn type holds client connection information. type ClientConn struct { - c net.Conn - config *ClientConfig + c net.Conn + config *ClientConfig + protocolVersion string // If the pixel format uses a color map, then this is the color // map that is used. This should not be modified directly, since @@ -48,6 +52,9 @@ type ClientConfig struct { // suitable by the server will be used to authenticate. Auth []ClientAuth + // Password for servers that require authentication. + Password string + // Exclusive determines whether the connection is shared with other // clients. If true, then all other clients connected will be // disconnected when a connection is established to the VNC server. @@ -66,13 +73,36 @@ type ClientConfig struct { ServerMessages []ServerMessage } +func NewClientConfig(p string) *ClientConfig { + return &ClientConfig{ + Auth: []ClientAuth{&ClientAuthNone{}, &ClientAuthVNC{p}}, + Password: p, + } +} + func Client(c net.Conn, cfg *ClientConfig) (*ClientConn, error) { conn := &ClientConn{ c: c, config: cfg, } - if err := conn.handshake(); err != nil { + if err := conn.protocolVersionHandshake(); err != nil { + conn.Close() + return nil, err + } + if err := conn.securityHandshake(); err != nil { + conn.Close() + return nil, err + } + if err := conn.securityResultHandshake(); err != nil { + conn.Close() + return nil, err + } + if err := conn.clientInit(); err != nil { + conn.Close() + return nil, err + } + if err := conn.serverInit(); err != nil { conn.Close() return nil, err } @@ -152,7 +182,6 @@ func (c *ClientConn) FramebufferUpdateRequest(incremental bool, x, y, width, hei return err } } - if _, err := c.c.Write(buf.Bytes()[0:10]); err != nil { return err } @@ -160,7 +189,7 @@ func (c *ClientConn) FramebufferUpdateRequest(incremental bool, x, y, width, hei return nil } -// KeyEvent indiciates a key press or release and sends it to the server. +// KeyEvent indicates a key press or release and sends it to the server. // The key is indicated using the X Window System "keysym" value. Use // Google to find a reference of these values. To simulate a key press, // you must send a key with both a down event, and a non-down event. @@ -299,119 +328,194 @@ func parseProtocolVersion(pv []byte) (uint, uint, error) { return major, minor, nil } -func (c *ClientConn) handshake() error { +const ( + // Client ProtocolVersions. + PROTO_VERS_UNSUP = "UNSUPPORTED" + PROTO_VERS_3_3 = "RFB 003.003\n" + PROTO_VERS_3_8 = "RFB 003.008\n" +) + +// protocolVersionHandshake implements §7.1.1 ProtocolVersion Handshake. +func (c *ClientConn) protocolVersionHandshake() error { var protocolVersion [pvLen]byte - // 7.1.1, read the ProtocolVersion message sent by the server. + // Read the ProtocolVersion message sent by the server. if _, err := io.ReadFull(c.c, protocolVersion[:]); err != nil { return err } - maxMajor, maxMinor, err := parseProtocolVersion(protocolVersion[:]) + major, minor, err := parseProtocolVersion(protocolVersion[:]) if err != nil { return err } - if maxMajor < 3 { - return fmt.Errorf("unsupported major version, less than 3: %d", maxMajor) + pv := PROTO_VERS_UNSUP + if major == 3 { + if minor >= 8 { + pv = PROTO_VERS_3_8 + } else if minor >= 3 { + pv = PROTO_VERS_3_3 + } } - if maxMinor < 8 { - return fmt.Errorf("unsupported minor version, less than 8: %d", maxMinor) + if pv == PROTO_VERS_UNSUP { + return NewVNCError(fmt.Sprintf("ProtocolVersion handshake failed; unsupported version '%v'", string(protocolVersion[:]))) } + c.protocolVersion = pv // Respond with the version we will support - if _, err = c.c.Write([]byte("RFB 003.008\n")); err != nil { + if _, err = c.c.Write([]byte(pv)); err != nil { return err } - // 7.1.2 Security Handshake from server - var numSecurityTypes uint8 - if err = binary.Read(c.c, binary.BigEndian, &numSecurityTypes); err != nil { - return err + return nil +} + +// securityHandshake implements §7.1.2 Security Handshake. +func (c *ClientConn) securityHandshake() error { + + switch c.protocolVersion { + case PROTO_VERS_3_3: + err := c.securityHandshake33() + if err != nil { + return err + } + case PROTO_VERS_3_8: + err := c.securityHandshake38() + if err != nil { + return err + } + default: + return NewVNCError(fmt.Sprintf("Security handshake failed; unsupported protocol")) } + return nil +} - if numSecurityTypes == 0 { - return fmt.Errorf("no security types: %s", c.readErrorReason()) +func (c *ClientConn) securityHandshake33() error { + var secType uint32 + if err := binary.Read(c.c, binary.BigEndian, &secType); err != nil { + return err } - securityTypes := make([]uint8, numSecurityTypes) - if err = binary.Read(c.c, binary.BigEndian, &securityTypes); err != nil { + var auth ClientAuth + switch secType { + case secTypeInvalid: // Connection failed. + reason, err := c.readErrorReason() + if err != nil { + return err + } + return NewVNCError(fmt.Sprintf("Security handshake failed; connection failed: %s", reason)) + case secTypeNone: + auth = &ClientAuthNone{} + case secTypeVNCAuth: + auth = &ClientAuthVNC{c.config.Password} + default: + return NewVNCError(fmt.Sprintf("Security handshake failed; invalid security type: %v", secType)) + } + if err := auth.Handshake(c.c); err != nil { return err } - clientSecurityTypes := c.config.Auth - if clientSecurityTypes == nil { - clientSecurityTypes = []ClientAuth{new(ClientAuthNone)} + return nil +} + +func (c *ClientConn) securityHandshake38() error { + // Determine server supported security types. + var numSecurityTypes uint8 + if err := binary.Read(c.c, binary.BigEndian, &numSecurityTypes); err != nil { + return err + } + if numSecurityTypes == 0 { + reason, err := c.readErrorReason() + if err != nil { + return err + } + return NewVNCError(fmt.Sprintf("Security handshake failed; no security types: %v", reason)) + } + securityTypes := make([]uint8, numSecurityTypes) + if err := binary.Read(c.c, binary.BigEndian, &securityTypes); err != nil { + return err } + // Choose client security type. + // TODO(kward): try "better" security types first. var auth ClientAuth FindAuth: - for _, curAuth := range clientSecurityTypes { - for _, securityType := range securityTypes { - if curAuth.SecurityType() == securityType { - // We use the first matching supported authentication - auth = curAuth + for _, securityType := range securityTypes { + for _, a := range c.config.Auth { + if a.SecurityType() == securityType { + // We use the first matching supported authentication. + auth = a break FindAuth } } } - if auth == nil { - return fmt.Errorf("no suitable auth schemes found. server supported: %#v", securityTypes) + return NewVNCError(fmt.Sprintf("Security handshake failed; no suitable auth schemes found; server supports: %#v", securityTypes)) } - - // Respond back with the security type we'll use - if err = binary.Write(c.c, binary.BigEndian, auth.SecurityType()); err != nil { + if err := binary.Write(c.c, binary.BigEndian, auth.SecurityType()); err != nil { return err } - if err = auth.Handshake(c.c); err != nil { + if err := auth.Handshake(c.c); err != nil { return err } + return nil +} + +// securityResultHandshake implements §7.1.3 SecurityResult Handshake. +func (c *ClientConn) securityResultHandshake() error { + if c.protocolVersion == PROTO_VERS_3_3 { + // Not required for 3.3. + return nil + } - // 7.1.3 SecurityResult Handshake var securityResult uint32 - if err = binary.Read(c.c, binary.BigEndian, &securityResult); err != nil { + if err := binary.Read(c.c, binary.BigEndian, &securityResult); err != nil { return err } - if securityResult == 1 { - return fmt.Errorf("security handshake failed: %s", c.readErrorReason()) + reason, err := c.readErrorReason() + if err != nil { + return err + } + return NewVNCError(fmt.Sprintf("SecurityResult handshake failed: %s", reason)) } + return nil +} - // 7.3.1 ClientInit - var sharedFlag uint8 = 1 - if c.config.Exclusive { - sharedFlag = 0 - } +// clientInit implements §7.3.1 ClientInit. +func (c *ClientConn) clientInit() error { + var sharedFlag uint8 - if err = binary.Write(c.c, binary.BigEndian, sharedFlag); err != nil { + if !c.config.Exclusive { + sharedFlag = 1 + } + if err := binary.Write(c.c, binary.BigEndian, sharedFlag); err != nil { return err } - // 7.3.2 ServerInit - if err = binary.Read(c.c, binary.BigEndian, &c.FrameBufferWidth); err != nil { + return nil +} + +// serverInit implements §7.3.2 ServerInit. +func (c *ClientConn) serverInit() error { + if err := binary.Read(c.c, binary.BigEndian, &c.FrameBufferWidth); err != nil { return err } - - if err = binary.Read(c.c, binary.BigEndian, &c.FrameBufferHeight); err != nil { + if err := binary.Read(c.c, binary.BigEndian, &c.FrameBufferHeight); err != nil { return err } - - // Read the pixel format - if err = readPixelFormat(c.c, &c.PixelFormat); err != nil { + if err := readPixelFormat(c.c, &c.PixelFormat); err != nil { return err } var nameLength uint32 - if err = binary.Read(c.c, binary.BigEndian, &nameLength); err != nil { + if err := binary.Read(c.c, binary.BigEndian, &nameLength); err != nil { return err } - nameBytes := make([]uint8, nameLength) - if err = binary.Read(c.c, binary.BigEndian, &nameBytes); err != nil { + if err := binary.Read(c.c, binary.BigEndian, &nameBytes); err != nil { return err } - c.DesktopName = string(nameBytes) return nil @@ -467,16 +571,31 @@ func (c *ClientConn) mainLoop() { } } -func (c *ClientConn) readErrorReason() string { +// TODO(kward): need a context for timeout +func (c *ClientConn) readErrorReason() (string, error) { var reasonLen uint32 if err := binary.Read(c.c, binary.BigEndian, &reasonLen); err != nil { - return "" + return "", err } reason := make([]uint8, reasonLen) if err := binary.Read(c.c, binary.BigEndian, &reason); err != nil { - return "" + return "", err } - return string(reason) + return string(reason), nil +} + +// VNCError implements error interface. +type VNCError struct { + s string +} + +// NewVNCError returns a custom VNCError error. +func NewVNCError(s string) error { + return &VNCError{s} +} + +func (e VNCError) Error() string { + return e.s } diff --git a/client_auth.go b/client_auth.go index c88f911..c6dd037 100644 --- a/client_auth.go +++ b/client_auth.go @@ -4,6 +4,12 @@ import ( "net" ) +const ( + secTypeInvalid = iota + secTypeNone + secTypeVNCAuth +) + // A ClientAuth implements a method of authenticating with a remote server. type ClientAuth interface { // SecurityType returns the byte identifier sent by the server to @@ -16,10 +22,10 @@ type ClientAuth interface { } // ClientAuthNone is the "none" authentication. See 7.1.2 -type ClientAuthNone byte +type ClientAuthNone struct{} func (*ClientAuthNone) SecurityType() uint8 { - return 1 + return secTypeNone } func (*ClientAuthNone) Handshake(net.Conn) error { diff --git a/client_auth_test.go b/client_auth_test.go index 746b46f..f216db1 100644 --- a/client_auth_test.go +++ b/client_auth_test.go @@ -1,6 +1,10 @@ package vnc -import "testing" +import ( + "encoding/hex" + "strings" + "testing" +) func TestClientAuthNone_Impl(t *testing.T) { var raw interface{} @@ -9,3 +13,48 @@ func TestClientAuthNone_Impl(t *testing.T) { t.Fatal("ClientAuthNone doesn't implement ClientAuth") } } + +func TestClientAuthVNC_Impl(t *testing.T) { + var raw interface{} + raw = new(ClientAuthVNC) + if _, ok := raw.(ClientAuth); !ok { + t.Fatal("ClientAuthVNC doesn't implement ClientAuth") + } +} + +// wiresharkToChallenge converts VNC authentication challenge and response +// values captured with Wireshark (https://www.wireshark.org) into usable byte +// streams. +func wiresharkToChallenge(h string) [vncAuthChallengeSize]byte { + var c [vncAuthChallengeSize]byte + r := strings.NewReplacer(":", "") + b, err := hex.DecodeString(r.Replace(h)) + if err != nil { + return c + } + copy(c[:], b) + return c +} + +func TestClientAuthVNCEncode(t *testing.T) { + tests := []struct { + pw string + challenge, response string + }{ + {".", "7f:e2:e1:3d:a4:ae:10:9c:54:c5:5f:52:74:aa:db:31", "1d:86:92:71:1f:00:24:35:02:d3:91:ef:e9:bc:c5:d5"}, + {"12345678", "13:8e:a4:2e:0e:66:f3:ad:2d:f3:08:c3:04:cd:c4:2a", "5b:e1:56:fa:49:49:ef:56:d3:f8:44:97:73:27:95:9f"}, + {"abc123", "c6:30:45:d2:57:9e:e7:f2:f9:0c:62:3e:52:40:86:c6", "a3:63:59:e4:28:c8:7f:b3:45:2c:d7:e0:ca:d6:70:3e"}, + } + + for _, tt := range tests { + challenge := wiresharkToChallenge(tt.challenge) + a := ClientAuthVNC{tt.pw} + if err := a.encode(&challenge); err != nil { + t.Errorf("ClientAuthVNC.encode() failed: key=%v, err=%v", tt.pw, err) + } + response := wiresharkToChallenge(tt.response) + if challenge != response { + t.Errorf("ClientAuthVNC.encode() failed: key=%v got=%v, want=%v", tt.pw, challenge, response) + } + } +} diff --git a/client_auth_vnc.go b/client_auth_vnc.go new file mode 100644 index 0000000..9d7495f --- /dev/null +++ b/client_auth_vnc.go @@ -0,0 +1,72 @@ +/* +ClientAuthVNC implements the ClientAuth interface to provide support for +VNC Authentication. + +See http://tools.ietf.org/html/rfc6143#section-7.2.2 for more info. +*/ +package vnc + +import ( + "crypto/des" + "encoding/binary" + "net" +) + +// ClientAuthVNC is the standard password authentication +type ClientAuthVNC struct { + Password string +} + +func (*ClientAuthVNC) SecurityType() uint8 { + return secTypeVNCAuth +} + +// 7.2.2. VNC Authentication uses a 16-byte challenge. +const vncAuthChallengeSize = 16 + +func (auth *ClientAuthVNC) Handshake(conn net.Conn) error { + + if auth.Password == "" { + return NewVNCError("securityHandshake: handshake failed; no password provided for VNCAuth.") + } + + // Read challenge block + var challenge [vncAuthChallengeSize]byte + if err := binary.Read(conn, binary.BigEndian, &challenge); err != nil { + return err + } + + auth.encode(&challenge) + + // Send the encrypted challenge back to server + if err := binary.Write(conn, binary.BigEndian, challenge); err != nil { + return err + } + + return nil +} + +func (auth *ClientAuthVNC) encode(c *[vncAuthChallengeSize]byte) error { + // Copy password string to 8 byte 0-padded slice + key := make([]byte, 8) + copy(key, auth.Password) + + // Each byte of the password needs to be reversed. This is a + // non RFC-documented behaviour of VNC clients and servers + for i := range key { + key[i] = (key[i]&0x55)<<1 | (key[i]&0xAA)>>1 // Swap adjacent bits + key[i] = (key[i]&0x33)<<2 | (key[i]&0xCC)>>2 // Swap adjacent pairs + key[i] = (key[i]&0x0F)<<4 | (key[i]&0xF0)>>4 // Swap the 2 halves + } + + // Encrypt challenge with key. + cipher, err := des.NewCipher(key) + if err != nil { + return err + } + for i := 0; i < vncAuthChallengeSize; i += cipher.BlockSize() { + cipher.Encrypt(c[i:i+cipher.BlockSize()], c[i:i+cipher.BlockSize()]) + } + + return nil +} diff --git a/client_test.go b/client_test.go index 31591b4..6423542 100644 --- a/client_test.go +++ b/client_test.go @@ -1,9 +1,14 @@ package vnc import ( + "bytes" + "encoding/binary" "fmt" + "io" "net" + "reflect" "testing" + "time" ) func newMockServer(t *testing.T, version string) string { @@ -39,14 +44,15 @@ func TestClient_LowMajorVersion(t *testing.T) { if err == nil { t.Fatal("error expected") } - - if err.Error() != "unsupported major version, less than 3: 2" { - t.Fatalf("unexpected error: %s", err) + if err != nil { + if verr, ok := err.(*VNCError); !ok { + t.Errorf("Client() unexpected %v error: %v", reflect.TypeOf(err), verr) + } } } func TestClient_LowMinorVersion(t *testing.T) { - nc, err := net.Dial("tcp", newMockServer(t, "003.007")) + nc, err := net.Dial("tcp", newMockServer(t, "003.002")) if err != nil { t.Fatalf("error connecting to mock server: %s", err) } @@ -55,9 +61,10 @@ func TestClient_LowMinorVersion(t *testing.T) { if err == nil { t.Fatal("error expected") } - - if err.Error() != "unsupported minor version, less than 8: 7" { - t.Fatalf("unexpected error: %s", err) + if err != nil { + if verr, ok := err.(*VNCError); !ok { + t.Errorf("Client() unexpected %v error: %v", reflect.TypeOf(err), verr) + } } } @@ -82,6 +89,7 @@ func TestParseProtocolVersion(t *testing.T) { if err != nil && !tt.isErr { t.Fatalf("parseProtocolVersion(%v) unexpected error %v", tt.proto, err) } + // TODO(kward): validate VNCError thrown. if err == nil && tt.isErr { t.Fatalf("parseProtocolVersion(%v) expected error", tt.proto) } @@ -93,3 +101,397 @@ func TestParseProtocolVersion(t *testing.T) { } } } + +func TestProtocolVersionHandshake(t *testing.T) { + tests := []struct { + server string + client string + ok bool + }{ + // Supported versions. + {"RFB 003.003\n", "RFB 003.003\n", true}, + {"RFB 003.006\n", "RFB 003.003\n", true}, + {"RFB 003.008\n", "RFB 003.008\n", true}, + {"RFB 003.389\n", "RFB 003.008\n", true}, + // Unsupported versions. + {server: "RFB 002.009\n", ok: false}, + } + + mockConn := &MockConn{} + conn := &ClientConn{ + c: mockConn, + config: &ClientConfig{}, + } + + for _, tt := range tests { + mockConn.Reset() + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.server)); err != nil { + t.Fatal(err) + } + + // Validate server message handling. + err := conn.protocolVersionHandshake() + if err == nil && !tt.ok { + t.Fatalf("protocolVersionHandshake() expected error for server protocol version %v", tt.server) + } + if err != nil { + if verr, ok := err.(*VNCError); !ok { + t.Errorf("protocolVersionHandshake() unexpected %v error: %v", reflect.TypeOf(err), verr) + } + } + + // Validate client response. + var client [pvLen]byte + err = binary.Read(conn.c, binary.BigEndian, &client) + if err == nil && !tt.ok { + t.Fatalf("protocolVersionHandshake() unexpected error: %v", err) + } + if string(client[:]) != tt.client && tt.ok { + t.Errorf("protocolVersionHandshake() client version: got = %v, want = %v", string(client[:]), tt.client) + } + } +} + +func writeVNCAuthChallenge(w io.Writer) error { + var c [vncAuthChallengeSize]uint8 + for i := 0; i < vncAuthChallengeSize; i++ { + c[i] = uint8(i) + } + if err := binary.Write(w, binary.BigEndian, c); err != nil { + return err + } + return nil +} + +func readVNCAuthChallenge(r io.Reader) error { + var c [vncAuthChallengeSize]uint8 + if err := binary.Read(r, binary.BigEndian, &c); err != nil { + return fmt.Errorf("error reading back VNCAuth challenge") + } + return nil +} + +func TestSecurityHandshake33(t *testing.T) { + tests := []struct { + server uint32 + ok bool + reason string + }{ + //-- Supported security types. -- + // Server supports None. + {secTypeNone, true, ""}, + // Server supports VNCAuth. + {secTypeVNCAuth, true, ""}, + //-- Unsupported security types. -- + {secTypeInvalid, false, "some reason"}, + {255, false, ""}, + } + + mockConn := &MockConn{} + conn := &ClientConn{ + c: mockConn, + config: NewClientConfig("."), + protocolVersion: PROTO_VERS_3_3, + } + + for _, tt := range tests { + mockConn.Reset() + if err := binary.Write(conn.c, binary.BigEndian, tt.server); err != nil { + t.Fatal(err) + } + if len(tt.reason) > 0 { + if err := binary.Write(conn.c, binary.BigEndian, uint32(len(tt.reason))); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.reason)); err != nil { + t.Fatal(err) + } + } + if tt.server == secTypeVNCAuth { + if err := writeVNCAuthChallenge(conn.c); err != nil { + t.Fatal(err) + } + } + + // Validate server message handling. + err := conn.securityHandshake() + if err == nil && !tt.ok { + t.Fatalf("securityHandshake() expected error for server auth %v", tt.server) + } + if err != nil { + if verr, ok := err.(*VNCError); !ok { + t.Errorf("securityHandshake() unexpected %v error: %v", reflect.TypeOf(err), verr) + } + } + if !tt.ok { + continue + } + + // Validate client response. + if tt.server == secTypeVNCAuth { + if err := readVNCAuthChallenge(conn.c); err != nil { + t.Fatal(err) + } + } + } +} + +func TestSecurityHandshake38(t *testing.T) { + tests := []struct { + server []uint8 + client []ClientAuth + secType uint8 + ok bool + reason string + }{ + //-- Supported security types. -- + // Server and client support None. + {[]uint8{secTypeNone}, []ClientAuth{&ClientAuthNone{}}, secTypeNone, true, ""}, + // Server and client support VNCAuth. + {[]uint8{secTypeVNCAuth}, []ClientAuth{&ClientAuthVNC{"."}}, secTypeVNCAuth, true, ""}, + // Server and client both support VNCAuth and None. + {[]uint8{secTypeVNCAuth, secTypeNone}, []ClientAuth{&ClientAuthVNC{"."}, &ClientAuthNone{}}, secTypeVNCAuth, true, ""}, + // Server supports unknown #255, VNCAuth and None. + {[]uint8{255, secTypeVNCAuth, secTypeNone}, []ClientAuth{&ClientAuthVNC{"."}, &ClientAuthNone{}}, secTypeVNCAuth, true, ""}, + //-- Unsupported security types. -- + // Server provided no valid security types. + {[]uint8{secTypeInvalid}, []ClientAuth{}, secTypeInvalid, false, "some reason"}, + // Client and server don't support same security types. + {[]uint8{secTypeVNCAuth}, []ClientAuth{&ClientAuthNone{}}, secTypeInvalid, false, ""}, + // Server supports only unknown #255. + {[]uint8{255}, []ClientAuth{&ClientAuthNone{}}, secTypeInvalid, false, ""}, + } + + mockConn := &MockConn{} + conn := &ClientConn{ + c: mockConn, + config: &ClientConfig{}, + protocolVersion: PROTO_VERS_3_8, + } + + for _, tt := range tests { + mockConn.Reset() + if err := binary.Write(conn.c, binary.BigEndian, uint8(len(tt.server))); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.server)); err != nil { + t.Fatal(err) + } + if len(tt.reason) > 0 { + if err := binary.Write(conn.c, binary.BigEndian, uint32(len(tt.reason))); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.reason)); err != nil { + t.Fatal(err) + } + } + if tt.secType == secTypeVNCAuth { + if err := writeVNCAuthChallenge(conn.c); err != nil { + t.Fatal(err) + } + } + conn.config.Auth = tt.client + + // Validate server message handling. + err := conn.securityHandshake() + if err == nil && !tt.ok { + t.Fatalf("securityHandshake() expected error for server auth %v", tt.server) + } + if err != nil { + if verr, ok := err.(*VNCError); !ok { + t.Errorf("securityHandshake() unexpected %v error: %v", reflect.TypeOf(err), verr) + } + } + if !tt.ok { + continue + } + + // Validate client response. + var secType uint8 + err = binary.Read(conn.c, binary.BigEndian, &secType) + if secType != tt.secType { + t.Errorf("securityHandshake() secType: got = %v, want = %v", secType, tt.secType) + } + if tt.secType == secTypeVNCAuth { + if err := readVNCAuthChallenge(conn.c); err != nil { + t.Fatal(err) + } + } + } +} + +func TestSecurityResultHandshake(t *testing.T) { + tests := []struct { + result uint32 + ok bool + reason string + }{ + {0, true, ""}, + {1, false, "SecurityResult error"}, + } + + mockConn := &MockConn{} + conn := &ClientConn{ + c: mockConn, + config: &ClientConfig{}, + } + + for _, tt := range tests { + mockConn.Reset() + if err := binary.Write(conn.c, binary.BigEndian, tt.result); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, uint32(len(tt.reason))); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.reason)); err != nil { + t.Fatal(err) + } + + // Validate server message handling. + err := conn.securityResultHandshake() + if err == nil && !tt.ok { + t.Fatalf("securityResultHandshake() expected error for result %v", tt.result) + } + if err != nil { + if verr, ok := err.(*VNCError); !ok { + t.Errorf("securityResultHandshake() unexpected %v error: %v", reflect.TypeOf(err), verr) + } + } + } +} + +func TestClientInit(t *testing.T) { + tests := []struct { + exclusive bool + shared uint8 + }{ + {true, 0}, + {false, 1}, + } + + mockConn := &MockConn{} + conn := &ClientConn{ + c: mockConn, + config: &ClientConfig{}, + } + + for _, tt := range tests { + mockConn.Reset() + conn.config.Exclusive = tt.exclusive + + // Validate client response. + err := conn.clientInit() + if err != nil { + t.Fatalf("clientInit() unexpected error %v", err) + } + var shared uint8 + err = binary.Read(conn.c, binary.BigEndian, &shared) + if shared != tt.shared { + t.Errorf("clientInit() shared: got = %v, want = %v", shared, tt.shared) + } + } +} + +func TestServerInit(t *testing.T) { + const ( + none = iota + fbw + fbh + pf + dn + ) + tests := []struct { + eof int + fbWidth, fbHeight uint16 + pixelFormat [16]byte // TODO(kward): replace with PixelFormat + desktopName string + }{ + // Valid protocol. + {dn, 100, 200, [16]byte{}, "foo"}, + // Invalid protocol (missing fields). + {eof: none}, + {eof: fbw, fbWidth: 1}, + {eof: fbh, fbWidth: 2, fbHeight: 1}, + {eof: pf, fbWidth: 3, fbHeight: 2, pixelFormat: [16]byte{}}, + } + + mockConn := &MockConn{} + conn := &ClientConn{ + c: mockConn, + config: &ClientConfig{}, + } + + for _, tt := range tests { + mockConn.Reset() + if tt.eof >= fbw { + if err := binary.Write(conn.c, binary.BigEndian, tt.fbWidth); err != nil { + t.Fatal(err) + } + } + if tt.eof >= fbh { + if err := binary.Write(conn.c, binary.BigEndian, tt.fbHeight); err != nil { + t.Fatal(err) + } + } + if tt.eof >= pf { + if err := binary.Write(conn.c, binary.BigEndian, tt.pixelFormat); err != nil { + t.Fatal(err) + } + } + if tt.eof >= dn { + if err := binary.Write(conn.c, binary.BigEndian, uint32(len(tt.desktopName))); err != nil { + t.Fatal(err) + } + if err := binary.Write(conn.c, binary.BigEndian, []byte(tt.desktopName)); err != nil { + t.Fatal(err) + } + } + + // Validate server message handling. + err := conn.serverInit() + if tt.eof < dn && err == nil { + t.Fatalf("serverInit() expected error") + } + if tt.eof < dn { + // The protocol was incomplete; no point in checking values. + continue + } + if err != nil { + t.Fatalf("serverInit() error %v", err) + } + if conn.FrameBufferWidth != tt.fbWidth { + t.Errorf("serverInit() FrameBufferWidth: got = %v, want = %v", conn.FrameBufferWidth, tt.fbWidth) + } + if conn.FrameBufferHeight != tt.fbHeight { + t.Errorf("serverInit() FrameBufferHeight: got = %v, want = %v", conn.FrameBufferHeight, tt.fbHeight) + } + // TODO(kward): add test for PixelFormat. + if conn.DesktopName != tt.desktopName { + t.Errorf("serverInit() DesktopName: got = %v, want = %v", conn.DesktopName, tt.desktopName) + } + } +} + +// MockConn implements the net.Conn interface. +type MockConn struct { + b bytes.Buffer +} + +func (m *MockConn) Read(b []byte) (int, error) { + return m.b.Read(b) +} +func (m *MockConn) Write(b []byte) (int, error) { + return m.b.Write(b) +} +func (m *MockConn) Close() error { return nil } +func (m *MockConn) LocalAddr() net.Addr { return nil } +func (m *MockConn) RemoteAddr() net.Addr { return nil } +func (m *MockConn) SetDeadline(t time.Time) error { return nil } +func (m *MockConn) SetReadDeadline(t time.Time) error { return nil } +func (m *MockConn) SetWriteDeadline(t time.Time) error { return nil } + +// Implement additional buffer.Buffer functions. +func (m *MockConn) Reset() { + m.b.Reset() +}