{-# LANGUAGE Arrows                #-}
{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE KindSignatures        #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies          #-}
module FRP.Rhine.Clock.Realtime.Audio
  ( AudioClock (..)
  , AudioRate (..)
  , PureAudioClock (..)
  , pureAudioClockF
  )
  where
import GHC.Float       (double2Float)
import GHC.TypeLits    (Nat, natVal, KnownNat)
import Data.Time.Clock
import Control.Monad.IO.Class
import Control.Monad.Trans.MSF.Except
import FRP.Rhine
data AudioRate
  = Hz44100
  | Hz48000
  | Hz96000
rateToIntegral :: Integral a => AudioRate -> a
rateToIntegral Hz44100 = 44100
rateToIntegral Hz48000 = 48000
rateToIntegral Hz96000 = 96000
data AudioClock (rate :: AudioRate) (bufferSize :: Nat) = AudioClock
class AudioClockRate (rate :: AudioRate) where
  theRate :: AudioClock rate bufferSize -> AudioRate
  theRateIntegral :: Integral a => AudioClock rate bufferSize -> a
  theRateIntegral = rateToIntegral . theRate
  theRateNum :: Num a => AudioClock rate bufferSize -> a
  theRateNum = fromInteger . theRateIntegral
instance AudioClockRate Hz44100 where
  theRate _ = Hz44100
instance AudioClockRate Hz48000 where
  theRate _ = Hz48000
instance AudioClockRate Hz96000 where
  theRate _ = Hz96000
theBufferSize
  :: (KnownNat bufferSize, Integral a)
  => AudioClock rate bufferSize -> a
theBufferSize = fromInteger . natVal
instance (MonadIO m, KnownNat bufferSize, AudioClockRate rate)
      => Clock m (AudioClock rate bufferSize) where
  type TimeDomainOf (AudioClock rate bufferSize) = UTCTime
  type Tag          (AudioClock rate bufferSize) = Maybe Double
  startClock audioClock = do
    let
      step       = picosecondsToDiffTime 
                     $ round (10 ^ (12 :: Integer) / theRateNum audioClock :: Double)
      bufferSize = theBufferSize audioClock
      runningClock :: MonadIO m => UTCTime -> Maybe Double -> MSF m () (UTCTime, Maybe Double)
      runningClock initialTime maybeWasLate = safely $ do
        bufferFullTime <- try $ proc () -> do
          n <- count    -< ()
          let nextTime = (realToFrac step * fromIntegral (n :: Int)) `addUTCTime` initialTime
          _ <- throwOn' -< (n >= bufferSize, nextTime)
          returnA       -< (nextTime, if n == 0 then maybeWasLate else Nothing)
        currentTime <- once_ $ liftIO getCurrentTime
        let
          lateDiff = currentTime `diffTime` bufferFullTime
          late     = if lateDiff > 0 then Just lateDiff else Nothing
        safe $ runningClock bufferFullTime late
    initialTime <- liftIO getCurrentTime
    return
      ( runningClock initialTime Nothing
      , initialTime
      )
data PureAudioClock (rate :: AudioRate) = PureAudioClock
class PureAudioClockRate (rate :: AudioRate) where
  thePureRate :: PureAudioClock rate -> AudioRate
  thePureRateIntegral :: Integral a => PureAudioClock rate -> a
  thePureRateIntegral = rateToIntegral . thePureRate
  thePureRateNum :: Num a => PureAudioClock rate -> a
  thePureRateNum = fromInteger . thePureRateIntegral
instance (Monad m, PureAudioClockRate rate) => Clock m (PureAudioClock rate) where
  type TimeDomainOf (PureAudioClock rate) = Double
  type Tag          (PureAudioClock rate) = ()
  startClock audioClock = return
    ( arr (const (1 / thePureRateNum audioClock)) >>> sumS &&& arr (const ())
    , 0
    )
type PureAudioClockF (rate :: AudioRate) = RescaledClock (PureAudioClock rate) Float
pureAudioClockF :: PureAudioClockF rate
pureAudioClockF = RescaledClock
  { unscaledClock = PureAudioClock
  , rescale       = double2Float
}