1- use  std:: { env,  net:: SocketAddr ,  sync:: Arc ,  time:: Duration } ; 
1+ use  std:: { env,  io :: Cursor ,   net:: SocketAddr ,  path :: PathBuf ,  sync:: Arc ,  time:: Duration } ; 
22
3+ use  anyhow:: * ; 
4+ use  bytes:: Bytes ; 
35use  futures_util:: { SinkExt ,  StreamExt } ; 
6+ use  serde:: { de:: DeserializeOwned ,  Serialize } ; 
47use  serde_json:: json; 
5- use  tokio:: sync:: Mutex ; 
6- use  tokio_tungstenite:: { connect_async,  tungstenite:: protocol:: Message } ; 
7- use  uuid:: Uuid ; 
8+ use  tokio:: { net:: UnixStream ,  sync:: Mutex } ; 
9+ use  tokio_util:: codec:: { Framed ,  LengthDelimitedCodec } ; 
810use  warp:: Filter ; 
911
1012const  PING_INTERVAL :  Duration  = Duration :: from_secs ( 1 ) ; 
@@ -18,20 +20,20 @@ async fn main() {
1820	} 
1921
2022	// Get manager connection details from env vars 
21- 	let  manager_ip  = env :: var ( "RIVET_MANAGER_IP" ) . expect ( "RIVET_MANAGER_IP not set" ) ; 
22- 	let  manager_port =  env:: var ( "RIVET_MANAGER_PORT " ) . expect ( "RIVET_MANAGER_PORT  not set" ) ; 
23- 	let  manager_addr =  format ! ( "ws://{}:{}" ,  manager_ip ,  manager_port ) ; 
23+ 	let  manager_socket_path  = PathBuf :: from ( 
24+ 		 env:: var ( "RIVET_MANAGER_SOCKET_PATH " ) . expect ( "RIVET_MANAGER_SOCKET_PATH  not set" ) , 
25+ 	) ; 
2426
2527	// Get HTTP server port from env var or use default 
2628	let  http_port = env:: var ( "PORT_MAIN" ) 
2729		. expect ( "PORT_MAIN not set" ) 
2830		. parse :: < u16 > ( ) 
2931		. expect ( "bad PORT_MAIN" ) ; 
3032
31- 	// Spawn the WebSocket  client 
33+ 	// Spawn the unix socket  client 
3234	tokio:: spawn ( async  move  { 
33- 		if  let  Err ( e)  = run_websocket_client ( & manager_addr ) . await  { 
34- 			eprintln ! ( "WebSocket  client error: {}" ,  e) ; 
35+ 		if  let  Err ( e)  = run_socket_client ( manager_socket_path ) . await  { 
36+ 			eprintln ! ( "Socket  client error: {}" ,  e) ; 
3537		} 
3638	} ) ; 
3739
@@ -53,25 +55,28 @@ async fn main() {
5355	warp:: serve ( echo) . run ( http_addr) . await ; 
5456} 
5557
56- async  fn  run_websocket_client ( url :   & str )  -> Result < ( ) ,   Box < dyn  std :: error :: Error > >  { 
57- 	println ! ( "Connecting to WebSocket  at {}" ,  url ) ; 
58+ async  fn  run_socket_client ( socket_path :   PathBuf )  -> Result < ( ) >  { 
59+ 	println ! ( "Connecting to socket  at {}" ,  socket_path . display ( ) ) ; 
5860
59- 	// Connect to the WebSocket  server 
60- 	let  ( ws_stream ,  _ )  =  connect_async ( url ) . await ?; 
61- 	println ! ( "WebSocket  connection established" ) ; 
61+ 	// Connect to the socket  server 
62+ 	let  stream =  UnixStream :: connect ( socket_path ) . await ?; 
63+ 	println ! ( "Socket  connection established" ) ; 
6264
63- 	// Split the stream 
64- 	let  ( mut  write,  mut  read)  = ws_stream. split ( ) ; 
65+ 	let  codec = LengthDelimitedCodec :: builder ( ) 
66+ 		. length_field_type :: < u32 > ( ) 
67+ 		. length_field_length ( 4 ) 
68+ 		// No offset 
69+ 		. length_field_offset ( 0 ) 
70+ 		// Header length is not included in the length calculation 
71+ 		. length_adjustment ( 4 ) 
72+ 		// header is included in the returned bytes 
73+ 		. num_skip ( 0 ) 
74+ 		. new_codec ( ) ; 
6575
66- 	let  payload = json ! ( { 
67- 		"init" :  { 
68- 			"access_token" :  env:: var( "RIVET_ACCESS_TOKEN" ) . expect( "RIVET_ACCESS_TOKEN not set" ) , 
69- 		} , 
70- 	} ) ; 
76+ 	let  framed = Framed :: new ( stream,  codec) ; 
7177
72- 	let  data = serde_json:: to_vec ( & payload) ?; 
73- 	write. send ( Message :: Binary ( data) ) . await ?; 
74- 	println ! ( "Sent init message" ) ; 
78+ 	// Split the stream 
79+ 	let  ( write,  mut  read)  = framed. split ( ) ; 
7580
7681	// Ping thread 
7782	let  write = Arc :: new ( Mutex :: new ( write) ) ; 
@@ -80,10 +85,14 @@ async fn run_websocket_client(url: &str) -> Result<(), Box<dyn std::error::Error
8085		loop  { 
8186			tokio:: time:: sleep ( PING_INTERVAL ) . await ; 
8287
88+ 			let  payload = json ! ( { 
89+ 				"ping" :  { } 
90+ 			} ) ; 
91+ 
8392			if  write2
8493				. lock ( ) 
8594				. await 
86- 				. send ( Message :: Ping ( Vec :: new ( ) ) ) 
95+ 				. send ( encode_frame ( & payload ) . unwrap ( ) ) 
8796				. await 
8897				. is_err ( ) 
8998			{ 
@@ -93,53 +102,61 @@ async fn run_websocket_client(url: &str) -> Result<(), Box<dyn std::error::Error
93102	} ) ; 
94103
95104	// Process incoming messages 
96- 	while  let  Some ( message)  = read. next ( ) . await  { 
97- 		match  message { 
98- 			Ok ( msg)  => match  msg { 
99- 				Message :: Pong ( _)  => { } 
100- 				Message :: Binary ( buf)  => { 
101- 					let  packet = serde_json:: from_slice :: < serde_json:: Value > ( & buf) ?; 
102- 					println ! ( "Received packet: {packet:?}" ) ; 
103- 
104- 					if  let  Some ( packet)  = packet. get ( "start_actor" )  { 
105- 						let  payload = json ! ( { 
106- 							"actor_state_update" :  { 
107- 								"actor_id" :  packet[ "actor_id" ] , 
108- 								"generation" :  packet[ "generation" ] , 
109- 								"state" :  { 
110- 									"running" :  null, 
111- 								} , 
112- 							} , 
113- 						} ) ; 
114- 
115- 						let  data = serde_json:: to_vec ( & payload) ?; 
116- 						write. lock ( ) . await . send ( Message :: Binary ( data) ) . await ?; 
117- 					}  else  if  let  Some ( packet)  = packet. get ( "signal_actor" )  { 
118- 						let  payload = json ! ( { 
119- 							"actor_state_update" :  { 
120- 								"actor_id" :  packet[ "actor_id" ] , 
121- 								"generation" :  packet[ "generation" ] , 
122- 								"state" :  { 
123- 									"exited" :  { 
124- 										"exit_code" :  null, 
125- 									} , 
126- 								} , 
127- 							} , 
128- 						} ) ; 
129- 
130- 						let  data = serde_json:: to_vec ( & payload) ?; 
131- 						write. lock ( ) . await . send ( Message :: Binary ( data) ) . await ?; 
132- 					} 
133- 				} 
134- 				msg => eprintln ! ( "Unexpected message: {msg:?}" ) , 
135- 			} , 
136- 			Err ( e)  => { 
137- 				eprintln ! ( "Error reading message: {}" ,  e) ; 
138- 				break ; 
139- 			} 
105+ 	while  let  Some ( frame)  = read. next ( ) . await . transpose ( ) ? { 
106+ 		let  ( _,  packet)  = decode_frame :: < serde_json:: Value > ( & frame. freeze ( ) ) ?; 
107+ 		println ! ( "Received packet: {packet:?}" ) ; 
108+ 
109+ 		if  let  Some ( packet)  = packet. get ( "start_actor" )  { 
110+ 			let  payload = json ! ( { 
111+ 				"actor_state_update" :  { 
112+ 					"actor_id" :  packet[ "actor_id" ] , 
113+ 					"generation" :  packet[ "generation" ] , 
114+ 					"state" :  { 
115+ 						"running" :  null, 
116+ 					} , 
117+ 				} , 
118+ 			} ) ; 
119+ 
120+ 			write. lock ( ) . await . send ( encode_frame ( & payload) ?) . await ?; 
121+ 		}  else  if  let  Some ( packet)  = packet. get ( "signal_actor" )  { 
122+ 			let  payload = json ! ( { 
123+ 				"actor_state_update" :  { 
124+ 					"actor_id" :  packet[ "actor_id" ] , 
125+ 					"generation" :  packet[ "generation" ] , 
126+ 					"state" :  { 
127+ 						"exited" :  { 
128+ 							"exit_code" :  null, 
129+ 						} , 
130+ 					} , 
131+ 				} , 
132+ 			} ) ; 
133+ 
134+ 			write. lock ( ) . await . send ( encode_frame ( & payload) ?) . await ?; 
140135		} 
141136	} 
142137
143- 	println ! ( "WebSocket  connection closed" ) ; 
138+ 	println ! ( "Socket  connection closed" ) ; 
144139	Ok ( ( ) ) 
145140} 
141+ 
142+ fn  decode_frame < T :  DeserializeOwned > ( frame :  & Bytes )  -> Result < ( [ u8 ;  4 ] ,  T ) >  { 
143+ 	ensure ! ( frame. len( )  >= 4 ,  "Frame too short" ) ; 
144+ 
145+ 	// Extract the header (first 4 bytes) 
146+ 	let  header = [ frame[ 0 ] ,  frame[ 1 ] ,  frame[ 2 ] ,  frame[ 3 ] ] ; 
147+ 
148+ 	// Deserialize the rest of the frame (payload after the header) 
149+ 	let  payload = serde_json:: from_slice ( & frame[ 4 ..] ) ?; 
150+ 
151+ 	Ok ( ( header,  payload) ) 
152+ } 
153+ 
154+ fn  encode_frame < T :  Serialize > ( payload :  & T )  -> Result < Bytes >  { 
155+ 	let  mut  buf = Vec :: with_capacity ( 4 ) ; 
156+ 	buf. extend_from_slice ( & [ 0u8 ;  4 ] ) ;  // header (currently unused) 
157+ 
158+ 	let  mut  cursor = Cursor :: new ( & mut  buf) ; 
159+ 	serde_json:: to_writer ( & mut  cursor,  payload) ?; 
160+ 
161+ 	Ok ( buf. into ( ) ) 
162+ } 
0 commit comments