Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,9 @@ class AkkaHttpServerTest extends TestSuite with EitherValues {
def drainAkka(stream: AkkaStreams.BinaryStream): Future[Unit] =
stream.runWith(Sink.ignore).map(_ => ())

new AllServerTests(createServerTest, interpreter, backend).tests() ++
new AllServerTests(createServerTest, interpreter, backend, multipart = false).tests() ++
new ServerMultipartTests(createServerTest, chunkingSupport = false)
.tests() ++ // chunking disabled, akka-http rejects content-length with transfer-encoding
new ServerStreamingTests(createServerTest).tests(AkkaStreams)(drainAkka) ++
new ServerWebSocketTests(
createServerTest,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ class ArmeriaCatsServerTest extends TestSuite {
def drainFs2(stream: Fs2Streams[IO]#BinaryStream): IO[Unit] =
stream.compile.drain.void

new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false).tests() ++
new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false, multipart = false)
.tests() ++
new ServerMultipartTests(createServerTest, chunkingSupport = false)
.tests() ++ // chunking disabled, Armeria rejects content-length with transfer-encoding
new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false, maxContentLength = false).tests() ++
new ServerStreamingTests(createServerTest).tests(Fs2Streams[IO])(drainFs2)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ class ArmeriaFutureServerTest extends TestSuite {
val interpreter = new ArmeriaTestFutureServerInterpreter()
val createServerTest = new DefaultCreateServerTest(backend, interpreter)

new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false).tests() ++
new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false, multipart = false)
.tests() ++
new ServerMultipartTests(createServerTest, chunkingSupport = false)
.tests() ++ // chunking disabled, Armeria rejects content-length with transfer-encoding
new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false, maxContentLength = false).tests() ++
new ServerStreamingTests(createServerTest, maxLengthSupported = false).tests(ArmeriaStreams)(_ => Future.unit)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ class ArmeriaZioServerTest extends TestSuite {
def drainZStream(zStream: ZioStreams.BinaryStream): Task[Unit] =
zStream.run(ZSink.drain)

new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false).tests() ++
new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false, multipart = false)
.tests() ++
new ServerMultipartTests(createServerTest, chunkingSupport = false)
.tests() ++ // chunking disabled, Armeria rejects content-length with transfer-encoding
new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false, maxContentLength = false).tests() ++
new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ class JdkHttpServerTest extends TestSuite with EitherValues {
val createServerTest = new DefaultCreateServerTest(backend, interpreter)

new ServerBasicTests(createServerTest, interpreter, invulnerableToUnsanitizedHeaders = false).tests() ++
new AllServerTests(createServerTest, interpreter, backend, basic = false).tests()
new ServerMultipartTests(createServerTest, chunkingSupport = false)
.tests() ++ // chunking disabled, backend rejects content-length with transfer-encoding
new AllServerTests(createServerTest, interpreter, backend, basic = false, multipart = false).tests()
})
}
}
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
package sttp.tapir.server.netty.cats.internal

import cats.effect.Async
import cats.effect.kernel.{Resource, Sync}
import cats.syntax.all._
import fs2.Chunk
import fs2.interop.reactivestreams.StreamSubscriber
import fs2.io.file.{Files, Path}
import io.netty.handler.codec.http.HttpContent
import io.netty.handler.codec.http.multipart.{DefaultHttpDataFactory, HttpData, HttpPostRequestDecoder}
import org.playframework.netty.http.StreamedHttpRequest
import org.reactivestreams.Publisher
import sttp.capabilities.StreamMaxLengthExceededException
import sttp.capabilities.fs2.Fs2Streams
import sttp.model.Part
import sttp.monad.MonadError
import sttp.tapir.TapirFile
import sttp.tapir.integ.cats.effect.CatsMonadError
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.interpreter.RawValue
import sttp.tapir.server.netty.internal.{NettyStreamingRequestBody, StreamCompatible}
import sttp.capabilities.WebSockets
import sttp.tapir.{RawBodyType, RawPart, TapirFile}

import java.io.File


private[cats] class NettyCatsRequestBody[F[_]: Async](
val createFile: ServerRequest => F[TapirFile],
Expand All @@ -21,6 +30,63 @@ private[cats] class NettyCatsRequestBody[F[_]: Async](

override implicit val monad: MonadError[F] = new CatsMonadError()

def publisherToMultipart(
nettyRequest: StreamedHttpRequest,
serverRequest: ServerRequest,
m: RawBodyType.MultipartBody,
maxBytes: Option[Long]
): F[RawValue[Seq[RawPart]]] = {
fs2.Stream
.resource(
Resource.make(Sync[F].delay(new HttpPostRequestDecoder(NettyCatsRequestBody.multiPartDataFactory, nettyRequest)))(d =>
Sync[F].blocking(d.destroy()) // after the stream finishes or fails, decoder data has to be cleaned up
)
)
.flatMap { decoder =>
fs2.Stream
.eval(StreamSubscriber[F, HttpContent](bufferSize = 1))
.flatMap(s => s.sub.stream(Sync[F].delay(nettyRequest.subscribe(s))))
.evalMapAccumulate({
(decoder, 0L)
})({ case ((decoder, processedBytesNum), httpContent) =>
monad
.blocking {
val newProcessedBytes = if (httpContent.content() != null) {
val processedBytesAndContentBytes = processedBytesNum + httpContent.content().readableBytes()
maxBytes.foreach { max =>
if (max < processedBytesAndContentBytes) {
throw new StreamMaxLengthExceededException(max)
}
}
processedBytesAndContentBytes
} else processedBytesNum

// this operation is the one that does potential I/O (writing files)
decoder.offer(httpContent)
val parts = Stream
.continually(if (decoder.hasNext) {
val next = decoder.next()
next
} else null)
.takeWhile(_ != null)
.toVector

(
(decoder, newProcessedBytes),
parts
)
}
})
.map(_._2)
.map(_.flatMap(p => m.partType(p.getName()).map((p, _)).toList))
.evalMap(_.traverse { case (data, partType) => toRawPart(serverRequest, data, partType).map(_.asInstanceOf[Part[Any]]) })
}
.compile
.toVector
.map(_.flatten)
.map(RawValue.fromParts(_))
}

override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): F[Array[Byte]] =
streamCompatible.fromPublisher(publisher, maxBytes).compile.to(Chunk).map(_.toArray[Byte])

Expand All @@ -32,4 +98,13 @@ private[cats] class NettyCatsRequestBody[F[_]: Async](
)
.compile
.drain

override def writeBytesToFile(bytes: Array[Byte], file: File): F[Unit] =
fs2.Stream.emits(bytes).through(Files.forAsync[F].writeAll(Path.fromNioPath(file.toPath))).compile.drain

}

private[cats] object NettyCatsRequestBody {
val multiPartDataFactory =
new DefaultHttpDataFactory() // writes to memory, then switches to disk if exceeds MINSIZE (16kB), check other constructors.
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ import sttp.tapir.tests.{Test, TestSuite}

import scala.concurrent.Future
import scala.concurrent.duration.FiniteDuration
import org.scalatest.matchers.should.Matchers

class NettyCatsServerTest extends TestSuite with EitherValues {
class NettyCatsServerTest extends TestSuite with EitherValues with Matchers {

override def tests: Resource[IO, List[Test]] =
backendResource.flatMap { backend =>
Expand Down Expand Up @@ -41,6 +42,12 @@ class NettyCatsServerTest extends TestSuite with EitherValues {
new ServerCancellationTests(createServerTest)(m, IO.asyncForIO).tests() ++
new NettyFs2StreamingCancellationTest(createServerTest).tests() ++
new ServerGracefulShutdownTests(createServerTest, ioSleeper).tests() ++
new ServerMultipartTests(
createServerTest,
partContentTypeHeaderSupport = false,
partOtherHeaderSupport = false,
multipartResponsesSupport = false
).tests() ++
new ServerWebSocketTests(
createServerTest,
Fs2Streams[IO],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,27 @@ import org.playframework.netty.http.StreamedHttpRequest
import org.reactivestreams.Publisher
import sttp.capabilities
import sttp.monad.{FutureMonad, MonadError}
import sttp.tapir.TapirFile
import sttp.tapir.capabilities.NoStreams
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.interpreter.RawValue
import sttp.tapir.server.netty.internal.reactivestreams._
import sttp.tapir.{RawBodyType, RawPart, TapirFile}

import java.io.File
import scala.concurrent.{ExecutionContext, Future}

private[netty] class NettyFutureRequestBody(val createFile: ServerRequest => Future[TapirFile])(implicit ec: ExecutionContext)
extends NettyRequestBody[Future, NoStreams] {

override def publisherToMultipart(
nettyRequest: StreamedHttpRequest,
serverRequest: ServerRequest,
m: RawBodyType.MultipartBody,
maxBytes: Option[Long]
): Future[RawValue[Seq[RawPart]]] = Future.failed(new UnsupportedOperationException("Multipart requests not supported."))

override def writeBytesToFile(bytes: Array[Byte], file: File): Future[Unit] = Future.failed(new UnsupportedOperationException)

override val streams: capabilities.Streams[NoStreams] = NoStreams
override implicit val monad: MonadError[Future] = new FutureMonad()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile}
import java.io.InputStream
import java.nio.ByteBuffer

import scala.collection.JavaConverters._
import sttp.tapir.RawPart
import io.netty.handler.codec.http.multipart.InterfaceHttpData
import sttp.model.Part
import io.netty.handler.codec.http.multipart.HttpData
import io.netty.handler.codec.http.multipart.FileUpload
import java.io.ByteArrayInputStream
import java.io.File

/** Common logic for processing request body in all Netty backends. It requires particular backends to implement a few operations. */
private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody[F, S] {

Expand All @@ -37,6 +46,16 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody
*/
def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): F[Array[Byte]]

/** Reads the reactive stream emitting HttpData into a vector of parts. Implementation-specific, as file manipulations and stream
* processing logic can be different for different backends.
*/
def publisherToMultipart(
nettyRequest: StreamedHttpRequest,
serverRequest: ServerRequest,
m: RawBodyType.MultipartBody,
maxBytes: Option[Long]
): F[RawValue[Seq[RawPart]]]

/** Backend-specific way to process all elements emitted by a Publisher[HttpContent] and write their bytes into a file.
*
* @param serverRequest
Expand All @@ -50,6 +69,8 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody
*/
def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit]

def writeBytesToFile(bytes: Array[Byte], file: File): F[Unit]

override def toRaw[RAW](
serverRequest: ServerRequest,
bodyType: RawBodyType[RAW],
Expand All @@ -70,8 +91,8 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody
file <- createFile(serverRequest)
_ <- writeToFile(serverRequest, file, maxBytes)
} yield RawValue(FileRange(file), Seq(FileRange(file)))
case _: RawBodyType.MultipartBody =>
monad.error(new UnsupportedOperationException)
case m: RawBodyType.MultipartBody =>
publisherToMultipart(serverRequest.underlying.asInstanceOf[StreamedHttpRequest], serverRequest, m, maxBytes)
}

private def readAllBytes(serverRequest: ServerRequest, maxBytes: Option[Long]): F[Array[Byte]] =
Expand All @@ -96,4 +117,72 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody
throw new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass.getName}")
}
}

protected def toRawPart[R](
serverRequest: ServerRequest,
data: InterfaceHttpData,
partType: RawBodyType[R]
): F[Part[R]] = {
val partName = data.getName()
data match {
case httpData: HttpData =>
// TODO filename* attribute is not used by netty. Non-ascii filenames like https://github.com/http4s/http4s/issues/5809 are unsupported.
toRawPartHttpData(partName, serverRequest, httpData, partType)
case unsupportedDataType =>
monad.error(new UnsupportedOperationException(s"Unsupported multipart data type: $unsupportedDataType in part $partName"))
}
}

private def toRawPartHttpData[R](
partName: String,
serverRequest: ServerRequest,
httpData: HttpData,
partType: RawBodyType[R]
): F[Part[R]] = {
val fileName = httpData match {
case fileUpload: FileUpload => Option(fileUpload.getFilename())
case _ => None
}
partType match {
case RawBodyType.StringBody(defaultCharset) =>
// TODO otherDispositionParams not supported. They are normally a part of the content-disposition part header, but this header is not directly accessible, they are extracted internally by the decoder.
val charset = if (httpData.getCharset() != null) httpData.getCharset() else defaultCharset
readHttpData(httpData, _.getString(charset)).map(body => Part(partName, body, fileName = fileName))
case RawBodyType.ByteArrayBody =>
readHttpData(httpData, _.get()).map(body => Part(partName, body, fileName = fileName))
case RawBodyType.ByteBufferBody =>
readHttpData(httpData, _.get()).map(body => Part(partName, ByteBuffer.wrap(body), fileName = fileName))
case RawBodyType.InputStreamBody =>
(if (httpData.isInMemory())
monad.unit(new ByteArrayInputStream(httpData.get()))
else {
monad.blocking(java.nio.file.Files.newInputStream(httpData.getFile().toPath()))
}).map(body => Part(partName, body, fileName = fileName))
case RawBodyType.InputStreamRangeBody =>
val body = () => {
if (httpData.isInMemory())
new ByteArrayInputStream(httpData.get())
else
java.nio.file.Files.newInputStream(httpData.getFile().toPath())
}
monad.unit(Part(partName, InputStreamRange(body), fileName = fileName))
case RawBodyType.FileBody =>
val fileF: F[File] =
if (httpData.isInMemory())
(for {
file <- createFile(serverRequest)
_ <- writeBytesToFile(httpData.get(), file)
} yield file)
else monad.unit(httpData.getFile())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: Netty decoder creates the file (if it's > 16KB), so we can just get its handle here.

fileF.map(file => Part(partName, FileRange(file), fileName = fileName))
case _: RawBodyType.MultipartBody =>
monad.error(new UnsupportedOperationException(s"Nested multipart not supported, part name = $partName"))
}
}

private def readHttpData[T](httpData: HttpData, f: HttpData => T): F[T] =
if (httpData.isInMemory())
monad.unit(f(httpData))
else
monad.blocking(f(httpData))
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ import sttp.tapir.model.ServerRequest
import sttp.tapir.server.netty.internal.NettyRequestBody
import sttp.tapir.server.netty.internal.reactivestreams.{FileWriterSubscriber, SimpleSubscriber}
import sttp.tapir.server.netty.sync.*
import sttp.tapir.RawBodyType
import sttp.tapir.server.interpreter.RawValue
import sttp.tapir.RawPart
import java.io.File

private[sync] class NettySyncRequestBody(val createFile: ServerRequest => TapirFile) extends NettyRequestBody[Identity, OxStreams]:

Expand All @@ -20,6 +24,14 @@ private[sync] class NettySyncRequestBody(val createFile: ServerRequest => TapirF
override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): Array[Byte] =
SimpleSubscriber.processAllBlocking(publisher, contentLength, maxBytes)

override def publisherToMultipart(
nettyRequest: StreamedHttpRequest,
serverRequest: ServerRequest,
m: RawBodyType.MultipartBody,
maxBytes: Option[Long]
): RawValue[Seq[RawPart]] = throw new UnsupportedOperationException("Multipart requests not supported.")
override def writeBytesToFile(bytes: Array[Byte], file: File) = throw new UnsupportedOperationException()

override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Unit =
serverRequest.underlying match
case r: StreamedHttpRequest => FileWriterSubscriber.processAllBlocking(r, file.toPath, maxBytes)
Expand Down
Loading
Loading