Skip to content

Commit e7c0c67

Browse files
committed
chore(guard): move token routing out off inner router for path-based routing
1 parent cc21086 commit e7c0c67

File tree

2 files changed

+42
-22
lines changed

2 files changed

+42
-22
lines changed

engine/packages/guard/src/routing/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ pub(crate) const X_RIVET_TOKEN: HeaderName = HeaderName::from_static("x-rivet-to
1616
pub(crate) const SEC_WEBSOCKET_PROTOCOL: HeaderName =
1717
HeaderName::from_static("sec-websocket-protocol");
1818
pub(crate) const WS_PROTOCOL_TARGET: &str = "rivet_target.";
19+
pub(crate) const WS_PROTOCOL_ACTOR: &str = "rivet_actor.";
20+
pub(crate) const WS_PROTOCOL_TOKEN: &str = "rivet_token.";
1921

2022
#[derive(Debug, Clone)]
2123
pub struct ActorPathInfo {

engine/packages/guard/src/routing/pegboard_gateway.rs

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,28 @@ use gas::prelude::*;
55
use hyper::header::HeaderName;
66
use rivet_guard_core::proxy_service::{RouteConfig, RouteTarget, RoutingOutput, RoutingTimeout};
77

8-
use super::SEC_WEBSOCKET_PROTOCOL;
8+
use super::{SEC_WEBSOCKET_PROTOCOL, WS_PROTOCOL_ACTOR, WS_PROTOCOL_TOKEN, X_RIVET_TOKEN};
99
use crate::{errors, shared_state::SharedState};
1010

1111
const ACTOR_READY_TIMEOUT: Duration = Duration::from_secs(10);
1212
pub const X_RIVET_ACTOR: HeaderName = HeaderName::from_static("x-rivet-actor");
1313
pub const X_RIVET_AMESPACE: HeaderName = HeaderName::from_static("x-rivet-namespace");
14-
const WS_PROTOCOL_ACTOR: &str = "rivet_actor.";
15-
const WS_PROTOCOL_TOKEN: &str = "rivet_token.";
1614

1715
/// Route requests to actor services using path-based routing
1816
#[tracing::instrument(skip_all)]
1917
pub async fn route_request_path_based(
2018
ctx: &StandaloneCtx,
2119
shared_state: &SharedState,
2220
actor_id_str: &str,
23-
_token: Option<&str>,
21+
token: Option<&str>,
2422
path: &str,
2523
_headers: &hyper::HeaderMap,
2624
_is_websocket: bool,
2725
) -> Result<Option<RoutingOutput>> {
28-
// NOTE: Token validation implemented in EE
29-
3026
// Parse actor ID
3127
let actor_id = Id::parse(actor_id_str).context("invalid actor id in path")?;
3228

33-
route_request_inner(ctx, shared_state, actor_id, path).await
29+
route_request_inner(ctx, shared_state, actor_id, path, token).await
3430
}
3531

3632
/// Route requests to actor services based on headers
@@ -49,28 +45,39 @@ pub async fn route_request(
4945
return Ok(None);
5046
}
5147

52-
// Extract actor ID from WebSocket protocol or HTTP header
53-
let actor_id_str = if is_websocket {
48+
// Extract actor ID and token from WebSocket protocol or HTTP headers
49+
let (actor_id_str, token) = if is_websocket {
5450
// For WebSocket, parse the sec-websocket-protocol header
55-
headers
51+
let protocols_header = headers
5652
.get(SEC_WEBSOCKET_PROTOCOL)
5753
.and_then(|protocols| protocols.to_str().ok())
58-
.and_then(|protocols| {
59-
// Parse protocols to find actor.{id}
60-
protocols
61-
.split(',')
62-
.map(|p| p.trim())
63-
.find_map(|p| p.strip_prefix(WS_PROTOCOL_ACTOR))
64-
})
54+
.ok_or_else(|| {
55+
crate::errors::MissingHeader {
56+
header: "sec-websocket-protocol".to_string(),
57+
}
58+
.build()
59+
})?;
60+
61+
let protocols: Vec<&str> = protocols_header.split(',').map(|p| p.trim()).collect();
62+
63+
let actor_id = protocols
64+
.iter()
65+
.find_map(|p| p.strip_prefix(WS_PROTOCOL_ACTOR))
6566
.ok_or_else(|| {
6667
crate::errors::MissingHeader {
6768
header: "`rivet_actor.*` protocol in sec-websocket-protocol".to_string(),
6869
}
6970
.build()
70-
})?
71+
})?;
72+
73+
let token = protocols
74+
.iter()
75+
.find_map(|p| p.strip_prefix(WS_PROTOCOL_TOKEN));
76+
77+
(actor_id, token)
7178
} else {
72-
// For HTTP, use the x-rivet-actor header
73-
headers
79+
// For HTTP, use headers
80+
let actor_id = headers
7481
.get(X_RIVET_ACTOR)
7582
.map(|x| x.to_str())
7683
.transpose()
@@ -80,21 +87,32 @@ pub async fn route_request(
8087
header: X_RIVET_ACTOR.to_string(),
8188
}
8289
.build()
83-
})?
90+
})?;
91+
92+
let token = headers
93+
.get(X_RIVET_TOKEN)
94+
.map(|x| x.to_str())
95+
.transpose()
96+
.context("invalid x-rivet-token header")?;
97+
98+
(actor_id, token)
8499
};
85100

86101
// Find actor to route to
87102
let actor_id = Id::parse(actor_id_str).context("invalid x-rivet-actor header")?;
88103

89-
route_request_inner(ctx, shared_state, actor_id, path).await
104+
route_request_inner(ctx, shared_state, actor_id, path, token).await
90105
}
91106

92107
async fn route_request_inner(
93108
ctx: &StandaloneCtx,
94109
shared_state: &SharedState,
95110
actor_id: Id,
96111
path: &str,
112+
_token: Option<&str>,
97113
) -> Result<Option<RoutingOutput>> {
114+
// NOTE: Token validation implemented in EE
115+
98116
// Route to peer dc where the actor lives
99117
if actor_id.label() != ctx.config().dc_label() {
100118
tracing::debug!(peer_dc_label=?actor_id.label(), "re-routing actor to peer dc");

0 commit comments

Comments
 (0)