diff --git a/async/websocket_async.ml b/async/websocket_async.ml index 229a7d6..cf19d95 100644 --- a/async/websocket_async.ml +++ b/async/websocket_async.ml @@ -210,7 +210,9 @@ let src = let server ?(name = "websocket.server") ?(check_request = fun _ -> Deferred.return true) - ?(select_protocol = fun _ -> None) ~reader ~writer ~app_to_ws ~ws_to_app () + ?(select_protocol = fun _ -> None) + ?max_frame_length + ~reader ~writer ~app_to_ws ~ws_to_app () = let handshake r w = (Request.read r >>= function @@ -276,7 +278,9 @@ let server ?(name = "websocket.server") handshake reader writer) |> Deferred.Or_error.bind ~f:(fun () -> set_tcp_nodelay writer; - let read_frame = make_read_frame ~mode:Server reader writer in + let read_frame = + make_read_frame ?max_len:max_frame_length ~mode:Server reader writer + in let rec loop () = read_frame () >>= Pipe.write ws_to_app >>= loop in let transfer_end = let buf = Buffer.create 128 in @@ -303,7 +307,9 @@ let server ?(name = "websocket.server") >>= Deferred.Or_error.return) let upgrade_connection ?(select_protocol = fun _ -> None) - ?(ping_interval = Time_ns.Span.of_int_sec 50) ~app_to_ws ~ws_to_app ~f + ?(ping_interval = Time_ns.Span.of_int_sec 50) + ?max_frame_length + ~app_to_ws ~ws_to_app ~f request = let headers = Cohttp.Request.headers request in let key = @@ -328,7 +334,9 @@ let upgrade_connection ?(select_protocol = fun _ -> None) () in let handler reader writer = - let read_frame = make_read_frame ~mode:Server reader writer in + let read_frame = + make_read_frame ?max_len:max_frame_length ~mode:Server reader writer + in let rec loop () = try_with read_frame >>= function | Error _ -> Deferred.unit diff --git a/async/websocket_async.mli b/async/websocket_async.mli index 79a4cca..31ab68f 100644 --- a/async/websocket_async.mli +++ b/async/websocket_async.mli @@ -55,13 +55,14 @@ val server : ?name:string -> ?check_request:(Cohttp.Request.t -> bool Deferred.t) -> ?select_protocol:(string -> string option) -> + ?max_frame_length:int -> reader:Reader.t -> writer:Writer.t -> app_to_ws:Frame.t Pipe.Reader.t -> ws_to_app:Frame.t Pipe.Writer.t -> unit -> unit Deferred.Or_error.t -(** [server ?request_cb reader writer app_to_ws +(** [server ?request_cb ?max_frame_length reader writer app_to_ws ws_to_app ()] returns a thread that expects a websocket client connected to [reader]/[writer] and, after performing the handshake, will resp. read outgoing frames from [app_to_ws] and @@ -71,17 +72,19 @@ val server : reception of the client HTTP request, [request_cb] will be called with the request as its argument. If [request_cb] returns true, the connection will proceed, otherwise, the result is immediately - determined to [Error Exit]. *) + determined to [Error Exit]. If [max_frame_length] is specified and + the server receives a frame above this size, the connection is closed. *) val upgrade_connection : ?select_protocol:(string -> string option) -> ?ping_interval:Core.Time_ns.Span.t -> + ?max_frame_length:int -> app_to_ws:Frame.t Pipe.Reader.t -> ws_to_app:Frame.t Pipe.Writer.t -> f:(unit -> unit Deferred.t) -> Cohttp.Request.t -> Cohttp.Response.t * (Reader.t -> Writer.t -> unit Deferred.t) -(** [upgrade_connection ?select_protocol ?ping_interval +(** [upgrade_connection ?select_protocol ?ping_interval ?max_frame_length app_to_ws ws_to_app f request] returns a {!Cohttp_async.Server.response_action}. Just wrap the return value of this function with [`Expert]. diff --git a/core/websocket.ml b/core/websocket.ml index 86ec131..222e577 100644 --- a/core/websocket.ml +++ b/core/websocket.ml @@ -155,6 +155,7 @@ module type S = sig type mode = Client of (int -> string) | Server val make_read_frame : + ?max_len:int -> ?buf:Buffer.t -> mode:mode -> IO.ic -> IO.oc -> unit -> Frame.t IO.t val write_frame_to_buf : mode:mode -> Buffer.t -> Frame.t -> unit @@ -177,6 +178,7 @@ module type S = sig type t val create : + ?max_len:int -> ?read_buf:Buffer.t -> ?write_buf:Buffer.t -> Cohttp.Request.t -> @@ -265,7 +267,7 @@ module Make (IO : Cohttp.S.IO) = struct write_frame_to_buf ~mode buf @@ Frame.close code; write oc @@ Buffer.contents buf - let read_frame ic oc buf mode hdr = + let read_frame ?max_len ic oc buf mode hdr = let hdr_part1 = EndianString.BigEndian.get_int8 hdr 0 in let hdr_part2 = EndianString.BigEndian.get_int8 hdr 1 in let final = is_bit_set 7 hdr_part1 in @@ -291,6 +293,10 @@ module Make (IO : Cohttp.S.IO) = struct else if extension <> 0 then close_with_code mode buf oc 1002 >>= fun () -> proto_error "unsupported extension" + else if (match max_len with Some max -> payload_len > max | None -> false) + then + close_with_code mode buf oc 1009 >>= fun () -> + proto_error "frame payload too big" else if Frame.Opcode.is_ctrl opcode && payload_len > 125 then close_with_code mode buf oc 1002 >>= fun () -> proto_error "control frame too big" @@ -315,11 +321,11 @@ module Make (IO : Cohttp.S.IO) = struct let frame = Frame.of_bytes ~opcode ~extension ~final payload in return frame) - let make_read_frame ?(buf = Buffer.create 128) ~mode ic oc () = + let make_read_frame ?max_len ?(buf = Buffer.create 128) ~mode ic oc () = Buffer.clear buf; read_exactly ic 2 buf >>= function | None -> raise End_of_file - | Some hdr -> read_frame ic oc buf mode hdr + | Some hdr -> read_frame ?max_len ic oc buf mode hdr module Request = Cohttp.Request.Make (IO) module Response = Cohttp.Response.Make (IO) @@ -337,9 +343,9 @@ module Make (IO : Cohttp.S.IO) = struct let source { endp; _ } = endp - let create ?read_buf ?(write_buf = Buffer.create 128) http_request endp ic + let create ?max_len ?read_buf ?(write_buf = Buffer.create 128) http_request endp ic oc = - let read_frame = make_read_frame ?buf:read_buf ~mode:Server ic oc in + let read_frame = make_read_frame ?max_len ?buf:read_buf ~mode:Server ic oc in { buffer = write_buf; endp; diff --git a/core/websocket.mli b/core/websocket.mli index 04f1689..359761b 100644 --- a/core/websocket.mli +++ b/core/websocket.mli @@ -84,6 +84,7 @@ module type S = sig type mode = Client of (int -> string) | Server val make_read_frame : + ?max_len:int -> ?buf:Buffer.t -> mode:mode -> IO.ic -> IO.oc -> unit -> Frame.t IO.t val write_frame_to_buf : mode:mode -> Buffer.t -> Frame.t -> unit @@ -106,6 +107,7 @@ module type S = sig type t val create : + ?max_len:int -> ?read_buf:Buffer.t -> ?write_buf:Buffer.t -> Cohttp.Request.t -> diff --git a/lwt/websocket_cohttp_lwt.ml b/lwt/websocket_cohttp_lwt.ml index 434ae3e..c9de49e 100644 --- a/lwt/websocket_cohttp_lwt.ml +++ b/lwt/websocket_cohttp_lwt.ml @@ -29,12 +29,12 @@ let send_frames stream oc = in Lwt_stream.iter_s send_frame stream -let read_frames ic oc handler_fn = - let read_frame = Lwt_IO.make_read_frame ~mode:Server ic oc in +let read_frames ?max_len ic oc handler_fn = + let read_frame = Lwt_IO.make_read_frame ?max_len ~mode:Server ic oc in let rec inner () = read_frame () >>= Lwt.wrap1 handler_fn >>= inner in inner () -let upgrade_connection request incoming_handler = +let upgrade_connection ?max_frame_length request incoming_handler = let headers = Cohttp.Request.headers request in (match Cohttp.Header.get headers "sec-websocket-key" with | None -> @@ -61,7 +61,7 @@ let upgrade_connection request incoming_handler = [ (* input: data from the client is read from the input channel * of the tcp connection; pass it to handler function *) - read_frames ic oc incoming_handler; + read_frames ?max_len:max_frame_length ic oc incoming_handler; (* output: data for the client is written to the output * channel of the tcp connection *) send_frames frames_out_stream oc; diff --git a/lwt/websocket_cohttp_lwt.mli b/lwt/websocket_cohttp_lwt.mli index b05cf57..3043f0f 100644 --- a/lwt/websocket_cohttp_lwt.mli +++ b/lwt/websocket_cohttp_lwt.mli @@ -19,6 +19,7 @@ open Websocket val upgrade_connection : + ?max_frame_length:int -> Cohttp.Request.t -> (Frame.t -> unit) -> (Cohttp_lwt_unix.Server.response_action * (Frame.t option -> unit)) Lwt.t diff --git a/lwt/websocket_lwt_unix.ml b/lwt/websocket_lwt_unix.ml index 4080aa7..94943e1 100644 --- a/lwt/websocket_lwt_unix.ml +++ b/lwt/websocket_lwt_unix.ml @@ -101,11 +101,12 @@ let write { write_frame; _ } frame = write_frame frame let close_transport { oc; _ } = Lwt_io.close oc let connect ?(extra_headers = Cohttp.Header.init ()) + ?max_frame_length ?(random_string = Websocket.Rng.init ()) ?(ctx = Lazy.force Conduit_lwt_unix.default_ctx) ?buf client url = let nonce = Base64.encode_exn (random_string 16) in connect ctx client url nonce extra_headers >|= fun (ic, oc) -> - let read_frame = make_read_frame ?buf ~mode:(Client random_string) ic oc in + let read_frame = make_read_frame ?max_len:max_frame_length ?buf ~mode:(Client random_string) ic oc in let read_frame () = Lwt.catch read_frame (fun exn -> Lwt.async (fun () -> Lwt_io.close ic); @@ -135,7 +136,7 @@ let write_failed_response oc = let open Response in write ~flush:true (fun writer -> write_body writer body) response oc -let server_fun ?read_buf ?write_buf check_request flow ic oc react = +let server_fun ?max_len ?read_buf ?write_buf check_request flow ic oc react = let read = function | `Ok r -> Lwt.return r | `Eof -> @@ -181,11 +182,11 @@ let server_fun ?read_buf ?write_buf check_request flow ic oc react = in Response.write (fun _writer -> Lwt.return_unit) response oc >>= fun () -> let client = - Connected_client.create ?read_buf ?write_buf request flow ic oc + Connected_client.create ?max_len ?read_buf ?write_buf request flow ic oc in react client -let establish_server ?read_buf ?write_buf ?timeout ?stop +let establish_server ?max_frame_length ?read_buf ?write_buf ?timeout ?stop ?(on_exn = fun exn -> !Lwt.async_exception_hook exn) ?(check_request = check_origin_with_host) ?(ctx = Lazy.force Conduit_lwt_unix.default_ctx) ~mode react = @@ -194,7 +195,8 @@ let establish_server ?read_buf ?write_buf ?timeout ?stop set_tcp_nodelay flow; Lwt.catch (fun () -> - server_fun ?read_buf ?write_buf check_request + server_fun ?max_len:max_frame_length ?read_buf ?write_buf + check_request (Conduit_lwt_unix.endp_of_flow flow) ic oc react) (function @@ -211,9 +213,11 @@ let mk_frame_stream recv = in Lwt_stream.from f -let establish_standard_server ?read_buf ?write_buf ?timeout ?stop ?on_exn +let establish_standard_server ?max_frame_length ?read_buf ?write_buf ?timeout + ?stop ?on_exn ?check_request ?(ctx = Lazy.force Conduit_lwt_unix.default_ctx) ~mode react = let f client = react (Connected_client.make_standard client) in - establish_server ?read_buf ?write_buf ?timeout ?stop ?on_exn ?check_request + establish_server ?max_frame_length ?read_buf ?write_buf ?timeout ?stop + ?on_exn ?check_request ~ctx ~mode f diff --git a/lwt/websocket_lwt_unix.mli b/lwt/websocket_lwt_unix.mli index 2ad2736..96d9b08 100644 --- a/lwt/websocket_lwt_unix.mli +++ b/lwt/websocket_lwt_unix.mli @@ -39,6 +39,7 @@ val close_transport : conn -> unit Lwt.t val connect : ?extra_headers:Cohttp.Header.t -> + ?max_frame_length:int -> ?random_string:(int -> string) -> ?ctx:Conduit_lwt_unix.ctx -> ?buf:Buffer.t -> @@ -47,6 +48,7 @@ val connect : conn Lwt.t val establish_server : + ?max_frame_length:int -> ?read_buf:Buffer.t -> ?write_buf:Buffer.t -> ?timeout:int -> @@ -71,6 +73,7 @@ val mk_frame_stream : is received, the stream will be closed. *) val establish_standard_server : + ?max_frame_length:int -> ?read_buf:Buffer.t -> ?write_buf:Buffer.t -> ?timeout:int ->