1- use std:: { env, net:: SocketAddr , sync:: Arc , time:: Duration } ;
1+ use std:: { env, io :: Cursor , net:: SocketAddr , path :: PathBuf , sync:: Arc , time:: Duration } ;
22
3+ use anyhow:: * ;
4+ use bytes:: Bytes ;
35use futures_util:: { SinkExt , StreamExt } ;
6+ use serde:: { de:: DeserializeOwned , Serialize } ;
47use serde_json:: json;
5- use tokio:: sync:: Mutex ;
6- use tokio_tungstenite:: { connect_async, tungstenite:: protocol:: Message } ;
7- use uuid:: Uuid ;
8+ use tokio:: { net:: UnixStream , sync:: Mutex } ;
9+ use tokio_util:: codec:: { Framed , LengthDelimitedCodec } ;
810use warp:: Filter ;
911
1012const PING_INTERVAL : Duration = Duration :: from_secs ( 1 ) ;
@@ -18,20 +20,20 @@ async fn main() {
1820 }
1921
2022 // Get manager connection details from env vars
21- let manager_ip = env :: var ( "RIVET_MANAGER_IP" ) . expect ( "RIVET_MANAGER_IP not set" ) ;
22- let manager_port = env:: var ( "RIVET_MANAGER_PORT " ) . expect ( "RIVET_MANAGER_PORT not set" ) ;
23- let manager_addr = format ! ( "ws://{}:{}" , manager_ip , manager_port ) ;
23+ let manager_socket_path = PathBuf :: from (
24+ env:: var ( "RIVET_MANAGER_SOCKET_PATH " ) . expect ( "RIVET_MANAGER_SOCKET_PATH not set" ) ,
25+ ) ;
2426
2527 // Get HTTP server port from env var or use default
2628 let http_port = env:: var ( "PORT_MAIN" )
2729 . expect ( "PORT_MAIN not set" )
2830 . parse :: < u16 > ( )
2931 . expect ( "bad PORT_MAIN" ) ;
3032
31- // Spawn the WebSocket client
33+ // Spawn the unix socket client
3234 tokio:: spawn ( async move {
33- if let Err ( e) = run_websocket_client ( & manager_addr ) . await {
34- eprintln ! ( "WebSocket client error: {}" , e) ;
35+ if let Err ( e) = run_socket_client ( manager_socket_path ) . await {
36+ eprintln ! ( "Socket client error: {}" , e) ;
3537 }
3638 } ) ;
3739
@@ -53,25 +55,28 @@ async fn main() {
5355 warp:: serve ( echo) . run ( http_addr) . await ;
5456}
5557
56- async fn run_websocket_client ( url : & str ) -> Result < ( ) , Box < dyn std :: error :: Error > > {
57- println ! ( "Connecting to WebSocket at {}" , url ) ;
58+ async fn run_socket_client ( socket_path : PathBuf ) -> Result < ( ) > {
59+ println ! ( "Connecting to socket at {}" , socket_path . display ( ) ) ;
5860
59- // Connect to the WebSocket server
60- let ( ws_stream , _ ) = connect_async ( url ) . await ?;
61- println ! ( "WebSocket connection established" ) ;
61+ // Connect to the socket server
62+ let stream = UnixStream :: connect ( socket_path ) . await ?;
63+ println ! ( "Socket connection established" ) ;
6264
63- // Split the stream
64- let ( mut write, mut read) = ws_stream. split ( ) ;
65+ let codec = LengthDelimitedCodec :: builder ( )
66+ . length_field_type :: < u32 > ( )
67+ . length_field_length ( 4 )
68+ // No offset
69+ . length_field_offset ( 0 )
70+ // Header length is not included in the length calculation
71+ . length_adjustment ( 4 )
72+ // header is included in the returned bytes
73+ . num_skip ( 0 )
74+ . new_codec ( ) ;
6575
66- let payload = json ! ( {
67- "init" : {
68- "access_token" : env:: var( "RIVET_ACCESS_TOKEN" ) . expect( "RIVET_ACCESS_TOKEN not set" ) ,
69- } ,
70- } ) ;
76+ let framed = Framed :: new ( stream, codec) ;
7177
72- let data = serde_json:: to_vec ( & payload) ?;
73- write. send ( Message :: Binary ( data) ) . await ?;
74- println ! ( "Sent init message" ) ;
78+ // Split the stream
79+ let ( write, mut read) = framed. split ( ) ;
7580
7681 // Ping thread
7782 let write = Arc :: new ( Mutex :: new ( write) ) ;
@@ -80,10 +85,14 @@ async fn run_websocket_client(url: &str) -> Result<(), Box<dyn std::error::Error
8085 loop {
8186 tokio:: time:: sleep ( PING_INTERVAL ) . await ;
8287
88+ let payload = json ! ( {
89+ "ping" : { }
90+ } ) ;
91+
8392 if write2
8493 . lock ( )
8594 . await
86- . send ( Message :: Ping ( Vec :: new ( ) ) )
95+ . send ( encode_frame ( & payload ) . unwrap ( ) )
8796 . await
8897 . is_err ( )
8998 {
@@ -93,53 +102,61 @@ async fn run_websocket_client(url: &str) -> Result<(), Box<dyn std::error::Error
93102 } ) ;
94103
95104 // Process incoming messages
96- while let Some ( message) = read. next ( ) . await {
97- match message {
98- Ok ( msg) => match msg {
99- Message :: Pong ( _) => { }
100- Message :: Binary ( buf) => {
101- let packet = serde_json:: from_slice :: < serde_json:: Value > ( & buf) ?;
102- println ! ( "Received packet: {packet:?}" ) ;
103-
104- if let Some ( packet) = packet. get ( "start_actor" ) {
105- let payload = json ! ( {
106- "actor_state_update" : {
107- "actor_id" : packet[ "actor_id" ] ,
108- "generation" : packet[ "generation" ] ,
109- "state" : {
110- "running" : null,
111- } ,
112- } ,
113- } ) ;
114-
115- let data = serde_json:: to_vec ( & payload) ?;
116- write. lock ( ) . await . send ( Message :: Binary ( data) ) . await ?;
117- } else if let Some ( packet) = packet. get ( "signal_actor" ) {
118- let payload = json ! ( {
119- "actor_state_update" : {
120- "actor_id" : packet[ "actor_id" ] ,
121- "generation" : packet[ "generation" ] ,
122- "state" : {
123- "exited" : {
124- "exit_code" : null,
125- } ,
126- } ,
127- } ,
128- } ) ;
129-
130- let data = serde_json:: to_vec ( & payload) ?;
131- write. lock ( ) . await . send ( Message :: Binary ( data) ) . await ?;
132- }
133- }
134- msg => eprintln ! ( "Unexpected message: {msg:?}" ) ,
135- } ,
136- Err ( e) => {
137- eprintln ! ( "Error reading message: {}" , e) ;
138- break ;
139- }
105+ while let Some ( frame) = read. next ( ) . await . transpose ( ) ? {
106+ let ( _, packet) = decode_frame :: < serde_json:: Value > ( & frame. freeze ( ) ) ?;
107+ println ! ( "Received packet: {packet:?}" ) ;
108+
109+ if let Some ( packet) = packet. get ( "start_actor" ) {
110+ let payload = json ! ( {
111+ "actor_state_update" : {
112+ "actor_id" : packet[ "actor_id" ] ,
113+ "generation" : packet[ "generation" ] ,
114+ "state" : {
115+ "running" : null,
116+ } ,
117+ } ,
118+ } ) ;
119+
120+ write. lock ( ) . await . send ( encode_frame ( & payload) ?) . await ?;
121+ } else if let Some ( packet) = packet. get ( "signal_actor" ) {
122+ let payload = json ! ( {
123+ "actor_state_update" : {
124+ "actor_id" : packet[ "actor_id" ] ,
125+ "generation" : packet[ "generation" ] ,
126+ "state" : {
127+ "exited" : {
128+ "exit_code" : null,
129+ } ,
130+ } ,
131+ } ,
132+ } ) ;
133+
134+ write. lock ( ) . await . send ( encode_frame ( & payload) ?) . await ?;
140135 }
141136 }
142137
143- println ! ( "WebSocket connection closed" ) ;
138+ println ! ( "Socket connection closed" ) ;
144139 Ok ( ( ) )
145140}
141+
142+ fn decode_frame < T : DeserializeOwned > ( frame : & Bytes ) -> Result < ( [ u8 ; 4 ] , T ) > {
143+ ensure ! ( frame. len( ) >= 4 , "Frame too short" ) ;
144+
145+ // Extract the header (first 4 bytes)
146+ let header = [ frame[ 0 ] , frame[ 1 ] , frame[ 2 ] , frame[ 3 ] ] ;
147+
148+ // Deserialize the rest of the frame (payload after the header)
149+ let payload = serde_json:: from_slice ( & frame[ 4 ..] ) ?;
150+
151+ Ok ( ( header, payload) )
152+ }
153+
154+ fn encode_frame < T : Serialize > ( payload : & T ) -> Result < Bytes > {
155+ let mut buf = Vec :: with_capacity ( 4 ) ;
156+ buf. extend_from_slice ( & [ 0u8 ; 4 ] ) ; // header (currently unused)
157+
158+ let mut cursor = Cursor :: new ( & mut buf) ;
159+ serde_json:: to_writer ( & mut cursor, payload) ?;
160+
161+ Ok ( buf. into ( ) )
162+ }
0 commit comments