@@ -7,6 +7,7 @@ package tunnel
77import (
88 "context"
99 "crypto/tls"
10+ "crypto/x509"
1011 "encoding/json"
1112 "errors"
1213 "fmt"
@@ -25,6 +26,20 @@ import (
2526 "github.com/inconshreveable/go-vhost"
2627)
2728
29+ // A set of listeners to manage subscribers
30+ type SubscriptionListener interface {
31+ // Invoked if AutoSubscribe is false and must return true if the client is allowed to subscribe or not.
32+ // If the tlsConfig is configured to require client certificate validation, chain will contain the first
33+ // verified chain, else the presented peer certificate.
34+ CanSubscribe (id id.ID , chain []* x509.Certificate ) bool
35+ // Invoked when the client has been subscribed.
36+ // If the tlsConfig is configured to require client certificate validation, chain will contain the first
37+ // verified chain, else the presented peer certificate.
38+ Subscribed (id id.ID , tlsConn * tls.Conn , chain []* x509.Certificate )
39+ // Invoked before the client is unsubscribed.
40+ Unsubscribed (id id.ID )
41+ }
42+
2843// ServerConfig defines configuration for the Server.
2944type ServerConfig struct {
3045 // Addr is TCP address to listen for client connections. If empty ":0" is used.
@@ -43,6 +58,8 @@ type ServerConfig struct {
4358 KeepAlive connection.KeepAliveConfig
4459 // How long should a disconnected message been hold before sending it to the log
4560 Debounce Debounced
61+ // Optional listener to manage subscribers
62+ SubscriptionListener SubscriptionListener
4663}
4764
4865// Server is responsible for proxying public connections to the client over a
@@ -274,6 +291,7 @@ func (s *Server) handleClient(conn net.Conn) {
274291 ok bool
275292
276293 inConnPool bool
294+ certs []* x509.Certificate
277295
278296 remainingIDs []id.ID
279297 found bool
@@ -301,14 +319,26 @@ func (s *Server) handleClient(conn net.Conn) {
301319
302320 logger = logger .With ("identifier" , identifier )
303321
322+ certs = tlsConn .ConnectionState ().PeerCertificates
323+ if tlsConn .ConnectionState ().VerifiedChains != nil && len (tlsConn .ConnectionState ().VerifiedChains ) > 0 {
324+ certs = tlsConn .ConnectionState ().VerifiedChains [0 ]
325+ }
304326 if s .config .AutoSubscribe {
305327 s .Subscribe (identifier )
328+ if s .config .SubscriptionListener != nil {
329+ s .config .SubscriptionListener .Subscribed (identifier , tlsConn , certs )
330+ }
306331 } else if ! s .IsSubscribed (identifier ) {
307- logger .Log (
308- "level" , 2 ,
309- "msg" , "unknown client" ,
310- )
311- goto reject
332+ if s .config .SubscriptionListener != nil && s .config .SubscriptionListener .CanSubscribe (identifier , certs ) {
333+ s .Subscribe (identifier )
334+ s .config .SubscriptionListener .Subscribed (identifier , tlsConn , certs )
335+ } else {
336+ logger .Log (
337+ "level" , 2 ,
338+ "msg" , "unknown client" ,
339+ )
340+ goto reject
341+ }
312342 }
313343
314344 if err = conn .SetDeadline (time.Time {}); err != nil {
@@ -555,9 +585,12 @@ rollback:
555585 return err
556586}
557587
558- // Disconnect removes client from registry, disconnects client if already
588+ // Unsubscribe removes client from registry, disconnects client if already
559589// connected and returns it's RegistryItem.
560- func (s * Server ) Disconnect (identifier id.ID ) * RegistryItem {
590+ func (s * Server ) Unsubscribe (identifier id.ID ) * RegistryItem {
591+ if s .config .SubscriptionListener != nil {
592+ s .config .SubscriptionListener .Unsubscribed (identifier )
593+ }
561594 s .connPool .DeleteConn (identifier )
562595 return s .registry .Unsubscribe (identifier )
563596}
@@ -639,6 +672,50 @@ func (s *Server) listen(l net.Listener, identifier id.ID) {
639672 }
640673}
641674
675+ func (s * Server ) Upgrade (identifier id.ID , conn net.Conn , requestBytes []byte ) error {
676+
677+ var err error
678+
679+ msg := & proto.ControlMessage {
680+ Action : proto .ActionProxy ,
681+ ForwardedProto : "https" ,
682+ }
683+
684+ tlsConn , ok := conn .(* tls.Conn )
685+ if ok {
686+ msg .ForwardedHost = tlsConn .ConnectionState ().ServerName
687+ err = s .config .KeepAlive .Set (tlsConn .NetConn ())
688+
689+ } else {
690+ msg .ForwardedHost = conn .RemoteAddr ().String ()
691+ err = s .config .KeepAlive .Set (conn )
692+ }
693+
694+ if err != nil {
695+ s .logger .Log (
696+ "level" , 1 ,
697+ "msg" , "TCP keepalive for tunneled connection failed" ,
698+ "identifier" , identifier ,
699+ "ctrlMsg" , msg ,
700+ "err" , err ,
701+ )
702+ }
703+
704+ go func () {
705+ if err := s .proxyConnUpgraded (identifier , conn , msg , requestBytes ); err != nil {
706+ s .logger .Log (
707+ "level" , 0 ,
708+ "msg" , "proxy error" ,
709+ "identifier" , identifier ,
710+ "ctrlMsg" , msg ,
711+ "err" , err ,
712+ )
713+ }
714+ }()
715+
716+ return nil
717+ }
718+
642719// ServeHTTP proxies http connection to the client.
643720func (s * Server ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
644721 resp , err := s .RoundTrip (r )
@@ -724,6 +801,74 @@ func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) {
724801 return s .proxyHTTP (identifier , outr , msg )
725802}
726803
804+ func (s * Server ) proxyConnUpgraded (identifier id.ID , conn net.Conn , msg * proto.ControlMessage , requestBytes []byte ) error {
805+ s .logger .Log (
806+ "level" , 2 ,
807+ "action" , "proxy conn" ,
808+ "identifier" , identifier ,
809+ "ctrlMsg" , msg ,
810+ )
811+
812+ defer conn .Close ()
813+
814+ pr , pw := io .Pipe ()
815+ defer pr .Close ()
816+ defer pw .Close ()
817+
818+ continueChan := make (chan int )
819+
820+ go func () {
821+ pw .Write (requestBytes )
822+ continueChan <- 1
823+ }()
824+
825+ req , err := s .connectRequest (identifier , msg , pr )
826+ if err != nil {
827+ return err
828+ }
829+
830+ ctx , cancel := context .WithCancel (context .Background ())
831+ req = req .WithContext (ctx )
832+
833+ done := make (chan struct {})
834+ go func () {
835+ <- continueChan
836+ transfer (pw , conn , log .NewContext (s .logger ).With (
837+ "dir" , "user to client" ,
838+ "dst" , identifier ,
839+ "src" , conn .RemoteAddr (),
840+ ))
841+ cancel ()
842+ close (done )
843+ }()
844+
845+ resp , err := s .httpClient .Do (req )
846+ if err != nil {
847+ return fmt .Errorf ("io error: %s" , err )
848+ }
849+ defer resp .Body .Close ()
850+
851+ transfer (conn , resp .Body , log .NewContext (s .logger ).With (
852+ "dir" , "client to user" ,
853+ "dst" , conn .RemoteAddr (),
854+ "src" , identifier ,
855+ ))
856+
857+ select {
858+ case <- done :
859+ case <- time .After (DefaultTimeout ):
860+ }
861+
862+ s .logger .Log (
863+ "level" , 2 ,
864+ "action" , "proxy conn done" ,
865+ "identifier" , identifier ,
866+ "ctrlMsg" , msg ,
867+ )
868+
869+ return nil
870+ }
871+
727872func (s * Server ) proxyConn (identifier id.ID , conn net.Conn , msg * proto.ControlMessage ) error {
728873 s .logger .Log (
729874 "level" , 2 ,
0 commit comments