Skip to content

Commit 00e93a6

Browse files
committed
Make cuFFT plan handle cache thread-safe
1 parent 3c195de commit 00e93a6

File tree

7 files changed

+86
-33
lines changed

7 files changed

+86
-33
lines changed

accelerate-fft.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ library
8585
accelerate-llvm >= 1.3
8686
, accelerate-llvm-ptx >= 1.3
8787
, containers >= 0.5
88+
, exceptions >= 0.10
8889
, hashable >= 1.0
8990
, unordered-containers >= 0.2
9091
, cuda >= 0.5

src/Data/Array/Accelerate/Math/FFT/LLVM/Native.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{-# LANGUAGE GADTs #-}
22
{-# LANGUAGE PatternGuards #-}
3+
{-# LANGUAGE OverloadedStrings #-}
34
{-# LANGUAGE ScopedTypeVariables #-}
45
{-# LANGUAGE TypeApplications #-}
56
{-# LANGUAGE TypeFamilies #-}

src/Data/Array/Accelerate/Math/FFT/LLVM/Native/Ix.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
{-# LANGUAGE OverloadedStrings #-}
12
{-# LANGUAGE ScopedTypeVariables #-}
23
{-# LANGUAGE TypeApplications #-}
34
{-# LANGUAGE TypeFamilies #-}
5+
{-# LANGUAGE TypeOperators #-}
46
-- |
57
-- Module : Data.Array.Accelerate.Math.FFT.LLVM.Native.Ix
68
-- Copyright : [2017..2020] The Accelerate Team

src/Data/Array/Accelerate/Math/FFT/LLVM/PTX.hs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
{-# LANGUAGE FlexibleContexts #-}
22
{-# LANGUAGE GADTs #-}
33
{-# LANGUAGE PatternGuards #-}
4+
{-# LANGUAGE OverloadedStrings #-}
45
{-# LANGUAGE ScopedTypeVariables #-}
56
{-# LANGUAGE TupleSections #-}
67
{-# LANGUAGE TypeApplications #-}
@@ -92,14 +93,14 @@ fft' plans mode shR eR =
9293
aout <- allocateRemote aR sh
9394
stream <- asks ptxStream
9495
future <- new
95-
liftPar $
96-
withArray eR ain stream $ \d_in -> do
97-
withArray eR aout stream $ \d_out -> do
98-
withPlan plans (sh,t) $ \h -> do
99-
liftIO $ cuFFT eR h mode stream (castDevPtr d_in) (castDevPtr d_out)
100-
--
101-
put future aout
102-
return future
96+
withPlan plans (sh,t) $ \h -> do
97+
liftPar $
98+
withArray eR ain stream $ \d_in -> do
99+
withArray eR aout stream $ \d_out -> do
100+
liftIO $ cuFFT eR h mode stream (castDevPtr d_in) (castDevPtr d_out)
101+
--
102+
put future aout
103+
return future
103104
in
104105
case eR of
105106
NumericRfloat32 -> go (ArrayR shR (eltR @(Complex Float)))

src/Data/Array/Accelerate/Math/FFT/LLVM/PTX/Base.hs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
module Data.Array.Accelerate.Math.FFT.LLVM.PTX.Base
1717
where
1818

19+
import Control.Concurrent.MVar
20+
import Control.Exception (evaluate)
21+
import Control.Monad.Catch
22+
import Control.Monad.IO.Class
1923
import Data.Array.Accelerate.Math.FFT.Type
2024

2125
import Data.Array.Accelerate.Array.Data
@@ -57,9 +61,17 @@ withArrayData NumericRfloat64 ad s k =
5761
return (Just e, r)
5862

5963
{-# INLINE withLifetime' #-}
60-
withLifetime' :: Lifetime a -> (a -> LLVM PTX b) -> LLVM PTX b
64+
withLifetime' :: MonadIO m => Lifetime a -> (a -> m b) -> m b
6165
withLifetime' l k = do
6266
r <- k (unsafeGetValue l)
6367
liftIO $ touchLifetime l
6468
return r
6569

70+
{-# INLINE modifyMVar' #-}
71+
modifyMVar' :: (MonadIO m, MonadMask m) => MVar a -> (a -> m (a,b)) -> m b
72+
modifyMVar' m io =
73+
mask $ \restore -> do
74+
a <- liftIO (takeMVar m)
75+
(a',b) <- restore (io a >>= liftIO . evaluate) `onException` liftIO (putMVar m a)
76+
liftIO (putMVar m a')
77+
return b

src/Data/Array/Accelerate/Math/FFT/LLVM/PTX/Plans.hs

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
{-# LANGUAGE MagicHash #-}
22
{-# LANGUAGE RecordWildCards #-}
3+
{-# LANGUAGE LambdaCase #-}
4+
{-# LANGUAGE TupleSections #-}
35
-- |
46
-- Module : Data.Array.Accelerate.Math.FFT.LLVM.PTX.Plans
57
-- Copyright : [2017..2020] The Accelerate Team
@@ -19,26 +21,32 @@ module Data.Array.Accelerate.Math.FFT.LLVM.PTX.Plans (
1921
) where
2022

2123
import Data.Array.Accelerate.Lifetime
22-
import Data.Array.Accelerate.LLVM.PTX
24+
import Data.Array.Accelerate.LLVM.PTX hiding (stream, poll)
2325
import Data.Array.Accelerate.LLVM.PTX.Foreign
2426

2527
import Data.Array.Accelerate.Math.FFT.LLVM.PTX.Base
2628

2729
import Control.Concurrent.MVar
30+
import Control.Monad.Catch
2831
import Control.Monad.State
29-
import Data.HashMap.Strict
32+
import Data.HashMap.Strict hiding (map, update)
3033
import qualified Data.HashMap.Strict as Map
3134

3235
import qualified Foreign.CUDA.Driver.Context as CUDA
36+
import qualified Foreign.CUDA.Driver.Stream as CUDA
3337
import qualified Foreign.CUDA.FFT as FFT
3438

3539
import GHC.Ptr
3640
import GHC.Base
37-
import Prelude hiding ( lookup )
41+
import Prelude hiding ( lookup, mapM )
42+
import Data.Maybe
43+
import Control.Arrow (second)
44+
import Data.Function ((&))
45+
import Control.Monad.Reader (asks)
3846

3947

4048
data Plans a = Plans
41-
{ plans :: {-# UNPACK #-} !(MVar ( HashMap (Int, Int) (Lifetime FFT.Handle)))
49+
{ plans :: {-# UNPACK #-} !(MVar ( HashMap (Int, Int) [(Lifetime FFT.Handle, Maybe (Par PTX Bool, CUDA.Stream))]))
4250
, create :: a -> IO FFT.Handle
4351
, hash :: a -> Int
4452
}
@@ -62,30 +70,57 @@ createPlan via mix =
6270
--
6371
-- <http://docs.nvidia.com/cuda/cufft/index.html#thread-safety>
6472
--
73+
-- TODO: Determine if this handle is used in the same stream.
6574
{-# INLINE withPlan #-}
66-
withPlan :: Plans a -> a -> (FFT.Handle -> LLVM PTX b) -> LLVM PTX b
75+
withPlan :: Plans a -> a -> (FFT.Handle -> Par PTX (Future b)) -> Par PTX (Future b)
6776
withPlan Plans{..} a k = do
6877
lc <- gets (deviceContext . ptxContext)
69-
h <- liftIO $
70-
withLifetime lc $ \ctx ->
71-
modifyMVar plans $ \pm ->
72-
let key = (toKey ctx, hash a) in
73-
case Map.lookup key pm of
74-
-- handle does not exist yet; create it and add to the global
75-
-- state for reuse
76-
Nothing -> do
77-
h <- create a
78-
l <- newLifetime h
79-
addFinalizer lc $ modifyMVar plans (\pm' -> return (Map.delete key pm', ()))
80-
addFinalizer l $ FFT.destroy h
81-
return ( Map.insert key l pm, l )
82-
83-
-- return existing handle
84-
Just h -> return (pm, h)
85-
--
86-
withLifetime' h k
78+
ls <- asks ptxStream
79+
withLifetime' ls $ \stream ->
80+
withLifetime' lc $ \ctx -> do
81+
let key = (toKey ctx, hash a)
82+
-- Extract an existing cuFFT plan handle from our plan cache that isn't busy,
83+
-- if one cannot be found, create a new cuFFT handle.
84+
h <- modifyMVar' plans $ \pm -> do
85+
let maybeHandles = pm !? key
86+
handles = fromMaybe [] maybeHandles
87+
88+
update Nothing = pure Nothing
89+
update orig@(Just (isReady, _)) = isReady >>= \case
90+
True -> pure Nothing
91+
False -> pure orig
92+
93+
updatedHandles <- zip (map fst handles) <$> mapM (update . snd) handles
94+
95+
-- Extract first handle which is either entirely ready or is used but within the same stream
96+
let extractFirstReady [] = (Nothing, [])
97+
extractFirstReady (x@(_, Nothing):xs) = (Just x, xs)
98+
extractFirstReady (x@(_, Just (_, s)):xs) | stream == s = (Just x, xs)
99+
extractFirstReady (x@(_, Just _):xs) = second (x:) $ extractFirstReady xs
100+
101+
(maybeReadyHandle, otherHandles) = extractFirstReady updatedHandles
102+
103+
newHandle = liftIO $ do
104+
h <- create a
105+
l <- newLifetime h
106+
addFinalizer l $ FFT.destroy h
107+
when (isNothing maybeHandles) $
108+
addFinalizer lc $ modifyMVar_ plans $ pure . Map.delete key
109+
pure l
110+
111+
maybeReadyHandle & maybe newHandle (pure . fst)
112+
& fmap (Map.insert key otherHandles pm,)
113+
-- Ensure the handle is always returned back to the plan cache
114+
let returnHandle = liftIO $ modifyMVar_ plans $ pure . Map.adjust ((h, Nothing):) key
115+
flip onException returnHandle $ do
116+
-- Invoke user-provided function with cuFFT handle
117+
future <- withLifetime' h k
118+
-- Push new cuFFT plan-handle onto list of plan-handles of equal settings,
119+
-- w/ callback to check if the cuFFT handle is ready to use again.
120+
planHandleEntry <- (h,) . Just . (,stream) . fmap isJust . poll <$> statusHandle future
121+
liftIO $ modifyMVar_ plans $ pure . Map.adjust (planHandleEntry:) key
122+
pure future
87123

88124
{-# INLINE toKey #-}
89125
toKey :: CUDA.Context -> Int
90126
toKey (CUDA.Context (Ptr addr#)) = I# (addr2Int# addr#)
91-

src/Data/Array/Accelerate/Math/FFT/Type.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
{-# LANGUAGE FlexibleInstances #-}
33
{-# LANGUAGE GADTs #-}
44
{-# LANGUAGE NoImplicitPrelude #-}
5+
{-# LANGUAGE TypeOperators #-}
56
{-# OPTIONS_HADDOCK hide #-}
67
-- |
78
-- Module : Data.Array.Accelerate.Math.FFT.Type

0 commit comments

Comments
 (0)