@@ -5,32 +5,28 @@ use gas::prelude::*;
55use hyper:: header:: HeaderName ;
66use rivet_guard_core:: proxy_service:: { RouteConfig , RouteTarget , RoutingOutput , RoutingTimeout } ;
77
8- use super :: SEC_WEBSOCKET_PROTOCOL ;
8+ use super :: { SEC_WEBSOCKET_PROTOCOL , WS_PROTOCOL_ACTOR , WS_PROTOCOL_TOKEN , X_RIVET_TOKEN } ;
99use crate :: { errors, shared_state:: SharedState } ;
1010
1111const ACTOR_READY_TIMEOUT : Duration = Duration :: from_secs ( 10 ) ;
1212pub const X_RIVET_ACTOR : HeaderName = HeaderName :: from_static ( "x-rivet-actor" ) ;
1313pub const X_RIVET_AMESPACE : HeaderName = HeaderName :: from_static ( "x-rivet-namespace" ) ;
14- const WS_PROTOCOL_ACTOR : & str = "rivet_actor." ;
15- const WS_PROTOCOL_TOKEN : & str = "rivet_token." ;
1614
1715/// Route requests to actor services using path-based routing
1816#[ tracing:: instrument( skip_all) ]
1917pub async fn route_request_path_based (
2018 ctx : & StandaloneCtx ,
2119 shared_state : & SharedState ,
2220 actor_id_str : & str ,
23- _token : Option < & str > ,
21+ token : Option < & str > ,
2422 path : & str ,
2523 _headers : & hyper:: HeaderMap ,
2624 _is_websocket : bool ,
2725) -> Result < Option < RoutingOutput > > {
28- // NOTE: Token validation implemented in EE
29-
3026 // Parse actor ID
3127 let actor_id = Id :: parse ( actor_id_str) . context ( "invalid actor id in path" ) ?;
3228
33- route_request_inner ( ctx, shared_state, actor_id, path) . await
29+ route_request_inner ( ctx, shared_state, actor_id, path, token ) . await
3430}
3531
3632/// Route requests to actor services based on headers
@@ -49,28 +45,39 @@ pub async fn route_request(
4945 return Ok ( None ) ;
5046 }
5147
52- // Extract actor ID from WebSocket protocol or HTTP header
53- let actor_id_str = if is_websocket {
48+ // Extract actor ID and token from WebSocket protocol or HTTP headers
49+ let ( actor_id_str, token ) = if is_websocket {
5450 // For WebSocket, parse the sec-websocket-protocol header
55- headers
51+ let protocols_header = headers
5652 . get ( SEC_WEBSOCKET_PROTOCOL )
5753 . and_then ( |protocols| protocols. to_str ( ) . ok ( ) )
58- . and_then ( |protocols| {
59- // Parse protocols to find actor.{id}
60- protocols
61- . split ( ',' )
62- . map ( |p| p. trim ( ) )
63- . find_map ( |p| p. strip_prefix ( WS_PROTOCOL_ACTOR ) )
64- } )
54+ . ok_or_else ( || {
55+ crate :: errors:: MissingHeader {
56+ header : "sec-websocket-protocol" . to_string ( ) ,
57+ }
58+ . build ( )
59+ } ) ?;
60+
61+ let protocols: Vec < & str > = protocols_header. split ( ',' ) . map ( |p| p. trim ( ) ) . collect ( ) ;
62+
63+ let actor_id = protocols
64+ . iter ( )
65+ . find_map ( |p| p. strip_prefix ( WS_PROTOCOL_ACTOR ) )
6566 . ok_or_else ( || {
6667 crate :: errors:: MissingHeader {
6768 header : "`rivet_actor.*` protocol in sec-websocket-protocol" . to_string ( ) ,
6869 }
6970 . build ( )
70- } ) ?
71+ } ) ?;
72+
73+ let token = protocols
74+ . iter ( )
75+ . find_map ( |p| p. strip_prefix ( WS_PROTOCOL_TOKEN ) ) ;
76+
77+ ( actor_id, token)
7178 } else {
72- // For HTTP, use the x-rivet-actor header
73- headers
79+ // For HTTP, use headers
80+ let actor_id = headers
7481 . get ( X_RIVET_ACTOR )
7582 . map ( |x| x. to_str ( ) )
7683 . transpose ( )
@@ -80,21 +87,32 @@ pub async fn route_request(
8087 header : X_RIVET_ACTOR . to_string ( ) ,
8188 }
8289 . build ( )
83- } ) ?
90+ } ) ?;
91+
92+ let token = headers
93+ . get ( X_RIVET_TOKEN )
94+ . map ( |x| x. to_str ( ) )
95+ . transpose ( )
96+ . context ( "invalid x-rivet-token header" ) ?;
97+
98+ ( actor_id, token)
8499 } ;
85100
86101 // Find actor to route to
87102 let actor_id = Id :: parse ( actor_id_str) . context ( "invalid x-rivet-actor header" ) ?;
88103
89- route_request_inner ( ctx, shared_state, actor_id, path) . await
104+ route_request_inner ( ctx, shared_state, actor_id, path, token ) . await
90105}
91106
92107async fn route_request_inner (
93108 ctx : & StandaloneCtx ,
94109 shared_state : & SharedState ,
95110 actor_id : Id ,
96111 path : & str ,
112+ _token : Option < & str > ,
97113) -> Result < Option < RoutingOutput > > {
114+ // NOTE: Token validation implemented in EE
115+
98116 // Route to peer dc where the actor lives
99117 if actor_id. label ( ) != ctx. config ( ) . dc_label ( ) {
100118 tracing:: debug!( peer_dc_label=?actor_id. label( ) , "re-routing actor to peer dc" ) ;
0 commit comments