{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
module Dataflow.Primitives (
Dataflow(..),
DataflowState,
Vertex(..),
initDataflowState,
duplicateDataflowState,
StateRef,
newState,
readState,
writeState,
modifyState,
Edge,
Timestamp(..),
registerVertex,
registerFinalizer,
incrementEpoch,
input,
send,
finalize
) where
import Control.Arrow ((>>>))
import Control.Monad (forM, (>=>))
import Control.Monad.IO.Class (liftIO)
import Control.Monad.State.Strict (StateT, get, gets, modify)
import Control.Monad.Trans (lift)
import Data.Hashable (Hashable (..))
import Data.IORef (IORef, atomicModifyIORef',
atomicWriteIORef, newIORef,
readIORef)
import Data.Vector (Vector, empty, snoc, (!))
import Numeric.Natural (Natural)
import Prelude
import Unsafe.Coerce (unsafeCoerce)
newtype VertexID = VertexID Int deriving (Eq, Ord, Show)
newtype StateID = StateID Int deriving (Eq, Ord, Show)
newtype Epoch = Epoch Natural deriving (Eq, Ord, Hashable, Show)
newtype Timestamp = Timestamp Epoch deriving (Eq, Ord, Hashable, Show)
newtype Edge a = Edge VertexID
class Incrementable a where
inc :: a -> a
instance Incrementable VertexID where
inc (VertexID n) = VertexID (n + 1)
instance Incrementable StateID where
inc (StateID n) = StateID (n + 1)
instance Incrementable Epoch where
inc (Epoch n) = Epoch (n + 1)
data ErasedType = forall i. EraseType i
unEraseType :: ErasedType -> a
unEraseType (EraseType x) = unsafeCoerce x
data DataflowState = DataflowState {
dfsVertices :: Vector ErasedType,
dfsStates :: Vector (IORef ErasedType),
dfsFinalizers :: [Timestamp -> Dataflow ()],
dfsLastVertexID :: VertexID,
dfsLastStateID :: StateID,
dfsLastInputEpoch :: Epoch
}
newtype Dataflow a = Dataflow { runDataflow :: StateT DataflowState IO a }
deriving (Functor, Applicative, Monad)
initDataflowState :: DataflowState
initDataflowState = DataflowState {
dfsVertices = empty,
dfsStates = empty,
dfsFinalizers = [],
dfsLastVertexID = VertexID (-1),
dfsLastStateID = StateID (-1),
dfsLastInputEpoch = Epoch 0
}
duplicateDataflowState :: Dataflow (DataflowState)
duplicateDataflowState = Dataflow $ do
DataflowState{..} <- get
newStates <- liftIO $ forM dfsStates dupIORef
return $ DataflowState { dfsStates = newStates, .. }
where
dupIORef = readIORef >=> newIORef
incrementEpoch :: Dataflow Epoch
incrementEpoch =
Dataflow $ do
epoch <- gets (dfsLastInputEpoch >>> inc)
modify $ \s -> s { dfsLastInputEpoch = epoch }
return epoch
data Vertex i = forall s.
StatefulVertex
(StateRef s)
(StateRef s -> Timestamp -> i -> Dataflow ())
| StatelessVertex
(Timestamp -> i -> Dataflow ())
lookupVertex :: Edge i -> Dataflow (Vertex i)
lookupVertex (Edge (VertexID vindex)) =
Dataflow $ do
vertices <- gets dfsVertices
return $ unEraseType (vertices ! vindex)
registerVertex :: Vertex i -> Dataflow (Edge i)
registerVertex vertex =
Dataflow $ do
vid <- gets (dfsLastVertexID >>> inc)
modify $ addVertex vertex vid
return (Edge vid)
where
addVertex vtx vid s = s {
dfsVertices = dfsVertices s `snoc` EraseType vtx,
dfsLastVertexID = vid
}
registerFinalizer :: (Timestamp -> Dataflow ()) -> Dataflow ()
registerFinalizer finalizer =
Dataflow $ modify $ \s -> s { dfsFinalizers = finalizer : dfsFinalizers s }
newtype StateRef a = StateRef StateID
newState :: a -> Dataflow (StateRef a)
newState a =
Dataflow $ do
sid <- gets (dfsLastStateID >>> inc)
ioref <- lift $ newIORef (EraseType a)
modify $ addState ioref sid
return (StateRef sid)
where
addState ref sid s = s {
dfsStates = dfsStates s `snoc` ref,
dfsLastStateID = sid
}
lookupStateRef :: StateRef s -> Dataflow (IORef ErasedType)
lookupStateRef (StateRef (StateID sindex)) =
Dataflow $ do
states <- gets dfsStates
return (states ! sindex)
readState :: StateRef a -> Dataflow a
readState sref = do
ioref <- lookupStateRef sref
Dataflow $ lift $ (unEraseType <$> readIORef ioref)
writeState :: StateRef a -> a -> Dataflow ()
writeState sref x = do
ioref <- lookupStateRef sref
Dataflow $ lift $ atomicWriteIORef ioref (EraseType x)
modifyState :: StateRef a -> (a -> a) -> Dataflow ()
modifyState sref op = do
ioref <- lookupStateRef sref
Dataflow $ lift $ atomicModifyIORef' ioref (\x -> (EraseType $ op (unEraseType x), ()))
{-# INLINEABLE input #-}
input :: Traversable t => t i -> Edge i -> Dataflow ()
input inputs next = do
timestamp <- Timestamp <$> incrementEpoch
mapM_ (send next timestamp) inputs
finalize timestamp
{-# INLINE send #-}
send :: Edge input -> Timestamp -> input -> Dataflow ()
send e t i = lookupVertex e >>= invoke t i
where
invoke timestamp datum (StatefulVertex sref callback) = callback sref timestamp datum
invoke timestamp datum (StatelessVertex callback) = callback timestamp datum
finalize :: Timestamp -> Dataflow ()
finalize t = do
finalizers <- Dataflow $ gets dfsFinalizers
mapM_ (\p -> p t) finalizers