{-# LANGUAGE NoImplicitPrelude #-}
{-# OPTIONS -fglasgow-exts #-}
module STM
           (STM, TVar, newTVar, newTVarIO, readTVar, writeTVar,
            atomically,retry, Atomic, Twi, Safe, gbind, gret,twilight,getInconsistencies, isInconsistent,
            reload,safeTwiIO, rewriteTVar, rereadTVar,unsafeTwiIO,newRegion,tryCommit) where


-- for ghc
import Prelude hiding (mapM_)
import Counter
import Semaphore
import STMHelpers
import Control.Concurrent.MVar
import Data.IORef
import qualified Data.IntMap as Map 
import Data.Maybe
import Unsafe.Coerce
import Monad hiding (mapM_)
import Data.Foldable (mapM_)

import Control.VarStateM hiding (unlock)
-----------------------

-- The STM interface --
-----------------------

type Id = Int
type Timestamp = Int
data AccessFlag = W | WR | R deriving (Show,Eq)
data Twi = Twi 
data Atomic = Atomic 
data Safe = Safe
data Status = Ok [IO ()] | Bad [IO ()] | NotChecked
data Checked = Incons | Cons 

-- The STM monad 

data STM r p q a = STM ((StmState r) -> IO (STMResult r a))

instance Monad (STM r p q) where
    return x = STM (\state -> return (Success state x))
    (STM tr1) >>= k = STM (\state -> do
                          stmRes <- tr1 state
                          case stmRes of
                            Success newState a ->
                               let (STM tr2) = k a in
                                   tr2 newState
                            Retry newState -> return (Retry newState)
                            Error newState -> return (Error newState)
                       )


-- parametrized/indexed monad we use to distinguish between atomic and twilight operations

instance Monadish (STM r) where
    gret = return
    gbind = unsafeCoerce orig_bind where
          orig_bind :: STM r p q a -> (a -> STM r p q b) -> STM r p q b 
          orig_bind = (>>=)

data StmState x = StmState { 
                         tstamp :: Timestamp, 
                         wset :: Heap x,
                         rset :: Heap x,
                         regCounter :: Id, 
                         status :: Status }

data STMResult x a = Retry (StmState x)
                   | Success (StmState x) a  
                   | Error String


atomically :: (forall s. STM s p q a) -> IO a
atomically action = atomically' action

atomically' :: STM s p q a -> IO a
atomically' action = do
  state <- initialState 
  atomically'' action state
  where
    atomically'' :: STM s p q a  -> StmState s -> IO a
    atomically'' action state = do
      stmResult <- runSTM action state
      case stmResult of
        Error s -> do
          error s
        Retry state -> do

          case status state of
                    NotChecked  -> return ()
                    Bad unlocks -> unlockLocks unlocks
                    Ok unlocks  -> unlockLocks unlocks
           
          --putStr "Transaction issued a retry\n"
          atomically' action
        Success newState res -> do 
          
          case status newState of
                    NotChecked ->  do                           -- Transaction has no Twilight block
                              --safePutStrLn $ "Write set: "++ show (wset newState) ++"\nRead set:  "++show (rset newState)
                              wlocks <- enterTwilight newState
                            --safePutStrLn "continuing Twilight"
          
                              (valid,_) <- validateReadSet $ newState
                              if valid 
                               then do
                                        --putStr "Transaction succeeded\n"
                                        publishAndUnlock $ wset newState   --- successful commit
                                        return res
                               else do
                                       unlockLocks $ wlocks
                                       --safePutStrLn "Transaction failed validation"
                                       atomically' action       --- rollback

                    Bad unlocks -> do
                              --safePutStrLn "Transaction failed validation"
                              unlockLocks unlocks
                              atomically' action                    -- Transaction executed Twilight block and failed -> rollback 
                                                                -- TODO Shouldn't this trigger an error?

                    Ok _ ->  do                                   -- Transaction executed Twilight block and succeeded
                               --safePutStrLn "Transaction succeeded"
                               publishAndUnlock $ wset newState   
                               --safePutStrLn "Transaction published"           
                               return res
          
initialState :: IO (StmState r)
initialState = do
  --safePutStrLn "Starting..."
  timestamp <- sampleClock
  return (StmState { tstamp = timestamp, wset = Map.empty, rset = Map.empty, regCounter = 0, status = NotChecked} )


runSTM :: STM r p q a -> (StmState r) -> IO (STMResult r a)
runSTM action@(STM stm) state = stm state

retry :: STM r p q  a
retry = STM (\stmState -> do
          return (Retry stmState))


newTVar :: a -> STM r p p (TVar a)
newTVar x = STM (\stmState -> do 
                content <- newIORef x
                ident <- getUniqueId
                time <- newIORef (tstamp stmState)
                lck <- newMVar ()
                return $ Success stmState (TVar content ident time lck)
                )

writeTVar :: Show a => TVar a -> a -> STM r Atomic Atomic (WTwiVar a)
writeTVar tvar@(TVar _ ident _ _ ) x = STM (\stmState -> do
   let wset' = Map.insertWith 
                  (\ (MkHeapEntry tvar val ident r t _)  _ -> (MkHeapEntry (unsafeCoerce tvar) x ident (unsafeCoerce r) t Nothing))
                  ident
                  (MkHeapEntry tvar x ident Nothing (tstamp stmState) Nothing)
                  (wset stmState)                  
   return $ Success stmState{wset = wset'} (WTwiVar ident))

readTVar :: Show a => TVar a -> Region r a -> STM r Atomic Atomic (a, RTwiVar a)
readTVar tvar@(TVar mvar ident t lck) reg = STM (\stmState -> do
   let entry = Map.lookup ident (wset stmState)
   case entry of
     Just (MkHeapEntry _ x _ _ _ _) -> return $ Success stmState (unsafeCoerce x,RTwiVar ident)  -- local lookup in writeset
     Nothing -> do
          let entry = Map.lookup ident (rset stmState)
          case entry of
               Just (MkHeapEntry _ x _ _ _ _) -> return  $ Success stmState (unsafeCoerce x,RTwiVar ident)  -- local lookup in readset
               Nothing -> do
                       l <- tryTakeMVar lck
                       case l of 
                            Nothing -> return $ Retry stmState
                            Just _  -> do
                              time  <- readIORef t
                              x <- readIORef mvar
                              putMVar lck ()                              
                              if ((tstamp stmState) < time )
                                 then return $ Retry stmState
                                 else do
                                        let rset' = Map.insert ident 
                                                               (MkHeapEntry (unsafeCoerce tvar) x ident (Just reg) time Nothing) 
                                                               (rset stmState)                  
                                        return $ Success stmState{rset = rset'} (x,RTwiVar ident)
   )

is :: RTwiVar a -> TVar a -> Bool 
is (RTwiVar ident) (TVar tv ident' _ _) = ident == ident'  


-- Twilight extra API
twilight :: STM r Atomic Twi Bool
twilight =  STM (\stmState -> do 
          wlocks <- enterTwilight stmState
          (valid,rset')   <- validateReadSet $ stmState
          --safePutStrLn $ "valid read set? -> " ++ show valid    
          return (Success stmState{status = if valid then Ok wlocks else Bad wlocks, rset = rset'} valid )
          )

rewriteTVar :: Show a => WTwiVar a -> a -> STM r p p ()
rewriteTVar (WTwiVar ident) v = STM (\stmState -> do
          let entry = Map.lookup ident (wset stmState)
          case entry of
            Just (MkHeapEntry tvar val ident r t t') ->
                      let wset' = Map.insert
                              ident
                              (MkHeapEntry (unsafeCoerce tvar) v ident (unsafeCoerce r) t t')
                              (wset stmState)            
                      in return $ Success stmState{wset = wset'} ()
            Nothing ->
                return $ Error "You tried to rewrite a TVar which has not been written in the atomic before"
            )


rereadTVar :: RTwiVar a -> STM r p p a
rereadTVar (RTwiVar ident) = STM (\stmState -> do
          let entry = Map.lookup ident (rset stmState)
          case entry of
             Just (MkHeapEntry _ x _ _ _ _) -> return $ Success stmState (unsafeCoerce x)
             Nothing -> return $ Error "You tried to read a TVar which has not been read in the atomic before"
          )

lockReadAndWriteSet :: StmState r -> IO ([IO()], [IO ()])
lockReadAndWriteSet state = do
          let writeset = wset state
          let readset = rset state
          foldM (\(wlocks,rlocks) (ident,MkHeapEntry (TVar mv _ mt lck) _ _ _ _ _) -> do
                     takeMVar lck 
                     if Map.member ident writeset 
                        then return (putMVar lck ():wlocks,rlocks)
                        else return (wlocks,putMVar lck ():rlocks) 
                 ) 
                ([],[])
                (Map.toAscList $ Map.union writeset readset)

rereadAndUnlock :: [IO ()] -> Heap r -> IO (Heap r)
rereadAndUnlock rlocks rset = do
    rset' <- mapM (\(ident',(MkHeapEntry tvar@(TVar mv ident mt _ ) _ _ r t _)) -> do
                    newVal <- readIORef mv
                    newTime <- readIORef mt                    
                    return $ (ident',MkHeapEntry tvar (unsafeCoerce newVal) ident r t (Just newTime))  
                    )             
                    ( Map.toList rset)
    mapM_ (\x -> do x) rlocks       -- careful: you may only drop read lock, not write locks
    return $ Map.fromList rset'

exposed :: IO a -> StmState r -> IO (StmState r,a) 
exposed action stmState = do
    --safePutStrLn $ "starting exposed"    
    
    getPos semaphore
    --safePutStrLn $ "finished exposed"    
                
    case status stmState of
          Ok unlocks -> unlockLocks unlocks
          Bad unlocks -> unlockLocks unlocks
    result <- action

    (consistent,_) <- validateReadSet $ stmState
    if consistent
       then do 
          writelocks <- lockWriteSet $ wset stmState
          --safePutStrLn $ show consistent ++ "Blub!"    
    
          putPos semaphore          
              
          return (stmState{status = Ok writelocks},result)

       else do
          (writelocks,readlocks) <- lockReadAndWriteSet stmState
          putPos semaphore          
          rset' <- rereadAndUnlock readlocks (rset stmState)
          --unlockLocks readlocks
          return (stmState{rset = rset', status = Ok writelocks}, result)
    

reload :: STM r Twi Safe ()
reload = STM (\stmState -> do
          case status stmState of
               Ok  _ -> return $ Success stmState ()
               Bad _ -> do
                           (newState,_) <- exposed (return ()) stmState
                           return $ Success newState () )

tryCommit :: STM r Twi Safe ()
tryCommit = STM (\stmState -> do
              case status stmState of
               Ok  _ -> return $ Success stmState ()
               Bad _ -> return $ Retry stmState )


safeTwiIO :: IO a -> STM r Safe Safe a
safeTwiIO action = STM(\stmState -> do
                       result <- action
                       return $ Success stmState result)
      
unsafeTwiIO :: IO a -> STM r p p a
unsafeTwiIO action = STM (\stmState -> do
          a <- action
          return $ Success stmState a
          )

getInconsistencies :: Region r a -> STM r Safe Safe [(RTwiVar a,Maybe (WTwiVar a))]
getInconsistencies reg@(Region ident _) = STM (\stmState -> do
     let regionlist = filter (\ (_,(MkHeapEntry _ _ _ r _ _)) -> case r of
                                             Just (Region ident' _) -> ident == ident' ) (Map.toList $ rset stmState)    
     let inconslist = filter (\(_,MkHeapEntry tvar@(TVar _ _ timevar _) _ _ _ t t') -> case t' of
                                                                                             Nothing ->   False 
                                                                                             Just time -> t /= time ) 
                           regionlist    
     let result = map (\(ident, _ ) -> if (Map.member ident (wset stmState)) 
                                          then (RTwiVar ident, Just $ WTwiVar ident)
                                          else (RTwiVar ident, Nothing) )
                      inconslist
     return $ Success stmState result
   )


isInconsistent :: Region r a -> STM r p p Bool
isInconsistent reg@(Region ident _) = STM (\stmState -> do
     let regionlist = filter (\ (_,(MkHeapEntry _ _ _ r _ _)) -> case r of
                                             Just (Region ident' _) -> ident == ident' ) (Map.toList $ rset stmState)    
     let inconslist = filter (\(_,MkHeapEntry tvar@(TVar _ _ timevar _) _ _ _ t t') -> case t' of
                                                                                             Nothing ->   False 
                                                                                             Just time -> t /= time ) 
                           regionlist    
     return $ Success stmState (not $ null inconslist)
   )


-- Auxiliary functions: not exported code

{- Entering the twilight zone: 
Loop for locking the write set in combination with the neg semaphore. -}
enterTwilight :: StmState r -> IO [IO()] 
enterTwilight stmState = do
          getNeg semaphore
          success <- tryLockWriteset $ wset stmState 
          putNeg semaphore 
                              
          case success of 
                    Nothing -> enterTwilight stmState
                    Just locks -> return locks


{- Loop for locking the write set without semaphore involved. Spin locking!!!-}
lockWriteSet :: Heap r -> IO [IO ()]
lockWriteSet wset = do 
   success <- tryLockWriteset wset
   case success of
          Nothing -> lockWriteSet wset
          Just locks -> return locks


{- Locking the write set, releasing all the locks if it fails. 
Use the special lck in TVar for this purpose.
-}
tryLockWriteset :: Heap r -> IO (Maybe [IO ()])
tryLockWriteset wset = do 
   (success,unlocks) <- foldM (\(valid,unlock) (_,MkHeapEntry (TVar mv _ mt lck) val ident r _ _) ->
         if valid
            then do
                    l <- tryTakeMVar lck 
                    case l of 
                      Nothing -> return (False,unlock)
                      Just _  -> return (True,(putMVar lck () ):unlock) 
            else return (False,unlock)
           ) 
      (True,[])
      (Map.toAscList wset)
                 
   if (not success) 
       then do
         mapM (\x -> do x) unlocks 
         return Nothing
       else do 
          return $ Just unlocks


{- Validate the read set. 
If a read var is locked by some other transaction, return this as invalid. 
Otherwise compare the time stamps. -}
validateReadSet :: StmState r -> IO (Bool,Heap r)
validateReadSet state = 
  foldM (\(valid,rset') (_,MkHeapEntry tvar@(TVar mv _ mt lck) val ident r t _) -> do
                        if (Map.member ident (wset state)) 
                           then do
                               globTime <- readIORef mt
                               return (valid && (globTime == t), Map.insert ident (MkHeapEntry tvar val ident r t (Just globTime)) rset')
                           else do
                               l <- tryTakeMVar lck
                               case l of
                                    Nothing -> return (False,Map.insert ident (MkHeapEntry tvar val ident r t Nothing) rset')                          
                                    Just _ -> do globTime <- readIORef mt
                                                 putMVar lck ()
                                                 return (valid && (globTime == t), Map.insert ident (MkHeapEntry tvar val ident r t (Just globTime)) rset')
          ) 
      (True,Map.empty)
      (Map.toAscList $ rset state)


{- Write the new value and commit time stamp to the shared memory. Release then the lock. -}  
publishAndUnlock :: Heap r ->  IO ()
publishAndUnlock wset = do 
  t_commit <- sampleClock
  mapM_ (\(MkHeapEntry (TVar mv _ mt lck) val ident r _ _) -> do
            writeIORef mv val
            writeIORef mt t_commit
            putMVar lck ()
          ) 
       wset
  
{- Release the given set of locks by running the corresponding IO actions -}
unlockLocks :: [IO ()] -> IO ()
unlockLocks unlocks = 
   mapM_ (\x -> do x) unlocks
  


-- Transactional variables
data TVar a = TVar (IORef a)          -- value
                   Id                 -- unique ident, for ordering when locking     
                   (IORef Timestamp)  -- timestamp
                   (MVar ())          -- lock

instance Show (TVar a) where
  show (TVar _ ident _ _) = "(" ++ show ident ++ ")"  

newTVarIO :: a -> IO (TVar a)
newTVarIO x = do 
              content <- newIORef x
              ident <- getUniqueId
              time <- newIORef (-1)
              lck <- newMVar ()
              return (TVar content ident time lck)

data RTwiVar a = RTwiVar Id
data WTwiVar a = WTwiVar Id
                
-- Local heap, comprises read and write set
     
data HeapEntry r = forall a. Show a => MkHeapEntry (TVar a) a Id (Maybe (Region r a)) Timestamp (Maybe Timestamp)

--hideType :: TVar a -> HeapEntry 
--hideType (TVar (MVar x) ident t) = MkHeapEntry x ident t R

--instance Show (HeapEntry r) where
--  show (MkHeapEntry tvar a ident t ) = "(" ++show a ++", "++ show ident ++ ")"

type Heap r = Map.IntMap (HeapEntry r)

-- Regions for grouping TVars 
data Region r a = Region Id Checked

newRegion :: STM r Atomic Atomic (Region r a)
--newRegion :: (forall x . Region x a -> STM Atomic Atomic b) -> STM Atomic Atomic b
newRegion = STM (\stmState -> do
          let ident = regCounter stmState
          return $ Success stmState{regCounter = ident+1} ( Region ident Cons))




-- Auxiliaries: Counters 

-- for the timestamps
globalClock :: Counter 
globalClock = getCounter

sampleClock :: IO Int
sampleClock = getAndIncr globalClock 

semaphore :: Semaphore
semaphore = getSemaphore

-- for unique Ids
uniqueTVarId :: Counter
uniqueTVarId = getCounter

getUniqueId :: IO Int
getUniqueId = getAndIncr uniqueTVarId

--- Tests
main :: IO ()
main = do

  q <- atomically $ gret 3
  putStrLn "Empty transaction works\n"
  -- test read
  let val = 89

  y <- atomically $ newTVar val `gbind` \a -> gret a
  putStrLn "Creating new TVars works\n"
          
  x <- atomically $ newRegion `gbind` \r -> readTVar y r  `gbind` \(u,_) -> gret u
  putStrLn $ show x ++ " should be " ++ show val
  --- test write
 
  let new_val = 25
  z <- atomically $ writeTVar y new_val `gbind` \_ -> gret "done"
  putStrLn z
  
  let new_val = 34
  z <- atomically $ newRegion `gbind` \r -> writeTVar y new_val `gbind` \_ -> readTVar y r `gbind` \(u,_) -> gret u
  putStrLn $ show z ++ " should be " ++ show new_val

  --- test write
  let new_val = 78
  z <- atomically $ newRegion `gbind` \r -> writeTVar y new_val `gbind` \_ -> readTVar y r `gbind` \(u,_) -> twilight `gbind` \_ -> gret u
  putStrLn $ show z ++ " should be " ++ show new_val
{--
  --- test rewrite
  let v = 7
  w <- atomically $ newRegion `gbind` \r -> writeTVar y new_val `gbind` \y' -> readTVar y r `gbind` \u -> twilight `gbind` \_ -> rewriteTVar y' v `gbind` \_ -> gret v 
  putStrLn $ show w ++ " should be " ++ show v
  
  --}
 
 --  reg <- atomically $ newRegion

 -- a <- atomically $ readTVar y reg

 -- b <- atomically $ readTVar y reg


  return ()