{-# LANGUAGE ExistentialQuantification  #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE RecordWildCards            #-}

module Dataflow.Primitives (
  Dataflow(..),
  DataflowState,
  Vertex(..),
  initDataflowState,
  duplicateDataflowState,
  StateRef,
  newState,
  readState,
  writeState,
  modifyState,
  Edge,
  Timestamp(..),
  registerVertex,
  registerFinalizer,
  incrementEpoch,
  input,
  send,
  finalize
) where

import           Control.Arrow              ((>>>))
import           Control.Monad              (forM, (>=>))
import           Control.Monad.IO.Class     (liftIO)
import           Control.Monad.State.Strict (StateT, get, gets, modify)
import           Control.Monad.Trans        (lift)
import           Data.Hashable              (Hashable (..))
import           Data.IORef                 (IORef, atomicModifyIORef',
                                             atomicWriteIORef, newIORef,
                                             readIORef)
import           Data.Vector                (Vector, empty, snoc, (!))
import           Numeric.Natural            (Natural)
import           Prelude
import           Unsafe.Coerce              (unsafeCoerce)


newtype VertexID    = VertexID        Int deriving (Eq, Ord, Show)
newtype StateID     = StateID         Int deriving (Eq, Ord, Show)
newtype Epoch       = Epoch       Natural deriving (Eq, Ord, Hashable, Show)

-- | 'Timestamp's represent instants in the causal timeline.
--
-- @since 0.1.0.0
newtype Timestamp   = Timestamp     Epoch deriving (Eq, Ord, Hashable, Show)

-- | An 'Edge' is a typed reference to a computational vertex that
-- takes 'a's as its input.
--
-- @since 0.1.0.0
newtype Edge a      = Edge       VertexID

-- | Class of entities that can be incremented by one.
class Incrementable a where
  inc :: a -> a

instance Incrementable VertexID where
  inc (VertexID n) = VertexID (n + 1)

instance Incrementable StateID where
  inc (StateID n) = StateID (n + 1)

instance Incrementable Epoch where
  inc (Epoch n) = Epoch (n + 1)


-- | 'ErasedType' erases the type it wraps.
data ErasedType = forall i. EraseType i

unEraseType :: ErasedType -> a
unEraseType (EraseType x) = unsafeCoerce x


data DataflowState = DataflowState {
  dfsVertices       :: Vector ErasedType,
  dfsStates         :: Vector (IORef ErasedType),
  dfsFinalizers     :: [Timestamp -> Dataflow ()],
  dfsLastVertexID   :: VertexID,
  dfsLastStateID    :: StateID,
  dfsLastInputEpoch :: Epoch
}

-- | `Dataflow` is the type of all dataflow operations.
--
-- @since 0.1.0.0
newtype Dataflow a = Dataflow { runDataflow :: StateT DataflowState IO a }
  deriving (Functor, Applicative, Monad)

initDataflowState :: DataflowState
initDataflowState = DataflowState {
  dfsVertices       = empty,
  dfsStates         = empty,
  dfsFinalizers     = [],
  dfsLastVertexID   = VertexID (-1),
  dfsLastStateID    = StateID (-1),
  dfsLastInputEpoch = Epoch 0
}

duplicateDataflowState :: Dataflow (DataflowState)
duplicateDataflowState = Dataflow $ do
  DataflowState{..} <- get

  newStates <- liftIO $ forM dfsStates dupIORef

  return $ DataflowState { dfsStates = newStates, .. }

  where
    dupIORef = readIORef >=> newIORef

-- | Get the next input Epoch.
incrementEpoch :: Dataflow Epoch
incrementEpoch =
  Dataflow $ do
    epoch <- gets (dfsLastInputEpoch >>> inc)

    modify $ \s -> s { dfsLastInputEpoch = epoch }

    return epoch


data Vertex i = forall s.
    StatefulVertex
      (StateRef s)
      (StateRef s -> Timestamp -> i -> Dataflow ())
  | StatelessVertex
      (Timestamp -> i -> Dataflow ())

-- | Retrieve the vertex for a given edge.
lookupVertex :: Edge i -> Dataflow (Vertex i)
lookupVertex (Edge (VertexID vindex)) =
  Dataflow $ do
    vertices <- gets dfsVertices

    return $ unEraseType (vertices ! vindex)

-- | Store a provided vertex and obtain an 'Edge' that refers to it.
registerVertex :: Vertex i -> Dataflow (Edge i)
registerVertex vertex =
  Dataflow $ do
    vid <- gets (dfsLastVertexID >>> inc)

    modify $ addVertex vertex vid

    return (Edge vid)

  where
    addVertex vtx vid s = s {
      dfsVertices     = dfsVertices s `snoc` EraseType vtx,
      dfsLastVertexID = vid
    }

-- | Store a provided finalizer.
registerFinalizer :: (Timestamp -> Dataflow ()) -> Dataflow ()
registerFinalizer finalizer =
  Dataflow $ modify $ \s -> s { dfsFinalizers = finalizer : dfsFinalizers s }

-- | Mutable state that holds an `a`.
--
-- @since 0.1.0.0
newtype StateRef a = StateRef StateID

-- | Create a `StateRef` initialized to the provided `a`.
--
-- @since 0.1.0.0
newState :: a -> Dataflow (StateRef a)
newState a =
  Dataflow $ do
    sid   <- gets (dfsLastStateID >>> inc)
    ioref <- lift $ newIORef (EraseType a)

    modify $ addState ioref sid

    return (StateRef sid)

  where
    addState ref sid s = s {
      dfsStates      = dfsStates s `snoc` ref,
      dfsLastStateID = sid
    }

lookupStateRef :: StateRef s -> Dataflow (IORef ErasedType)
lookupStateRef (StateRef (StateID sindex)) =
  Dataflow $ do
    states <- gets dfsStates

    return (states ! sindex)

-- | Read the value stored in the `StateRef`.
--
-- @since 0.1.0.0
readState :: StateRef a -> Dataflow a
readState sref = do
  ioref <- lookupStateRef sref
  Dataflow $ lift $ (unEraseType <$> readIORef ioref)

-- | Overwrite the value stored in the `StateRef`.
--
-- @since 0.1.0.0
writeState :: StateRef a -> a -> Dataflow ()
writeState sref x = do
  ioref <- lookupStateRef sref
  Dataflow $ lift $ atomicWriteIORef ioref (EraseType x)

-- | Update the value stored in `StateRef`.
--
-- @since 0.1.0.0
modifyState :: StateRef a -> (a -> a) -> Dataflow ()
modifyState sref op = do
  ioref <- lookupStateRef sref
  Dataflow $ lift $ atomicModifyIORef' ioref (\x -> (EraseType $ op (unEraseType x), ()))

{-# INLINEABLE input #-}
input :: Traversable t => t i -> Edge i -> Dataflow ()
input inputs next = do
  timestamp <- Timestamp <$> incrementEpoch

  mapM_ (send next timestamp) inputs

  finalize timestamp

{-# INLINE send #-}
-- | Send an `input` item to be worked on to the indicated vertex.
--
-- @since 0.1.0.0
send :: Edge input -> Timestamp -> input -> Dataflow ()
send e t i = lookupVertex e >>= invoke t i
  where
    invoke timestamp datum (StatefulVertex sref callback) = callback sref timestamp datum
    invoke timestamp datum (StatelessVertex callback)     = callback timestamp datum

-- Notify all relevant vertices that no more input is coming for `Timestamp`.
--
-- @since 0.1.0.0
finalize :: Timestamp -> Dataflow ()
finalize t = do
  finalizers <- Dataflow $ gets dfsFinalizers

  mapM_ (\p -> p t) finalizers