Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 47 additions & 21 deletions core/shared/src/main/scala/fs2/concurrent/Channel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ package concurrent

import cats.effect._
import cats.effect.implicits._
import cats.effect.Resource.ExitCase
import cats.syntax.all._

/** Stream aware, multiple producer, single consumer closeable channel.
Expand Down Expand Up @@ -116,6 +117,18 @@ sealed trait Channel[F[_], A] {
*/
def closeWithElement(a: A): F[Either[Channel.Closed, Unit]]

/** Raises an error, closing the channel with an error state.
*
* No-op if the channel is closed, see [[close]] for further info.
*/
def raiseError(e: Throwable): F[Either[Channel.Closed, Unit]]

/** Cancels the channel, closing it with a canceled state.
*
* No-op if the channel is closed, see [[close]] for further info.
*/
def cancel: F[Either[Channel.Closed, Unit]]

/** Returns true if this channel is closed */
def isClosed: F[Boolean]

Expand All @@ -138,80 +151,89 @@ object Channel {
size: Int,
waiting: Option[Deferred[F, Unit]],
producers: List[(A, Deferred[F, Unit])],
closed: Boolean
closed: Option[ExitCase]
)

val open = State(List.empty, 0, None, List.empty, closed = false)
val open = State(List.empty, 0, None, List.empty, closed = None)

def empty(isClosed: Boolean): State =
if (isClosed) State(List.empty, 0, None, List.empty, closed = true)
def empty(close: Option[ExitCase]): State =
if (close.nonEmpty) State(List.empty, 0, None, List.empty, closed = close)
else open

(F.ref(open), F.deferred[Unit]).mapN { (state, closedGate) =>
new Channel[F, A] {

def sendAll: Pipe[F, A, Nothing] = { in =>
(in ++ Stream.exec(close.void))
in.onFinalizeCase(closeWithExitCase(_).void)
.evalMap(send)
.takeWhile(_.isRight)
.drain
}

def sendImpl(a: A, close: Boolean) =
def sendImpl(a: A, close: Option[ExitCase]) =
F.deferred[Unit].flatMap { producer =>
state.flatModifyFull { case (poll, state) =>
state match {
case s @ State(_, _, _, _, closed @ true) =>
case s @ State(_, _, _, _, Some(_)) =>
(s, Channel.closed[Unit].pure[F])

case State(values, size, waiting, producers, closed @ false) =>
case State(values, size, waiting, producers, None) =>
if (size < capacity)
(
State(a :: values, size + 1, None, producers, close),
signalClosure.whenA(close) *> notifyStream(waiting).as(rightUnit)
signalClosure.whenA(close.nonEmpty) *> notifyStream(waiting).as(rightUnit)
)
else
(
State(values, size, None, (a, producer) :: producers, close),
signalClosure.whenA(close) *>
signalClosure.whenA(close.nonEmpty) *>
notifyStream(waiting).as(rightUnit) <*
waitOnBound(producer, poll).unlessA(close)
waitOnBound(producer, poll).unlessA(close.nonEmpty)
)
}
}
}

def send(a: A) = sendImpl(a, false)
def send(a: A) = sendImpl(a, None)

def closeWithElement(a: A) = sendImpl(a, true)
def closeWithElement(a: A) = sendImpl(a, Some(ExitCase.Succeeded))

def trySend(a: A) =
state.flatModify {
case s @ State(_, _, _, _, closed @ true) =>
case s @ State(_, _, _, _, Some(_)) =>
(s, Channel.closed[Boolean].pure[F])

case s @ State(values, size, waiting, producers, closed @ false) =>
case s @ State(values, size, waiting, producers, None) =>
if (size < capacity)
(
State(a :: values, size + 1, None, producers, false),
State(a :: values, size + 1, None, producers, None),
notifyStream(waiting).as(rightTrue)
)
else
(s, rightFalse.pure[F])
}

def close =
closeWithExitCase(ExitCase.Succeeded)

def closeWithExitCase(exitCase: ExitCase): F[Either[Closed, Unit]] =
state.flatModify {
case s @ State(_, _, _, _, closed @ true) =>
case s @ State(_, _, _, _, Some(_)) =>
(s, Channel.closed[Unit].pure[F])

case State(values, size, waiting, producers, closed @ false) =>
case State(values, size, waiting, producers, None) =>
(
State(values, size, None, producers, true),
State(values, size, None, producers, Some(exitCase)),
notifyStream(waiting).as(rightUnit) <* signalClosure
)
}

def raiseError(e: Throwable): F[Either[Closed, Unit]] =
closeWithExitCase(ExitCase.Errored(e))

def cancel: F[Either[Closed, Unit]] =
closeWithExitCase(ExitCase.Canceled)

def isClosed = closedGate.tryGet.map(_.isDefined)

def closed = closedGate.get
Expand Down Expand Up @@ -250,8 +272,12 @@ object Channel {
unblock.as(Pull.output(toEmit) >> consumeLoop)
} else {
F.pure(
if (closed) Pull.done
else Pull.eval(waiting.get) >> consumeLoop
closed match {
case Some(ExitCase.Succeeded) => Pull.done
case Some(ExitCase.Errored(e)) => Pull.raiseError(e)
case Some(ExitCase.Canceled) => Pull.eval(F.canceled)
case None => Pull.eval(waiting.get) >> consumeLoop
}
)
}
}
Expand Down
18 changes: 16 additions & 2 deletions core/shared/src/main/scala/fs2/concurrent/Topic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package fs2
package concurrent

import cats.effect._
import cats.effect.Resource.ExitCase
import cats.effect.implicits._
import cats.syntax.all._
import scala.collection.immutable.LongMap
Expand Down Expand Up @@ -208,7 +209,8 @@ object Topic {
}

def publish: Pipe[F, A, Nothing] = { in =>
(in ++ Stream.exec(close.void))
in
.onFinalizeCase(closeWithExitCase(_).void)
.evalMap(publish1)
.takeWhile(_.isRight)
.drain
Expand All @@ -223,13 +225,25 @@ object Topic {
def subscribers: Stream[F, Int] = subscriberCount.discrete

def close: F[Either[Topic.Closed, Unit]] =
closeWithExitCase(ExitCase.Succeeded)

def closeWithExitCase(exitCase: ExitCase): F[Either[Closed, Unit]] =
signalClosure
.complete(())
.flatMap { completedNow =>
val result = if (completedNow) Topic.rightUnit else Topic.closed
val closeChannel =
(channel: Channel[F, A]) =>
exitCase match {
case ExitCase.Succeeded => channel.close.void
case ExitCase.Errored(e) => channel.raiseError(e).void
case ExitCase.Canceled => channel.cancel.void
}

state.get
.flatMap { case (subs, _) => foreach(subs)(_.close.void) }
.flatMap { case (subs, _) =>
foreach(subs)(closeChannel)
}
.as(result)
}
.uncancelable
Expand Down
19 changes: 19 additions & 0 deletions core/shared/src/test/scala/fs2/concurrent/ChannelSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import scala.concurrent.duration._

import org.scalacheck.effect.PropF.forAllF

import scala.concurrent.CancellationException

class ChannelSuite extends Fs2Suite {

test("receives some simple elements above capacity and closes") {
Expand Down Expand Up @@ -323,4 +325,21 @@ class ChannelSuite extends Fs2Suite {
racingSendOperations(channel)
}

test("stream should terminate when sendAll is interrupted") {
val program =
Channel
.bounded[IO, Unit](1)
.flatMap { ch =>
val producer =
Stream
.eval(IO.canceled)
.through(ch.sendAll)

ch.stream.concurrently(producer).compile.drain
}

TestControl
.executeEmbed(program)
.intercept[CancellationException]
}
}
25 changes: 25 additions & 0 deletions core/shared/src/test/scala/fs2/concurrent/TopicSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ package concurrent
import cats.syntax.all._
import cats.effect.IO
import scala.concurrent.duration._
import scala.concurrent.CancellationException

import cats.effect.testkit.TestControl

class TopicSuite extends Fs2Suite {
Expand Down Expand Up @@ -185,4 +187,27 @@ class TopicSuite extends Fs2Suite {

TestControl.executeEmbed(program) // will fail if program is deadlocked
}

test("publisher cancellation does not deadlock") {
val program =
Topic[IO, String]
.flatMap { topic =>
val publisher =
Stream
.constant("1")
.covary[IO]
.evalTap(_ => IO.canceled)
.through(topic.publish)

Stream
.resource(topic.subscribeAwait(1))
.flatMap(subscriber => subscriber.concurrently(publisher))
.compile
.drain
}

TestControl
.executeEmbed(program)
.intercept[CancellationException]
}
}
Loading