Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ async fn get_dcs(ctx: OperationCtx, labels: Vec<u16>) -> GlobalResult<Vec<Datace
FROM db_cluster.datacenters@datacenter_label_idx
WHERE label = ANY($1)
",
labels.into_iter().map(|x| x as i64).collect::<Vec<_>>(),
labels.into_iter().map(|x| x.to_be_bytes()).collect::<Vec<_>>(),
)
.await?;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ pub mod lz4 {
indoc!(
r#"
echo 'Downloading lz4'
curl -L https://releases.rivet.gg/tools/lz4/1.10.0/debian11-amd64/lz4 -o /usr/local/bin/lz4
curl -Lfo /usr/local/bin/lz4 https://releases.rivet.gg/tools/lz4/1.10.0/debian11-amd64/lz4
chmod +x /usr/local/bin/lz4
"#
)
Expand Down Expand Up @@ -286,7 +286,7 @@ pub mod umoci {
indoc!(
r#"
echo 'Downloading umoci'
curl -Lf -o /usr/bin/umoci "https://github.com/opencontainers/umoci/releases/download/v0.4.7/umoci.amd64"
curl -Lfo /usr/bin/umoci "https://github.com/opencontainers/umoci/releases/download/v0.4.7/umoci.amd64"
chmod +x /usr/bin/umoci
"#
).to_string()
Expand All @@ -300,7 +300,7 @@ pub mod cni {
indoc!(
r#"
echo 'Downloading cnitool'
curl -Lf -o /usr/bin/cnitool "https://github.com/rivet-gg/cni/releases/download/v1.1.2-build3/cnitool"
curl -Lfo /usr/bin/cnitool "https://github.com/rivet-gg/cni/releases/download/v1.1.2-build3/cnitool"
chmod +x /usr/bin/cnitool
"#
).to_string()
Expand Down
4 changes: 3 additions & 1 deletion packages/edge/infra/client/echo/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ authors = ["Rivet Gaming, LLC <[email protected]>"]
license = "Apache-2.0"

[dependencies]
anyhow = "1.0"
bytes = "1.0"
futures-util = "0.3"
http = "0.2"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1.40", features = ["full",] }
tokio-tungstenite = "0.23.1"
tokio-util = "0.7"
uuid = { version = "1", features = ["v4", "serde"] }
warp = "0.3.7"
159 changes: 88 additions & 71 deletions packages/edge/infra/client/echo/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::{env, net::SocketAddr, sync::Arc, time::Duration};
use std::{env, io::Cursor, net::SocketAddr, path::PathBuf, sync::Arc, time::Duration};

use anyhow::*;
use bytes::Bytes;
use futures_util::{SinkExt, StreamExt};
use serde::{de::DeserializeOwned, Serialize};
use serde_json::json;
use tokio::sync::Mutex;
use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
use uuid::Uuid;
use tokio::{net::UnixStream, sync::Mutex};
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use warp::Filter;

const PING_INTERVAL: Duration = Duration::from_secs(1);
Expand All @@ -18,20 +20,20 @@ async fn main() {
}

// Get manager connection details from env vars
let manager_ip = env::var("RIVET_MANAGER_IP").expect("RIVET_MANAGER_IP not set");
let manager_port = env::var("RIVET_MANAGER_PORT").expect("RIVET_MANAGER_PORT not set");
let manager_addr = format!("ws://{}:{}", manager_ip, manager_port);
let manager_socket_path = PathBuf::from(
env::var("RIVET_MANAGER_SOCKET_PATH").expect("RIVET_MANAGER_SOCKET_PATH not set"),
);

// Get HTTP server port from env var or use default
let http_port = env::var("PORT_MAIN")
.expect("PORT_MAIN not set")
.parse::<u16>()
.expect("bad PORT_MAIN");

// Spawn the WebSocket client
// Spawn the unix socket client
tokio::spawn(async move {
if let Err(e) = run_websocket_client(&manager_addr).await {
eprintln!("WebSocket client error: {}", e);
if let Err(e) = run_socket_client(manager_socket_path).await {
eprintln!("Socket client error: {}", e);
}
});

Expand All @@ -53,25 +55,28 @@ async fn main() {
warp::serve(echo).run(http_addr).await;
}

async fn run_websocket_client(url: &str) -> Result<(), Box<dyn std::error::Error>> {
println!("Connecting to WebSocket at {}", url);
async fn run_socket_client(socket_path: PathBuf) -> Result<()> {
println!("Connecting to socket at {}", socket_path.display());

// Connect to the WebSocket server
let (ws_stream, _) = connect_async(url).await?;
println!("WebSocket connection established");
// Connect to the socket server
let stream = UnixStream::connect(socket_path).await?;
println!("Socket connection established");

// Split the stream
let (mut write, mut read) = ws_stream.split();
let codec = LengthDelimitedCodec::builder()
.length_field_type::<u32>()
.length_field_length(4)
// No offset
.length_field_offset(0)
// Header length is not included in the length calculation
.length_adjustment(4)
// header is included in the returned bytes
.num_skip(0)
.new_codec();

let payload = json!({
"init": {
"access_token": env::var("RIVET_ACCESS_TOKEN").expect("RIVET_ACCESS_TOKEN not set"),
},
});
let framed = Framed::new(stream, codec);

let data = serde_json::to_vec(&payload)?;
write.send(Message::Binary(data)).await?;
println!("Sent init message");
// Split the stream
let (write, mut read) = framed.split();

// Ping thread
let write = Arc::new(Mutex::new(write));
Expand All @@ -80,10 +85,14 @@ async fn run_websocket_client(url: &str) -> Result<(), Box<dyn std::error::Error
loop {
tokio::time::sleep(PING_INTERVAL).await;

let payload = json!({
"ping": {}
});

if write2
.lock()
.await
.send(Message::Ping(Vec::new()))
.send(encode_frame(&payload).unwrap())
.await
.is_err()
{
Expand All @@ -93,53 +102,61 @@ async fn run_websocket_client(url: &str) -> Result<(), Box<dyn std::error::Error
});

// Process incoming messages
while let Some(message) = read.next().await {
match message {
Ok(msg) => match msg {
Message::Pong(_) => {}
Message::Binary(buf) => {
let packet = serde_json::from_slice::<serde_json::Value>(&buf)?;
println!("Received packet: {packet:?}");

if let Some(packet) = packet.get("start_actor") {
let payload = json!({
"actor_state_update": {
"actor_id": packet["actor_id"],
"generation": packet["generation"],
"state": {
"running": null,
},
},
});

let data = serde_json::to_vec(&payload)?;
write.lock().await.send(Message::Binary(data)).await?;
} else if let Some(packet) = packet.get("signal_actor") {
let payload = json!({
"actor_state_update": {
"actor_id": packet["actor_id"],
"generation": packet["generation"],
"state": {
"exited": {
"exit_code": null,
},
},
},
});

let data = serde_json::to_vec(&payload)?;
write.lock().await.send(Message::Binary(data)).await?;
}
}
msg => eprintln!("Unexpected message: {msg:?}"),
},
Err(e) => {
eprintln!("Error reading message: {}", e);
break;
}
while let Some(frame) = read.next().await.transpose()? {
let (_, packet) = decode_frame::<serde_json::Value>(&frame.freeze())?;
println!("Received packet: {packet:?}");

if let Some(packet) = packet.get("start_actor") {
let payload = json!({
"actor_state_update": {
"actor_id": packet["actor_id"],
"generation": packet["generation"],
"state": {
"running": null,
},
},
});

write.lock().await.send(encode_frame(&payload)?).await?;
} else if let Some(packet) = packet.get("signal_actor") {
let payload = json!({
"actor_state_update": {
"actor_id": packet["actor_id"],
"generation": packet["generation"],
"state": {
"exited": {
"exit_code": null,
},
},
},
});

write.lock().await.send(encode_frame(&payload)?).await?;
}
}

println!("WebSocket connection closed");
println!("Socket connection closed");
Ok(())
}

fn decode_frame<T: DeserializeOwned>(frame: &Bytes) -> Result<([u8; 4], T)> {
ensure!(frame.len() >= 4, "Frame too short");

// Extract the header (first 4 bytes)
let header = [frame[0], frame[1], frame[2], frame[3]];

// Deserialize the rest of the frame (payload after the header)
let payload = serde_json::from_slice(&frame[4..])?;

Ok((header, payload))
}

fn encode_frame<T: Serialize>(payload: &T) -> Result<Bytes> {
let mut buf = Vec::with_capacity(4);
buf.extend_from_slice(&[0u8; 4]); // header (currently unused)

let mut cursor = Cursor::new(&mut buf);
serde_json::to_writer(&mut cursor, payload)?;

Ok(buf.into())
}
Comment on lines +154 to +162
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The encode_frame function initializes a buffer with 4 zero bytes as a header, but these bytes are never populated with the actual length information. This appears inconsistent with the LengthDelimitedCodec configuration, which expects the first 4 bytes to contain the length of the payload.

Consider updating the header bytes with the appropriate length information before returning the buffer:

fn encode_frame<T: Serialize>(payload: &T) -> Result<Bytes> {
    let mut buf = Vec::with_capacity(4);
    buf.extend_from_slice(&[0u8; 4]); // Reserve space for length header
    
    // Write payload after header
    serde_json::to_writer(&mut buf, payload)?;
    
    // Calculate payload length (excluding header) and update header bytes
    let len = (buf.len() - 4) as u32;
    buf[0..4].copy_from_slice(&len.to_be_bytes());
    
    Ok(buf.into())
}
Suggested change
fn encode_frame<T: Serialize>(payload: &T) -> Result<Bytes> {
let mut buf = Vec::with_capacity(4);
buf.extend_from_slice(&[0u8; 4]); // header (currently unused)
let mut cursor = Cursor::new(&mut buf);
serde_json::to_writer(&mut cursor, payload)?;
Ok(buf.into())
}
fn encode_frame<T: Serialize>(payload: &T) -> Result<Bytes> {
let mut buf = Vec::with_capacity(4);
buf.extend_from_slice(&[0u8; 4]); // Reserve space for length header
// Write payload after header
serde_json::to_writer(&mut buf, payload)?;
// Calculate payload length (excluding header) and update header bytes
let len = (buf.len() - 4) as u32;
buf[0..4].copy_from_slice(&len.to_be_bytes());
Ok(buf.into())
}

Spotted by Diamond

Is this helpful? React 👍 or 👎 to let us know.

Comment on lines +154 to +162
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation of encode_frame initializes a buffer with 4 bytes for the header, but when writing the JSON payload with serde_json::to_writer, the cursor starts at position 0, which would overwrite the header bytes. To ensure the JSON payload is written after the header, consider setting the cursor position explicitly:

fn encode_frame<T: Serialize>(payload: &T) -> Result<Bytes> {
    let mut buf = Vec::with_capacity(4);
    buf.extend_from_slice(&[0u8; 4]); // header (currently unused)

    let mut cursor = Cursor::new(&mut buf);
    cursor.set_position(4); // Position cursor after header
    serde_json::to_writer(&mut cursor, payload)?;

    Ok(buf.into())
}

This ensures the header bytes are preserved and the JSON payload is appended correctly.

Suggested change
fn encode_frame<T: Serialize>(payload: &T) -> Result<Bytes> {
let mut buf = Vec::with_capacity(4);
buf.extend_from_slice(&[0u8; 4]); // header (currently unused)
let mut cursor = Cursor::new(&mut buf);
serde_json::to_writer(&mut cursor, payload)?;
Ok(buf.into())
}
fn encode_frame<T: Serialize>(payload: &T) -> Result<Bytes> {
let mut buf = Vec::with_capacity(4);
buf.extend_from_slice(&[0u8; 4]); // header (currently unused)
let mut cursor = Cursor::new(&mut buf);
cursor.set_position(4); // Position cursor after header
serde_json::to_writer(&mut cursor, payload)?;
Ok(buf.into())
}

Spotted by Diamond

Is this helpful? React 👍 or 👎 to let us know.

5 changes: 4 additions & 1 deletion packages/edge/infra/client/manager/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ async fn build_sqlite_pool(db_url: &str) -> Result<SqlitePool> {
.busy_timeout(Duration::from_secs(5))
// Enable foreign key constraint enforcement
.foreign_keys(true)
// Increases write performance
.journal_mode(SqliteJournalMode::Wal)
// Enable auto vacuuming and set it to incremental mode for gradual space reclaiming
.auto_vacuum(SqliteAutoVacuum::Incremental)
// Set synchronous mode to NORMAL for performance and data safety balance
Expand Down Expand Up @@ -241,7 +243,8 @@ async fn init_sqlite_schema(pool: &SqlitePool) -> Result<()> {
generation INTEGER NOT NULL,
config BLOB NOT NULL, -- JSONB

runner_id NOT NULL, -- Already exists in `config`, set here for ease of querying
-- Already exists in `config`, set here for ease of querying
runner_id BLOB NOT NULL, -- UUID

start_ts INTEGER NOT NULL,
running_ts INTEGER,
Expand Down
1 change: 0 additions & 1 deletion packages/edge/infra/guard/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ pub mod proxy_service;
pub mod request_context;
mod server;
pub mod types;
pub mod util;

pub use cert_resolver::CertResolverFn;
pub use proxy_service::{MiddlewareFn, ProxyService, ProxyState, RouteTarget, RoutingFn};
Expand Down
23 changes: 0 additions & 23 deletions packages/edge/infra/guard/core/src/util.rs

This file was deleted.

13 changes: 4 additions & 9 deletions packages/edge/infra/guard/server/src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,11 @@ pub async fn create_cert_resolver(
}
Ok(None) => {
tracing::warn!(
"Could not build dynamic hostname actor routing regex - pattern will be skipped"
);
"Could not build dynamic hostname actor routing regex - pattern will be skipped"
);
None
}
Err(err) => bail!(
"Failed to build dynamic hostname actor routing regex: {}",
err
),
Err(e) => bail!("Failed to build dynamic hostname actor routing regex: {}", e),
};
let actor_hostname_regex_static =
match build_actor_hostname_and_path_regex(EndpointType::Path, guard_hostname) {
Expand All @@ -178,9 +175,7 @@ pub async fn create_cert_resolver(
);
None
}
Err(e) => {
bail!("Failed to build static path actor routing regex: {}", e);
}
Err(e) => bail!("Failed to build static path actor routing regex: {}", e),
};

// Create resolver function that matches the routing logic
Expand Down
8 changes: 4 additions & 4 deletions packages/edge/services/pegboard/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use regex::Regex;
use crate::types::{EndpointType, GameGuardProtocol};

// Constants for regex patterns
const UUID_PATTERN: &str = r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}";
const ID_PATTERN: &str = r"[a-zA-Z0-9-]+";
const PORT_NAME_PATTERN: &str = r"[a-zA-Z0-9-_]+";

pub fn build_actor_hostname_and_path(
Expand Down Expand Up @@ -59,7 +59,7 @@ pub fn build_actor_hostname_and_path_regex(
// server in the subdomain is a convenience
(EndpointType::Hostname, GuardPublicHostname::DnsParent(dns_parent)) => {
let hostname_regex = Regex::new(&format!(
r"^(?P<actor_id>{UUID_PATTERN})-(?P<port_name>{PORT_NAME_PATTERN})\.actor\.{}$",
r"^(?P<actor_id>{ID_PATTERN})-(?P<port_name>{PORT_NAME_PATTERN})\.actor\.{}$",
regex::escape(dns_parent.as_str())
))?;
Ok(Some((hostname_regex, None)))
Expand All @@ -81,7 +81,7 @@ pub fn build_actor_hostname_and_path_regex(
))?;

let path_regex = Regex::new(&format!(
r"^/(?P<actor_id>{UUID_PATTERN})-(?P<port_name>{PORT_NAME_PATTERN})(?:/.*)?$"
r"^/(?P<actor_id>{ID_PATTERN})-(?P<port_name>{PORT_NAME_PATTERN})(?:/.*)?$"
))?;

Ok(Some((hostname_regex, Some(path_regex))))
Expand All @@ -91,7 +91,7 @@ pub fn build_actor_hostname_and_path_regex(
let hostname_regex = Regex::new(&format!(r"^{}$", regex::escape(static_.as_str())))?;

let path_regex = Regex::new(&format!(
r"^/(?P<actor_id>{UUID_PATTERN})-(?P<port_name>{PORT_NAME_PATTERN})(?:/.*)?$"
r"^/(?P<actor_id>{ID_PATTERN})-(?P<port_name>{PORT_NAME_PATTERN})(?:/.*)?$"
))?;

Ok(Some((hostname_regex, Some(path_regex))))
Expand Down
Loading