{-# LANGUAGE ForeignFunctionInterface #-}
module Sound.ALSA.PCM
    (SampleFmt(..),
     SampleFreq,
     Time,
     SoundFmt(..),
     SoundSource(..),
     SoundSink(..),
     SoundBufferTime(..),
     Pcm,
     withSoundSource,
     withSoundSourceRunning,
     withSoundSink,
     withSoundSinkRunning,
     soundFmtMIME,
     audioBytesPerSample,
     audioBytesPerFrame,
     soundSourceBytesPerFrame,
     soundSinkBytesPerFrame,
     copySound,
     alsaSoundSource,
     alsaSoundSink,
     alsaSoundSourceTime,
     alsaSoundSinkTime,
     fileSoundSource,
     fileSoundSink,
    ) where

import Sound.ALSA.PCM.Core (Pcm)
import qualified Sound.ALSA.PCM.Core as PCM
import qualified Sound.ALSA.Exception as AlsaExc
import qualified Sound.ALSA.PCM.Debug as Debug

import qualified Sound.Frame as Frame
import qualified Sound.Frame.Stereo as Stereo
import qualified Sound.Frame.MuLaw  as MuLaw

import Data.Word (Word8, Word16, Word32, )
import Data.Int (Int8, Int16, Int32, )

import Control.Exception (bracket, bracket_, )
import Control.Monad (liftM, when, )
import Foreign.Marshal.Array (advancePtr, allocaArray, )
import Foreign.C (CSize, CInt, )
import Foreign (Storable, Ptr, minusPtr, )
import qualified System.IO as IO
import System.IO
   (IOMode(ReadMode, WriteMode), Handle, openBinaryFile, hClose, )

--
-- * Generic sound API
--

class (Storable y, Frame.C y) => SampleFmt y where
   sampleFmtToPcmFormat :: y -> PCM.Format

type SampleFreq = Int

data SoundFmt y = SoundFmt {
        sampleFreq :: SampleFreq
        }
  deriving (Show)

type Time = Int

data SoundBufferTime = SoundBufferTime {
        bufferTime, periodTime :: Time
        }
  deriving (Show)


-- | Counts are in samples, not bytes. Multi-channel data is interleaved.
data SoundSource y handle =
    SoundSource {
                 soundSourceFmt   :: SoundFmt y,
                 soundSourceOpen  :: IO handle,
                 soundSourceClose :: handle -> IO (),
                 soundSourceStart :: handle -> IO (),
                 soundSourceStop  :: handle -> IO (),
                 soundSourceRead  :: handle -> Ptr y -> Int -> IO Int
                }

data SoundSink y handle =
    SoundSink {
                 soundSinkFmt   :: SoundFmt y,
                 soundSinkOpen  :: IO handle,
                 soundSinkClose :: handle -> IO (),
                 soundSinkWrite :: handle -> Ptr y -> Int -> IO (),
                 soundSinkStart :: handle -> IO (),
                 soundSinkStop  :: handle -> IO ()
                }

--
--
--

defaultBufferTime :: SoundBufferTime
defaultBufferTime =
    SoundBufferTime {
        bufferTime = 500000, -- 0.5s
        periodTime = 100000  -- 0.1s
    }

nullSoundSource :: SoundFmt y -> SoundSource y h
nullSoundSource fmt =
    SoundSource {
                 soundSourceFmt   = fmt,
                 soundSourceOpen  = return undefined,
                 soundSourceClose = \_ -> return (),
                 soundSourceStart = \_ -> return (),
                 soundSourceStop  = \_ -> return (),
                 soundSourceRead  = \_ _ _ -> return 0
                }

nullSoundSink :: SoundFmt y -> SoundSink y h
nullSoundSink fmt =
    SoundSink {
               soundSinkFmt   = fmt,
               soundSinkOpen  = return undefined,
               soundSinkClose = \_ -> return (),
               soundSinkStart = \_ -> return (),
               soundSinkStop  = \_ -> return (),
               soundSinkWrite = \_ _ _ -> return ()
              }


withSoundSource :: SoundSource y h -> (h -> IO a) -> IO a
withSoundSource source =
    bracket (soundSourceOpen source) (soundSourceClose source)

withSoundSourceRunning :: SoundSource y h -> h -> IO a -> IO a
withSoundSourceRunning src h = bracket_ (soundSourceStart src h) (soundSourceStop src h)

withSoundSink :: SoundSink y h -> (h -> IO a) -> IO a
withSoundSink sink =
    bracket (soundSinkOpen sink) (soundSinkClose sink)

withSoundSinkRunning :: SoundSink y h -> h -> IO a -> IO a
withSoundSinkRunning src h = bracket_ (soundSinkStart src h) (soundSinkStop src h)


instance SampleFmt Word8 where
   sampleFmtToPcmFormat _ = PCM.FormatU8

instance SampleFmt Int8 where
   sampleFmtToPcmFormat _ = PCM.FormatS8

instance SampleFmt Word16 where
   sampleFmtToPcmFormat _ = PCM.FormatU16

instance SampleFmt Int16 where
   sampleFmtToPcmFormat _ = PCM.FormatS16

instance SampleFmt Word32 where
   sampleFmtToPcmFormat _ = PCM.FormatU32

instance SampleFmt Int32 where
   sampleFmtToPcmFormat _ = PCM.FormatS32

instance SampleFmt Float where
   sampleFmtToPcmFormat _ = PCM.FormatFloat

instance SampleFmt Double where
   sampleFmtToPcmFormat _ = PCM.FormatFloat64

instance SampleFmt MuLaw.T where
   sampleFmtToPcmFormat _ = PCM.FormatMuLaw

instance SampleFmt a => SampleFmt (Stereo.T a) where
   sampleFmtToPcmFormat y =
      sampleFmtToPcmFormat (Stereo.left y)

withSampleFmt :: (y -> a) -> (SoundFmt y -> a)
withSampleFmt f _ = f undefined


soundFmtMIME :: SampleFmt y => SoundFmt y -> String
soundFmtMIME fmt = t ++ r ++ c
  where t = "audio/basic"
{-
        t = case sampleFmt fmt of
                SampleFmtLinear16BitSignedLE -> "audio/L16"
                SampleFmtMuLaw8Bit           -> "audio/basic"
-}
        r = ";rate=" ++ show (sampleFreq fmt)
        c | numChannels fmt == 1 = ""
          | otherwise = ";channels=" ++ show (numChannels fmt)

numChannels :: SampleFmt y => SoundFmt y -> Int
numChannels = withSampleFmt Frame.numberOfChannels

audioBytesPerSample :: SampleFmt y => SoundFmt y -> Int
audioBytesPerSample = withSampleFmt Frame.sizeOfElement

{-
assumes interleaved data

Due to alignment constraints
a frame might occupy more than the calculated size
in an array in memory.
-}
audioBytesPerFrame :: SampleFmt y => SoundFmt y -> Int
audioBytesPerFrame fmt = numChannels fmt * audioBytesPerSample fmt

soundSourceBytesPerFrame :: SampleFmt y => SoundSource y h -> Int
soundSourceBytesPerFrame = audioBytesPerFrame . soundSourceFmt

soundSinkBytesPerFrame :: SampleFmt y => SoundSink y h -> Int
soundSinkBytesPerFrame = audioBytesPerFrame . soundSinkFmt

copySound :: SampleFmt y =>
             SoundSource y h1
          -> SoundSink y h2
          -> Int -- ^ Buffer size (in sample frames) to use
          -> IO ()
copySound source sink bufSize =
    allocaArray     bufSize $ \buf ->
    withSoundSource source  $ \from ->
    withSoundSink   sink    $ \to ->
       let loop = do n <- soundSourceRead source from buf bufSize
                     when (n > 0) $ do soundSinkWrite sink to buf n
                                       loop
        in loop

--
-- * Alsa stuff
--


alsaOpen :: SampleFmt y =>
           String -- ^ device, e.g @"default"@
        -> SoundFmt y
        -> SoundBufferTime
        -> PCM.Stream
        -> IO Pcm
alsaOpen dev fmt time stream = AlsaExc.rethrow $
    do Debug.put "alsaOpen"
       h <- PCM.open dev stream 0
       Debug.put $ "requested buffer_time = " ++ show (bufferTime time)
       Debug.put $ "requested period_time = " ++ show (periodTime time)
       (buffer_time,buffer_size,period_time,period_size) <-
           setHwParams h (withSampleFmt sampleFmtToPcmFormat fmt)
                         (numChannels fmt)
                         (sampleFreq fmt)
                         (bufferTime time)
                         (periodTime time)
       setSwParams h buffer_size period_size
       PCM.prepare h
       Debug.put $ "buffer_time = " ++ show buffer_time
       Debug.put $ "buffer_size = " ++ show buffer_size
       Debug.put $ "period_time = " ++ show period_time
       Debug.put $ "period_size = " ++ show period_size
       when (stream == PCM.StreamPlayback) $
         callocaArray fmt period_size $ \buf ->
            PCM.writei h buf period_size >> return ()
       return h


setHwParams :: Pcm
            -> PCM.Format
            -> Int -- ^ number of channels
            -> SampleFreq -- ^ sample frequency
            -> Time -- ^ buffer time
            -> Time -- ^ period time
            -> IO (Int,Int,Int,Int)
               -- ^ (buffer_time,buffer_size,period_time,period_size)
setHwParams h format channels rate buffer_time period_time
  = withHwParams h $ \p ->
    do PCM.hw_params_set_access h p PCM.AccessRwInterleaved
       PCM.hw_params_set_format h p format
       PCM.hw_params_set_channels h p channels
       PCM.hw_params_set_rate h p rate EQ
       (actual_buffer_time,_) <-
           PCM.hw_params_set_buffer_time_near h p buffer_time EQ
       buffer_size <- PCM.hw_params_get_buffer_size p
       (actual_period_time,_) <-
           PCM.hw_params_set_period_time_near h p period_time EQ
       (period_size,_) <- PCM.hw_params_get_period_size p
       return (actual_buffer_time,buffer_size,
               actual_period_time,period_size)

setSwParams :: Pcm
            -> Int -- ^ buffer size
            -> Int -- ^ period size
            -> IO ()
setSwParams h _buffer_size period_size = withSwParams h $ \p ->
    do -- let start_threshold =
       --        (buffer_size `div` period_size) * period_size
       --PCM.sw_params_set_start_threshold h p start_threshold
       PCM.sw_params_set_start_threshold h p 0
       PCM.sw_params_set_avail_min h p period_size
       PCM.sw_params_set_xfer_align h p 1
       -- pad buffer with silence when needed
       --PCM.sw_params_set_silence_size h p period_size
       --PCM.sw_params_set_silence_threshold h p period_size

withHwParams :: Pcm -> (PCM.HwParams -> IO a) -> IO a
withHwParams h f =
    bracket PCM.hw_params_malloc PCM.hw_params_free $ \p ->
    do PCM.hw_params_any h p
       x <- f p
       PCM.hw_params h p
       return x

withSwParams :: Pcm -> (PCM.SwParams -> IO a) -> IO a
withSwParams h f =
    bracket PCM.sw_params_malloc PCM.sw_params_free $ \p ->
    do PCM.sw_params_current h p
       x <- f p
       PCM.sw_params h p
       return x

alsaClose :: Pcm -> IO ()
alsaClose pcm = AlsaExc.rethrow $
    do Debug.put "alsaClose"
       PCM.drain pcm
       PCM.close pcm

alsaStart :: Pcm -> IO ()
alsaStart pcm = AlsaExc.rethrow $
    do Debug.put "alsaStart"
       PCM.prepare pcm
       PCM.start pcm


-- FIXME: use PCM.drain for sinks?
alsaStop :: Pcm -> IO ()
alsaStop pcm = AlsaExc.rethrow $
    do Debug.put "alsaStop"
       PCM.drain pcm

alsaRead ::
   SampleFmt y =>
   Pcm -> Ptr y -> Int -> IO Int
alsaRead h buf0 n =
   let go buf offset = do
          -- debug $ "Reading " ++ show n ++ " samples..."
          nread <-
             PCM.readi h buf (n-offset)
             `AlsaExc.catchXRun`
             do Debug.put "snd_pcm_readi reported buffer over-run"
                PCM.prepare h
                go buf offset
          let newOffset = offset+nread
          -- debug $ "Got " ++ show n' ++ " samples."
          if newOffset < n
            then go (advancePtr buf nread) newOffset
            else return newOffset
   in  AlsaExc.rethrow $ go buf0 0


alsaWrite ::
   SampleFmt y =>
   Pcm -> Ptr y -> Int -> IO ()
alsaWrite h buf n = AlsaExc.rethrow $
    alsaWrite_ h buf n >> return ()

alsaWrite_ ::
   SampleFmt y =>
   Pcm -> Ptr y -> Int -> IO Int
alsaWrite_ h buf0 n =
   let go buf offset = do
          --debug $ "Writing " ++ show n ++ " samples..."
          nwritten <-
             PCM.writei h buf n
             `AlsaExc.catchXRun`
             do Debug.put "snd_pcm_writei reported buffer under-run"
                PCM.prepare h
                go buf offset
          let newOffset = offset+nwritten
          --debug $ "Wrote " ++ show n' ++ " samples."
          if newOffset < n
            then go (advancePtr buf nwritten) newOffset
            else return newOffset
   in  AlsaExc.rethrow $ go buf0 0


alsaSoundSource ::
   SampleFmt y =>
   String -> SoundFmt y -> SoundSource y Pcm
alsaSoundSource dev fmt =
   alsaSoundSourceTime dev fmt defaultBufferTime

alsaSoundSink ::
   SampleFmt y =>
   String -> SoundFmt y -> SoundSink y Pcm
alsaSoundSink dev fmt =
   alsaSoundSinkTime dev fmt defaultBufferTime

alsaSoundSourceTime ::
   SampleFmt y =>
   String -> SoundFmt y -> SoundBufferTime -> SoundSource y Pcm
alsaSoundSourceTime dev fmt time =
    (nullSoundSource fmt) {
        soundSourceOpen  = alsaOpen dev fmt time PCM.StreamCapture,
        soundSourceClose = alsaClose,
        soundSourceStart = alsaStart,
        soundSourceStop  = alsaStop,
        soundSourceRead  = alsaRead
    }

alsaSoundSinkTime ::
   SampleFmt y =>
   String -> SoundFmt y -> SoundBufferTime -> SoundSink y Pcm
alsaSoundSinkTime dev fmt time =
    (nullSoundSink fmt) {
        soundSinkOpen  = alsaOpen dev fmt time PCM.StreamPlayback,
        soundSinkClose = alsaClose,
        soundSinkStart = alsaStart,
        soundSinkStop  = alsaStop,
        soundSinkWrite = alsaWrite
    }

--
-- * File stuff
--

{- |
This expects pad bytes that are needed in memory
in order to satisfy aligment constraints.
This is only a problem for samples sizes like 24 bit.
-}
fileRead ::
   SampleFmt y =>
   Handle -> Ptr y -> Int -> IO Int
fileRead h buf n =
   liftM (`div` arraySize buf 1) $
   IO.hGetBuf h buf (arraySize buf n)

{- |
Same restrictions as for 'fileRead'.
-}
fileWrite ::
   SampleFmt y =>
   Handle -> Ptr y -> Int -> IO ()
fileWrite h buf n =
   IO.hPutBuf h buf (arraySize buf n)

fileSoundSource ::
   SampleFmt y =>
   FilePath -> SoundFmt y -> SoundSource y Handle
fileSoundSource file fmt =
    (nullSoundSource fmt) {
                           soundSourceOpen  = openBinaryFile file ReadMode,
                           soundSourceClose = hClose,
                           soundSourceRead  = fileRead
                          }

fileSoundSink ::
   SampleFmt y =>
   FilePath -> SoundFmt y -> SoundSink y Handle
fileSoundSink file fmt =
    (nullSoundSink fmt) {
                         soundSinkOpen  = openBinaryFile file WriteMode,
                         soundSinkClose = hClose,
                         soundSinkWrite = fileWrite
                        }

--
-- * Marshalling utilities
--

callocaArray :: Storable y => SoundFmt y -> Int -> (Ptr y -> IO b) -> IO b
callocaArray _ n f =
   allocaArray n $ \p ->
      clearBytes p (arraySize p n) >>
      f p

clearBytes :: Ptr a -> Int -> IO ()
clearBytes p n = memset p 0 (fromIntegral n) >> return ()

{-# INLINE arraySize #-}
arraySize :: Storable y => Ptr y -> Int -> Int
arraySize p n = advancePtr p n `minusPtr` p

foreign import ccall unsafe "string.h" memset :: Ptr a -> CInt -> CSize -> IO (Ptr a)