{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ExistentialQuantification #-}
module Control.Concurrent.STM.Twilight
           (STM, TVar, Region, RTwiVar, WTwiVar,
            newTVar, newTVarIO, readTVar, writeTVar,  readTVarR, writeTVarR,
            atomically,retry, Atomic, Twi, Safe,
            gbind, gret,twilight,
            isInconsistent,
            reload,safeTwiIO, rewriteTVar, rereadTVar,unsafeTwiIO,
            newRegion,tryCommit,ignoreAllConflicts,markTVar) where

-- for ghc
import Prelude hiding (mapM_)
import Control.Concurrent.STM.Counter
import Control.Concurrent.MVar
import Data.IORef
import qualified Data.IntMap as Map
import qualified Data.HashTable as Table
import GHC.Int
import Data.Maybe
import Unsafe.Coerce
import Monad hiding (mapM_)
import Data.Foldable (mapM_)

import Control.Concurrent.STM.Monadish
-----------------------

-- The STM interface --
-----------------------

type Id = Int
type Idreg = Int

type Timestamp = Int
data AccessFlag = W | WR | R deriving (Show,Eq)
data Twi = Twi
data Atomic = Atomic
data Safe = Safe
data Status = Ok [IO ()] | Bad [IO ()] | NotChecked

-- The STM monad

data STM r p q a = STM ((StmState r) -> IO (STMResult r a))

instance Monad (STM r p q) where
    return x = STM (\state -> return (Success state x))
    (STM tr1) >>= k = STM (\state -> do
                          stmRes <- tr1 state
                          case stmRes of
                            Success newState a ->
                               let (STM tr2) = k a in
                                   tr2 newState
                            Retry newState -> return (Retry newState)
                            Error newState -> return (Error newState)
                       )


-- parametrized/indexed monad we use to distinguish between atomic and twilight operations

instance Monadish (STM r) where
    gret = return
    gbind = unsafeCoerce orig_bind where
          orig_bind :: STM r p q a -> (a -> STM r p q b) -> STM r p q b
          orig_bind = (>>=)

data StmState x = StmState {
                         tstamp :: Timestamp,
                         wset :: Heap x,
                         rset :: Heap x,
                         status :: Status }

data STMResult x a = Retry (StmState x)
                   | Success (StmState x) a
                   | Error String


atomically :: (forall s. STM s p q a) -> IO a
atomically action = atomically' action

atomically' :: STM s p q a -> IO a
atomically' action = do
  state <- initialState
  atomically'' action state
  where
    atomically'' :: STM s p q a  -> StmState s -> IO a
    atomically'' action state = do
      stmResult <- runSTM action state
      case stmResult of
        Error s -> do
          error s
        Retry state -> do

          case status state of
                    NotChecked  -> return ()
                    Bad unlocks -> unlockLocks unlocks
                    Ok unlocks  -> unlockLocks unlocks

          --safePutStrLn "Transaction issued a retry"
          atomically' action
        Success newState res -> do

          case status newState of
                    NotChecked ->  do     -- Transaction has no Twilight block
                              --safePutStrLn $ "Write set: "++ show (wset newState) ++"\nRead set:  "++show (rset newState)
                              --safePutStrLn "entering Twilight"
                              wlocks <- enterTwilight newState
                              --safePutStrLn "continuing Twilight"

                              valid <- validateReadSetStrict $ newState
                              --safePutStrLn "validated Read set"

                              if valid
                               then do
                                        --safePutStrLn "Transaction succeeded"
                                        publishAndUnlock $ wset newState   --- successful commit
                                        --safePutStrLn "released wlocks"
                                        return res
                               else do
                                       unlockLocks $ wlocks
                                       --safePutStrLn "Transaction failed validation\n"
                                       atomically' action       --- rollback

                    Bad unlocks -> do
                              --safePutStrLn "Transaction failed validation"
                              unlockLocks unlocks
                              atomically' action                -- Transaction executed Twilight block and failed -> rollback
                                                                -- TODO Shouldn't this trigger an error?

                    Ok _ ->  do                                   -- Transaction executed Twilight block and succeeded
                               --safePutStrLn "Transaction succeeded"
                               publishAndUnlock $ wset newState
                               --safePutStrLn "Transaction published"
                               return res

initialState :: IO (StmState r)
initialState = do
  --safePutStrLn "Starting..."
  timestamp <- sampleClock
  --regMap <- Table.new (\x y -> x==y) (\x -> x)
  return (StmState { tstamp = timestamp, wset = Map.empty, rset = Map.empty, status = NotChecked} )


runSTM :: STM r p q a -> (StmState r) -> IO (STMResult r a)
runSTM action@(STM stm) state = stm state

retry :: STM r p q  a
retry = STM (\stmState -> do
          return (Retry stmState))


newTVar :: a -> STM r p p (TVar a)
newTVar x = STM (\stmState -> do
                content <- newIORef x
                ident <- getUniqueId
                time <- newIORef (tstamp stmState, Free)
                return $ Success stmState (TVar content ident time)
                )

writeTVarR :: TVar a -> a -> STM r Atomic Atomic (WTwiVar a)
writeTVarR tvar@(TVar _ ident _ ) x = STM (\stmState -> do
   --putStrLn $ "Adding " ++ show ident ++ " to write set"
   let wset' = modifyWSet (wset stmState) tvar x (tstamp stmState)
   return $ Success stmState{wset = wset'} (WTwiVar ident))

writeTVar :: TVar a -> a -> STM r Atomic Atomic ()
writeTVar tvar x = STM (\stmState -> do
   let wset' = modifyWSet (wset stmState) tvar x (tstamp stmState)
   return $ Success stmState{wset = wset'} ())


modifyWSet :: Heap x -> TVar a -> a -> Timestamp -> Heap x
modifyWSet wset tvar@(TVar _ ident _ ) x tstamp =
       Map.insertWith
                  (\ (MkHeapEntry tvar val ident t _ )  _ ->
                     (MkHeapEntry (unsafeCoerce tvar) x ident t Nothing ))
                  ident
                  (MkHeapEntry tvar x ident tstamp Nothing )
                  wset

markTVar :: TVar a -> Region r -> STM r Atomic Atomic ()
markTVar tvar reg@(Region r) = STM (\stmState -> do {
  ; entries <- readIORef r
  ; writeIORef r ((unsafeCoerce tvar):entries)
  --putStrLn $ "Marked var with region " ++ show r
  ; return $ Success stmState ()})

readTVarR :: TVar a -> Region r -> STM r Atomic Atomic (a, RTwiVar a)
readTVarR tvar@(TVar mvar ident tlck) reg@(Region r) = STM (\stmState -> do {
  ; entries <- readIORef r
  ; writeIORef r ((unsafeCoerce tvar):entries)

  --; putStrLn $ "Marked var with region " ++ show r
  ; let entry = Map.lookup ident (wset stmState)
  ; case entry of
     Just (MkHeapEntry _ x _ _ _ ) -> return $ Success stmState (unsafeCoerce x, RTwiVar ident)  -- local lookup in writeset
     Nothing -> do
          let entry = Map.lookup ident (rset stmState)
          case entry of
               Just (MkHeapEntry _ x _ _ _ ) -> return  $ Success stmState (unsafeCoerce x, RTwiVar ident)  -- local lookup in readset

               Nothing -> do
                       (t,l) <- readIORef tlck
                       case l of
                            Lcked -> return $ Retry stmState
                            _ -> do
                              if ((tstamp stmState) < t)         -- modified since start of the transaction
                                 then return $ Retry stmState
                                 else do
                                        x <- readIORef mvar
                                        (t,l) <- readIORef tlck  -- check for intermediate update
                                        if ( l == Lcked || (tstamp stmState) < t)
                                         then return $ Retry stmState
                                         else (do
                                                 let rset' = Map.insert ident
                                                               (MkHeapEntry (unsafeCoerce tvar) x ident t Nothing )
                                                               (rset stmState)
                                                 return $ Success stmState{rset=rset'} (x,RTwiVar ident))

   })


readTVar :: TVar a -> STM r Atomic Atomic a
readTVar tvar@(TVar mvar ident tlck)  = STM (\stmState -> do
   let entry = Map.lookup ident (wset stmState)
   case entry of
     Just (MkHeapEntry _ x _ _ _ ) -> return $ Success stmState (unsafeCoerce x)  -- local lookup in writeset
     Nothing -> do
          let entry = Map.lookup ident (rset stmState)
          case entry of
               Just (MkHeapEntry _ x _ _ _ ) -> return  $ Success stmState (unsafeCoerce x)  -- local lookup in readset
               Nothing -> do
                       (t,l) <- readIORef tlck
                       case l of
                            Lcked -> return $ Retry stmState
                            _ -> do
                              if ((tstamp stmState) < t)         -- modified since start of the transaction
                                 then return $ Retry stmState
                                 else do
                                        x <- readIORef mvar
                                        (t,l) <- readIORef tlck  -- check for intermediate update
                                        if ( l == Lcked || (tstamp stmState) < t)
                                         then return $ Retry stmState
                                         else (do
                                                 let rset' = Map.insert ident
                                                               (MkHeapEntry (unsafeCoerce tvar) x ident t Nothing )
                                                               (rset stmState)
                                                 return $ Success stmState{rset = rset'} (x))
    )

is :: RTwiVar a -> TVar a -> Bool
is (RTwiVar ident) (TVar tv ident' _ ) = ident == ident'


-- Twilight extra API
twilight :: STM r Atomic Twi Bool
twilight =  STM (\stmState -> do
          wlocks <- enterTwilight stmState
          valid <- validateReadSetStrict $ stmState
          --putStrLn $ "no wlocks? -> " ++ show (null wlocks)
          return (Success stmState{status = if valid then Ok wlocks else Bad wlocks} valid )
          )

rewriteTVar :: WTwiVar a -> a -> STM r p p ()
rewriteTVar (WTwiVar ident) v = STM (\stmState -> do
          let entry = Map.lookup ident (wset stmState)
          case entry of
            Just (MkHeapEntry tvar val ident t t' ) ->
                      let wset' = Map.insert
                              ident
                              (MkHeapEntry (unsafeCoerce tvar) v ident t t' )
                              (wset stmState)
                      in return $ Success stmState{wset = wset'} ()
            Nothing ->
                return $ Error "You tried to rewrite a TVar which has not been written in the atomic before"
            )


rereadTVar :: RTwiVar a -> STM r p p a
rereadTVar (RTwiVar ident) = STM (\stmState -> do
          let entry = Map.lookup ident (rset stmState)
          case entry of
             Just (MkHeapEntry _ x _ _ _ ) -> return $ Success stmState (unsafeCoerce x)
             Nothing -> return $ Error "You tried to read a TVar which has not been read in the atomic before"
          )


reloadRset :: Heap r -> IO (Bool, Heap r)
reloadRset rset = do
     time <- sampleClock
     result <- rereadCons time rset
     case result of
          (_,False,_) -> reloadRset rset
          (cons,True, rset') -> return (cons,rset')

rereadCons :: Timestamp -> Heap r -> IO (Bool, Bool, Heap r)
rereadCons t rset = do
      foldM (\(cons,valid,rset') (_,MkHeapEntry tvar@(TVar mv _ tlck) val ident t _ ) -> do
                               (globTime,l) <- readIORef tlck
                               case l of
                                    Lcked -> return (False,False,Map.insert ident (MkHeapEntry tvar val ident t Nothing ) rset')
                                    _ -> do  newval <- readIORef mv
                                             return (cons && (globTime == t), valid, Map.insert ident (MkHeapEntry tvar newval ident t (Just globTime) ) rset')
          )
       (True,True,Map.empty)
       (Map.toList rset)



reload :: STM r Twi Safe ()
reload = STM (\stmState -> do
          case status stmState of
               Ok  _ -> return $ Success stmState ()
               Bad l -> do
                        (_,newRset) <- reloadRset $ rset stmState
                        return $ Success stmState{rset=newRset,status=Ok l} () )

ignoreAllConflicts :: STM r a Safe ()
ignoreAllConflicts = STM (\stmState -> do
          case status stmState of
               Ok  _ -> return $ Success stmState ()
               NotChecked -> return $ Success stmState{status=Ok []} ()
               Bad l -> return $ Success stmState{status=Ok l} () )


exposed :: IO a -> StmState r -> IO (StmState r,a)
exposed action stmState = do
    result <- action
    return (stmState,result)



tryCommit :: STM r Twi Safe ()
tryCommit = STM (\stmState -> do
              case status stmState of
               Ok  _ -> return $ Success stmState ()
               Bad _ -> return $ Retry stmState )

safeTwiIO :: IO a -> STM r Safe Safe a
safeTwiIO action = STM(\stmState -> do
                       result <- action
                       return $ Success stmState result)

unsafeTwiIO :: IO a -> STM r p p a
unsafeTwiIO action = STM (\stmState -> do
          a <- action
          return $ Success stmState a
          )

isInconsistent :: Region r -> STM r p p Bool
isInconsistent (Region r) = STM (\stmState -> do {
                ; reglist <- readIORef r
                ; incons <-  do  {
                       ; incons <- foldM (\test tvar@(TVar _ id tlck) -> do {
                                 ; b2 <- if test then return test
                                                else (do {
                                                 ; let entry = (rset stmState) Map.! id
                                                 ; b <- case entry of (MkHeapEntry _ _ _ t _ ) ->  do {
                                                        ; (currt,l) <- readIORef tlck
                                                        ; return $ not (currt == t || l == Lcked) }

                                                 ; return b })
                                 ; return b2} ) False reglist
                        ; return incons}
               -- ; putStrLn ( "is inConsistent? " ++ show incons )
               ; return $ Success stmState incons})


-- Auxiliary functions: not exported code

{- Entering the twilight zone:
Loop for locking the write set in combination with the neg semaphore. -}
enterTwilight :: StmState r -> IO [IO()]
enterTwilight stmState = lockWriteSet $ wset stmState


{- Loop for locking the write set. Spin locking!!!-}
lockWriteSet :: Heap r -> IO [IO ()]
lockWriteSet wset = do
   success <- tryLockWriteset wset
   case success of
          Nothing -> lockWriteSet wset
          Just locks -> return locks


{- Locking the write set, releasing all the locks if it fails. -}
tryLockWriteset :: Heap r -> IO (Maybe [IO ()])
tryLockWriteset wset = do
   (success,unlocks) <- foldM (\(valid,unlock) (_,MkHeapEntry (TVar mv _ tlck) val ident _ _ ) ->
         if valid
            then do
                    result <- atomicModifyIORef tlck (\(t,l) -> case l of
                                                            Free -> ((t,Reserved),True)
                                                            _ -> ((t,l),False)
                              )
                    if result
                       then return (True, atomicModifyIORef tlck (\(t,_) -> ((t,Free),())):unlock)
                       else return (False,unlock)

            else return (False,unlock)
           )
      (True,[])
      (Map.toAscList wset)

   if (not success)
       then do
         mapM (\x -> do x) unlocks
         return Nothing
       else do
          return $ Just unlocks


{- Validate the read set.
If a read var is locked by some other transaction, return this as invalid.
Otherwise compare the time stamps. -}
validateReadSetStrict :: StmState r -> IO Bool
validateReadSetStrict state =
  foldM (\valid (_,MkHeapEntry tvar@(TVar mv _ tlck) val ident t _ ) -> do
                               (globTime,l) <- readIORef tlck
                               case l of
                                    Lcked -> return False
                                    Reserved ->  if (Map.member ident (wset state))
                                        then do
                                            (globTime,l) <- readIORef tlck
                                            return (valid && (globTime == t))
                                        else return False
                                    Free -> return (valid && (globTime == t))
          )
      True
      (Map.toAscList $ rset state)

{- Write the new value and commit time stamp to the shared memory. Release then the lock. -}
publishAndUnlock :: Heap r ->  IO ()
publishAndUnlock wset = do
  mapM_ (\(MkHeapEntry (TVar mv _ tlck) val ident _ _ ) -> do
            atomicModifyIORef tlck (\(t,_) -> ((t,Lcked),()))
          )  wset
  t_commit <- sampleClock
  mapM_ (\(MkHeapEntry (TVar mv _ tlck) val ident _ _ ) -> do
            writeIORef mv val
            writeIORef tlck (t_commit,Free)
          )
       wset

{- Release the given set of locks by running the corresponding IO actions -}
unlockLocks :: [IO ()] -> IO ()
unlockLocks unlocks =
   mapM_ (\x -> do x) unlocks



-- Transactional variables
data LockState = Free | Reserved | Lcked deriving (Eq,Show)
data TVar a = TVar (IORef a)          -- value
                   Id                 -- unique ident, for ordering when locking
                   (IORef (Timestamp,LockState))  -- timestamp and lock
               deriving Eq


instance Show (TVar a) where
  show (TVar _ ident _ ) = "(" ++ show ident ++ ")"

newTVarIO :: a -> IO (TVar a)
newTVarIO x = do
              content <- newIORef x
              ident <- getUniqueId
              time <- newIORef (-1,Free)
              return (TVar content ident time)-- lck)

data RTwiVar a = RTwiVar !Id deriving (Show, Eq)
data WTwiVar a = WTwiVar !Id deriving (Show, Eq)

-- Local heap, comprises read and write set

data HeapEntry r = forall a. MkHeapEntry
                                             (TVar a)   -- corresponding TVar
                                             a          -- value of this TVar at time of reading
                                             Id         -- unique identifier for TVar

                                             Timestamp            -- timestamp at time of first reading
                                             (Maybe Timestamp)    -- modified timestamp as found when reloading


--hideType :: TVar a -> HeapEntry
--hideType (TVar (MVar x) ident t) = MkHeapEntry x ident t R

--instance Show (HeapEntry r) where
--  show (MkHeapEntry tvar a ident t ) = "(" ++show a ++", "++ show ident ++ ")"

type Heap r = Map.IntMap (HeapEntry r)

-- Regions for grouping TVars
data Region r = forall a. Region (IORef [TVar a])

newRegion :: STM r Atomic Atomic (Region r)
newRegion = STM (\stmState -> do
          r <- newIORef []
          --putStrLn $ "newRegion " ++ show ident
          return $ Success stmState (Region r))




-- Auxiliaries: Counters

-- for the timestamps
globalClock :: Counter
globalClock = getCounter

sampleClock :: IO Int
sampleClock = getAndIncr globalClock

--semaphore :: Semaphore
--semaphore = getSemaphore

-- for unique Ids
uniqueTVarId :: Counter
uniqueTVarId = getCounter

getUniqueId :: IO Int
getUniqueId = getAndIncr uniqueTVarId

--- Tests
main :: IO ()
main = do

  atomically $ gret ()
  putStrLn "Empty transaction works\n"
  -- test read
  let val = 89

  y <- atomically $ newTVar val `gbind` \a -> gret a
  putStrLn "Creating new TVars works\n"

  x <- atomically $ readTVar y  `gbind` \u-> gret u
  putStrLn $ show x ++ " should be " ++ show val
  --- test write

  let new_val = 25
  z <- atomically $ writeTVarR y new_val `gbind` \_ -> gret "done"
  putStrLn z

  let new_val = 34
  z <- atomically $ newRegion `gbind` \r -> writeTVarR y new_val `gbind` \_ -> readTVarR y r `gbind` \(u,_) -> gret u
  putStrLn $ show z ++ " should be " ++ show new_val

  --- test write
  let new_val = 78
  z <- atomically $ newRegion `gbind` \r -> writeTVarR y new_val `gbind` \_ -> readTVarR y r `gbind` \(u,_) -> twilight `gbind` \_ -> gret u
  putStrLn $ show z ++ " should be " ++ show new_val
{--
  --- test rewrite
  let v = 7
  w <- atomically $ newRegion `gbind` \r -> writeTVar y new_val `gbind` \y' -> readTVar y r `gbind` \u -> twilight `gbind` \_ -> rewriteTVar y' v `gbind` \_ -> gret v
  putStrLn $ show w ++ " should be " ++ show v

  --}

 --  reg <- atomically $ newRegion

 -- a <- atomically $ readTVar y reg

 -- b <- atomically $ readTVar y reg


  return ()