{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
module Control.Clock.IO
  ( newClock
  , newClock'
  , Intv(..)
  , interval
  , convClock
  , clockWithIO
  , clockTimerIO
  , voidInput
  , module Control.Clock
  )
where
import qualified System.Time.Monotonic          as T
import           Control.Concurrent             (threadDelay)
import           Control.Concurrent.Async       (async, cancel, link, link2,
                                                 race)
import           Control.Concurrent.STM         (STM, atomically, orElse)
import           Control.Concurrent.STM.TBQueue (TBQueue, isEmptyTBQueue,
                                                 newTBQueueIO, readTBQueue,
                                                 tryPeekTBQueue, tryReadTBQueue,
                                                 writeTBQueue)
import           Control.Monad                  (forever, unless, when)
import           Data.Time.Clock                (DiffTime,
                                                 diffTimeToPicoseconds,
                                                 picosecondsToDiffTime)
import           Data.Void                      (Void)
import           GHC.Stack                      (HasCallStack)
import           Control.Clock
newClock :: Tick -> DiffTime -> IO (Clock IO)
newClock start intv = convClock start intv <$> T.newClock
newClock' :: DiffTime -> IO (Clock IO)
newClock' = newClock 0
data Intv = Ps | Ns | Us | Ms | S
interval :: Integer -> Intv -> DiffTime
interval i u = picosecondsToDiffTime $ case u of
  Ps -> i
  Ns -> 1000 * i
  Us -> 1000000 * i
  Ms -> 1000000000 * i
  S  -> 1000000000000 * i
checkNonNeg :: (HasCallStack, Num a, Ord a, Show a) => a -> a
checkNonNeg n =
  if n >= 0 then n else error $ "must be non-negative: " ++ show n
checkPos :: (HasCallStack, Num a, Ord a, Show a) => a -> a
checkPos n = if n > 0 then n else error $ "must be positive: " ++ show n
convClock :: Tick -> DiffTime -> T.Clock -> Clock IO
convClock start intv c =
  let r  = diffTimeToPicoseconds $ checkPos intv
      i  = start * r
      c' = Clock
        { clockNow   = (`div` r) <$> clockNowPico i c
        , clockDelay = \d -> when (d > 0) $ do
                         remain <- (`rem` r) <$> clockNowPico i c
                         
                         let t = r * fromIntegral d * 16 `div` 15 - remain
                         clockDelayPico t
        , clockWith  = clockWithIO c'
        , clockTimer = clockTimerIO c'
        }
  in  c'
clockNowPico :: Tick -> T.Clock -> IO Integer
clockNowPico start c = (start +) . diffTimeToPicoseconds <$> T.clockGetTime c
clockDelayPico :: Integer -> IO ()
clockDelayPico d = T.delay $ picosecondsToDiffTime $ checkNonNeg d
writeTBQueue' :: HasCallStack => TBQueue a -> a -> STM ()
writeTBQueue' q r = do
  e <- isEmptyTBQueue q
  unless e $ error "failed to assert non-blocking write on TBQueue"
  writeTBQueue q r
clockWithIO :: Clock IO -> IO a -> IO (Clocked IO a)
clockWithIO clock action = do
  qi           <- newTBQueueIO 1
  qo           <- newTBQueueIO 1
  qt           <- newTBQueueIO 1
  
  actionThread <- async $ forever $ do
    
    atomically $ do
      readTBQueue qi
      writeTBQueue' qi ()
    r <- action
    
    atomically $ do
      writeTBQueue' qo r
      readTBQueue qi
  link actionThread
  
  tickThread <- async $ forever $ do
    t <- clockTick clock 1
    atomically $ do
      _ <- tryReadTBQueue qt 
      writeTBQueue' qt t
  link tickThread
  
  
  link2 actionThread tickThread
  let fin     = cancel actionThread >> cancel tickThread
      action' = do
        atomically $ tryPeekTBQueue qi >>= \case
          Nothing -> writeTBQueue qi ()
          Just () -> pure ()
        atomically $ do
          (Right <$> readTBQueue qo) `orElse` (Left <$> readTBQueue qt)
  pure (Clocked action' fin)
clockTimerIO :: Clock IO -> TickDelta -> IO a -> IO (Either Tick a)
clockTimerIO c d = race (clockTick c d)
voidInput :: IO Void
voidInput = forever $ threadDelay maxBound