-------------------------------------------------------------------------------
-- minimal example of a much bigger program (which is my excuse for strange
-- looking code)
--
-- compile with
--
--   ghc --make -O2 -threaded Pi.hs -o pi -rtsopts -fforce-recomp
--
-- and execute it (for example, on a 2 core system) with
--
--   pi +RTS -N2
--
-- For inifite runs, try
--
--   i=0; while true;do printf "%4d\n" $((i=$i+1));pi +RTS -N2;done
--
-- on a system with bash.
-------------------------------------------------------------------------------

{-# LANGUAGE FlexibleContexts, MultiParamTypeClasses, FlexibleInstances#-}
module Main where
import Control.Monad
import Control.Concurrent
import Control.Parallel.Strategies
import System.Environment
import Data.Array.IO
import Control.Concurrent.STM
import System.Posix.Signals
import Data.Maybe
import GHC.Conc
import System.IO
import System.IO.Unsafe
import qualified Data.Set as Set
import Data.Set (Set)
import Data.Time.Format
import System.Locale
import Data.Time
import Data.Time.Clock.POSIX
import Text.Printf


main :: IO ()
main = do
    installLogger 
    pool <- newSTMPQueue 16
    let tasks = replicate 64 1
    mapM_ (putSTMPQueue pool) tasks
    ids <- forM [0..numCapabilities-1] (\n -> forkIO $ thread n calcPi pool)
    waitSTMPQueue pool 
    debug "__________________________________"
    writeLog_ True
    putStrLn "SHOULD EXIT"


thread n f pool = do
    task <- getSTMPQueue n pool
    case task of
        Nothing -> return ()
        Just t  -> do 
            f pool t 
            thread n f pool


--- pi calculation functions for number crunching ----------------------------
-- doing nothing to allow compilcation with GHC HEAD!!! the bug still occurs.

-- calcPiPure :: Int -> Int
-- calcPiPure digits = showCReal (fromEnum digits) pi `pseq` 1


calcPi :: t -> Int -> IO ()
calcPi _ digits = do
    debug $ "T calc " ++ show digits
    --calcPiPure digits `pseq` return ()
    debug $ "T finished"
    


--- STM based global queue with an additional private part ------------------
data Show a => STMPQueue a = STMPQueue {
      stmChan     :: STMCQueue a
    , stmState    :: TVar STMState
    , stmFinished :: TChan ()
    , stmWorking  :: TVar (Set ThreadId)

    -- for the private queue
    , stmPrivate  :: TArray (Int,Int) a     -- currently a bit slow...
    , stmIndex    :: TArray Int (Int, Int)
    , stmSize     :: Int
}

data STMState = 
      SPut
    | SWait
    deriving (Show, Eq)


newSTMPQueue :: Show a => Int -> IO (STMPQueue a)
newSTMPQueue size = do
    chan     <- newSTMCQueue
    state    <- newTVarIO SPut
    finished <- newTChanIO
    working  <- newTVarIO Set.empty
    
    (private, index) <- atomically $ do
        private <- newArray_ ((0,0), (numCapabilities-1,size-1))
        index   <- newArray (0,numCapabilities-1) (0,0)
        return (private,index)
    return $ STMPQueue chan state finished working private index size


putSTMPQueue :: Show a => STMPQueue a -> a -> IO ()
putSTMPQueue (STMPQueue chan state finished working _ _ _) a = do
    atomically $ writeSTMCQueue chan a


waitSTMPQueue pool@(STMPQueue chan state finished working _ _ _) = do 
    atomically $ writeTVar state SWait 
    atomically $ do
        work  <- readTVar working
        empty <- isEmptySTMCQueue chan
        check (Set.null work && empty)
        writeTVar state SPut
    return ()


getSTMPQueue :: Show a => Int -> STMPQueue a -> IO (Maybe a)
getSTMPQueue idx pool@(STMPQueue chan state finished working private index 
  size) = do
    (curidx,midx) <- atomically $ readArray index idx
    if curidx == midx
        then do
            debug "private queue empty"
            loop
        else do task <- atomically $ readArray private (idx, curidx)
                atomically $ writeArray index idx (curidx+1,midx)
                return (Just task)

  where loop = do
            tid <- myThreadId

            atomically $ do
                work  <- Set.delete tid `fmap` readTVar working
                writeTVar working $! work
           
            atomically $ do
                empty <- isEmptySTMCQueue chan
                work  <- readTVar working
                op    <- readTVar state

                if (not empty) 
                    then do 
                        a@(task:rest) <- readSTMCQueue chan (size+1)
                        unsafeIOToSTM $ debug $ "my tasks: " ++ show a
                        forM_ (zip [0..] rest) $ \(col,v) -> 
                            writeArray private (idx,col) v
                        writeArray index idx $! (0,length rest)
                        writeTVar working $! (Set.insert tid work)
                        return (Just task)
                    else do
                        case op of
                            SPut  -> retry
                            SWait -> do
                                if Set.null work
                                    then return Nothing
                                    else retry


--- For debugging output ----------------------------------------------------
debugChan :: Chan String
debugChan = unsafePerformIO newChan

debugTime :: MVar UTCTime
debugTime = unsafePerformIO (newMVar =<< getCurrentTime)

debug :: String -> IO ()
debug msg = writeChan debugChan =<< debugStr msg

debug_ :: String -> IO ()
debug_ msg = do
    s <- debugStr msg
    hPutStrLn stderr s
    debug msg

debugStr :: String -> IO String
debugStr msg = do
    tid  <- (drop 9 . show) `fmap` myThreadId
    t    <- getCurrentTime
    told <- swapMVar debugTime t
    let td = diffUTCTime t told
        ts = init $ show td
        tl = if read ts < 0.0
                 then "0.000000  "
                 else printf "%-10s" ts
    return $ tl ++ " " ++ tid ++ " " ++ msg


writeLog :: IO ()
writeLog = writeLog_ False


-- check, if an environment variable LOG exists.
writeLog_ :: Bool -> IO ()
writeLog_ checkEnv = do
     env <- map fst `fmap` getEnvironment
     when (checkEnv || ("LOG" `elem` env)) showLog
  where showLog = do
            e <- isEmptyChan debugChan
            unless e $ do
                value <- readChan debugChan
                putStrLn value
                showLog

numEnv :: String -> IO Int
numEnv ss = do
     (read . fromJust . lookup ss) `fmap` getEnvironment
    

whenEnv :: String -> IO () -> IO ()
whenEnv ss f = do
     env <- map fst `fmap` getEnvironment
     when (ss `elem` env) f
    

installLogger :: IO ()
installLogger = do
    installHandler sigINT (CatchOnce handleInt) Nothing
    return ()
  where handleInt = do
            putStrLn "" -- prevent that ^C is displayed on 1st log line
            debug "SIGINT caught."
            writeLog_ True
            raiseSignal sigINT


--- Counting Chan using STM -------------------------------------------------
data STMCQueue a = STMCQueue {
      scqList :: TChan a
    , scqSize :: TVar Int
}


newSTMCQueue :: IO (STMCQueue a) 
newSTMCQueue = do
    list <- newTChanIO
    size <- newTVarIO 0
    return (STMCQueue list size)


isEmptySTMCQueue :: STMCQueue a -> STM Bool 
isEmptySTMCQueue (STMCQueue list size) = do
    ioSize <- readTVar size
    return $! (ioSize == 0)


writeSTMCQueue :: STMCQueue a -> a -> STM () 
writeSTMCQueue (STMCQueue list size) value = do
    ioSize <- readTVar size
    writeTChan list value
    writeTVar size $! (ioSize + 1)


writeList2STMCQueue :: STMCQueue a -> [a] -> STM () 
writeList2STMCQueue (STMCQueue list size) values = do
    ioSize <- readTVar size
    mapM_ (writeTChan list) values
    writeTVar size $! (ioSize + length values)


readSTMCQueue :: STMCQueue a -> Int -> STM [a] 
-- return min(n,size) elements from list
readSTMCQueue (STMCQueue list size) n = do
    ioSize <- readTVar size
    unsafeIOToSTM $ debug $ "# queue: " ++ show ioSize
    let taken = min ioSize n
    values <- replicateM taken (readTChan list) 
    check (length values == taken)
    writeTVar size (ioSize - taken)
    return values



