{-# LANGUAGE ForeignFunctionInterface #-}
module Sound.ALSA.PCM
    (SampleFmt(..),
     SampleFreq,
     Time,
     SoundFmt(..),
     SoundSource(..),
     SoundSink(..),
     SoundBufferTime(..),
     Pcm,
     withSoundSource,
     withSoundSourceRunning,
     withSoundSink,
     withSoundSinkRunning,
     soundFmtMIME,
     audioBytesPerSample,
     audioBytesPerFrame,
     soundSourceBytesPerFrame,
     soundSinkBytesPerFrame,
     copySound,
     alsaSoundSource,
     alsaSoundSink,
     alsaSoundSourceTime,
     alsaSoundSinkTime,
     fileSoundSource,
     fileSoundSink,
    ) where

import Sound.ALSA.PCM.Core
import qualified Sound.ALSA.Exception as AlsaExc

import qualified Sound.Frame as Frame
import qualified Sound.Frame.Stereo as Stereo
import qualified Sound.Frame.MuLaw  as MuLaw

import Data.Word (Word8, Word16, Word32, )
import Data.Int (Int8, Int16, Int32, )

import Control.Concurrent (myThreadId, )
import Control.Exception (bracket, bracket_, )
import Control.Monad (liftM, when, )
import Foreign.Marshal.Array (advancePtr, allocaArray, )
import Foreign.C (CSize, CInt, )
import Foreign (Storable, Ptr, minusPtr, )
import qualified System.IO as IO
import System.IO
   (IOMode(ReadMode, WriteMode), Handle, openBinaryFile, hClose, )

--
-- * Generic sound API
--

class (Storable y, Frame.C y) => SampleFmt y where
   sampleFmtToPcmFormat :: y -> PcmFormat

type SampleFreq = Int

data SoundFmt y = SoundFmt {
	sampleFreq :: SampleFreq
	}
  deriving (Show)

type Time = Int

data SoundBufferTime = SoundBufferTime {
	bufferTime, periodTime :: Time
	}
  deriving (Show)


-- | Counts are in samples, not bytes. Multi-channel data is interleaved.
data SoundSource y handle =
    SoundSource {
                 soundSourceFmt   :: SoundFmt y,
                 soundSourceOpen  :: IO handle,
                 soundSourceClose :: handle -> IO (),
                 soundSourceStart :: handle -> IO (),
                 soundSourceStop  :: handle -> IO (),
                 soundSourceRead  :: handle -> Ptr y -> Int -> IO Int
                }

data SoundSink y handle =
    SoundSink {
                 soundSinkFmt   :: SoundFmt y,
                 soundSinkOpen  :: IO handle,
                 soundSinkClose :: handle -> IO (),
                 soundSinkWrite :: handle -> Ptr y -> Int -> IO (),
                 soundSinkStart :: handle -> IO (),
                 soundSinkStop  :: handle -> IO ()
                }

--
--
--

defaultBufferTime :: SoundBufferTime
defaultBufferTime =
    SoundBufferTime {
        bufferTime = 500000, -- 0.5s
        periodTime = 100000  -- 0.1s
    }

nullSoundSource :: SoundFmt y -> SoundSource y h
nullSoundSource fmt =
    SoundSource {
	         soundSourceFmt   = fmt,
                 soundSourceOpen  = return undefined,
                 soundSourceClose = \_ -> return (),
                 soundSourceStart = \_ -> return (),
                 soundSourceStop  = \_ -> return (),
                 soundSourceRead  = \_ _ _ -> return 0
                }

nullSoundSink :: SoundFmt y -> SoundSink y h
nullSoundSink fmt =
    SoundSink {
	       soundSinkFmt   = fmt,
               soundSinkOpen  = return undefined,
               soundSinkClose = \_ -> return (),
               soundSinkStart = \_ -> return (),
               soundSinkStop  = \_ -> return (),
               soundSinkWrite = \_ _ _ -> return ()
              }


withSoundSource :: SoundSource y h -> (h -> IO a) -> IO a
withSoundSource source =
    bracket (soundSourceOpen source) (soundSourceClose source)

withSoundSourceRunning :: SoundSource y h -> h -> IO a -> IO a
withSoundSourceRunning src h = bracket_ (soundSourceStart src h) (soundSourceStop src h)

withSoundSink :: SoundSink y h -> (h -> IO a) -> IO a
withSoundSink sink =
    bracket (soundSinkOpen sink) (soundSinkClose sink)

withSoundSinkRunning :: SoundSink y h -> h -> IO a -> IO a
withSoundSinkRunning src h = bracket_ (soundSinkStart src h) (soundSinkStop src h)


instance SampleFmt Word8 where
   sampleFmtToPcmFormat _ = PcmFormatU8

instance SampleFmt Int8 where
   sampleFmtToPcmFormat _ = PcmFormatS8

instance SampleFmt Word16 where
   sampleFmtToPcmFormat _ = PcmFormatU16

instance SampleFmt Int16 where
   sampleFmtToPcmFormat _ = PcmFormatS16

instance SampleFmt Word32 where
   sampleFmtToPcmFormat _ = PcmFormatU32

instance SampleFmt Int32 where
   sampleFmtToPcmFormat _ = PcmFormatS32

instance SampleFmt Float where
   sampleFmtToPcmFormat _ = PcmFormatFloat

instance SampleFmt Double where
   sampleFmtToPcmFormat _ = PcmFormatFloat64

instance SampleFmt MuLaw.T where
   sampleFmtToPcmFormat _ = PcmFormatMuLaw

instance SampleFmt a => SampleFmt (Stereo.T a) where
   sampleFmtToPcmFormat y =
      sampleFmtToPcmFormat (Stereo.left y)

withSampleFmt :: (y -> a) -> (SoundFmt y -> a)
withSampleFmt f _ = f undefined


soundFmtMIME :: SampleFmt y => SoundFmt y -> String
soundFmtMIME fmt = t ++ r ++ c
  where t = "audio/basic"
{-
        t = case sampleFmt fmt of
		SampleFmtLinear16BitSignedLE -> "audio/L16"
		SampleFmtMuLaw8Bit           -> "audio/basic"
-}
        r = ";rate=" ++ show (sampleFreq fmt)
        c | numChannels fmt == 1 = ""
	  | otherwise = ";channels=" ++ show (numChannels fmt)

numChannels :: SampleFmt y => SoundFmt y -> Int
numChannels = withSampleFmt Frame.numberOfChannels

audioBytesPerSample :: SampleFmt y => SoundFmt y -> Int
audioBytesPerSample = withSampleFmt Frame.sizeOfElement

{-
assumes interleaved data

Due to alignment constraints
a frame might occupy more than the calculated size
in an array in memory.
-}
audioBytesPerFrame :: SampleFmt y => SoundFmt y -> Int
audioBytesPerFrame fmt = numChannels fmt * audioBytesPerSample fmt

soundSourceBytesPerFrame :: SampleFmt y => SoundSource y h -> Int
soundSourceBytesPerFrame = audioBytesPerFrame . soundSourceFmt

soundSinkBytesPerFrame :: SampleFmt y => SoundSink y h -> Int
soundSinkBytesPerFrame = audioBytesPerFrame . soundSinkFmt

copySound :: SampleFmt y =>
             SoundSource y h1
          -> SoundSink y h2
          -> Int -- ^ Buffer size (in sample frames) to use
          -> IO ()
copySound source sink bufSize =
    allocaArray     bufSize $ \buf ->
    withSoundSource source  $ \from ->
    withSoundSink   sink    $ \to ->
       let loop = do n <- soundSourceRead source from buf bufSize
                     when (n > 0) $ do soundSinkWrite sink to buf n
                                       loop
        in loop

--
-- * Alsa stuff
--


debug :: String -> IO ()
debug s =
    do t <- myThreadId
       IO.hPutStrLn IO.stderr $ show t ++ ": " ++ s

alsaOpen :: SampleFmt y =>
           String -- ^ device, e.g @"default"@
	-> SoundFmt y
	-> SoundBufferTime
	-> PcmStream
	-> IO Pcm
alsaOpen dev fmt time stream = AlsaExc.rethrow $
    do debug "alsaOpen"
       h <- pcm_open dev stream 0
       (buffer_time,buffer_size,period_time,period_size) <-
           setHwParams h (withSampleFmt sampleFmtToPcmFormat fmt)
                         (numChannels fmt)
                         (sampleFreq fmt)
                         (bufferTime time)
                         (periodTime time)
       setSwParams h buffer_size period_size
       pcm_prepare h
       debug $ "buffer_time = " ++ show buffer_time
       debug $ "buffer_size = " ++ show buffer_size
       debug $ "period_time = " ++ show period_time
       debug $ "period_size = " ++ show period_size
       when (stream == PcmStreamPlayback) $
         callocaArray fmt period_size $ \buf ->
	   do pcm_writei h buf period_size
              return ()
       return h


setHwParams :: Pcm
            -> PcmFormat
            -> Int -- ^ number of channels
            -> SampleFreq -- ^ sample frequency
            -> Time -- ^ buffer time
            -> Time -- ^ period time
            -> IO (Int,Int,Int,Int)
               -- ^ (buffer_time,buffer_size,period_time,period_size)
setHwParams h format channels rate buffer_time period_time
  = withHwParams h $ \p ->
    do pcm_hw_params_set_access h p PcmAccessRwInterleaved
       pcm_hw_params_set_format h p format
       pcm_hw_params_set_channels h p channels
       pcm_hw_params_set_rate h p rate EQ
       (actual_buffer_time,_) <-
           pcm_hw_params_set_buffer_time_near h p buffer_time EQ
       buffer_size <- pcm_hw_params_get_buffer_size p
       (actual_period_time,_) <-
           pcm_hw_params_set_period_time_near h p period_time EQ
       (period_size,_) <- pcm_hw_params_get_period_size p
       return (actual_buffer_time,buffer_size,
               actual_period_time,period_size)

setSwParams :: Pcm
            -> Int -- ^ buffer size
            -> Int -- ^ period size
            -> IO ()
setSwParams h _buffer_size period_size = withSwParams h $ \p ->
    do -- let start_threshold =
       --        (buffer_size `div` period_size) * period_size
       --pcm_sw_params_set_start_threshold h p start_threshold
       pcm_sw_params_set_start_threshold h p 0
       pcm_sw_params_set_avail_min h p period_size
       pcm_sw_params_set_xfer_align h p 1
       -- pad buffer with silence when needed
       --pcm_sw_params_set_silence_size h p period_size
       --pcm_sw_params_set_silence_threshold h p period_size

withHwParams :: Pcm -> (PcmHwParams -> IO a) -> IO a
withHwParams h f =
    do p <- pcm_hw_params_malloc
       pcm_hw_params_any h p
       x <- f p
       pcm_hw_params h p
       pcm_hw_params_free p
       return x

withSwParams :: Pcm -> (PcmSwParams -> IO a) -> IO a
withSwParams h f =
    do p <- pcm_sw_params_malloc
       pcm_sw_params_current h p
       x <- f p
       pcm_sw_params h p
       pcm_sw_params_free p
       return x

alsaClose :: Pcm -> IO ()
alsaClose pcm = AlsaExc.rethrow $
    do debug "alsaClose"
       pcm_drain pcm
       pcm_close pcm

alsaStart :: Pcm -> IO ()
alsaStart pcm = AlsaExc.rethrow $
    do debug "alsaStart"
       pcm_prepare pcm
       pcm_start pcm


-- FIXME: use pcm_drain for sinks?
alsaStop :: Pcm -> IO ()
alsaStop pcm = AlsaExc.rethrow $
    do debug "alsaStop"
       pcm_drain pcm

alsaRead ::
   SampleFmt y =>
   Pcm -> Ptr y -> Int -> IO Int
alsaRead h buf n = AlsaExc.rethrow $
     do --debug $ "Reading " ++ show n ++ " samples..."
        n' <- pcm_readi h buf n `AlsaExc.catchXRun` handleOverRun
        --debug $ "Got " ++ show n' ++ " samples."
	if n' < n
          then do n'' <- alsaRead h (advancePtr buf n') (n - n')
	          return (n' + n'')
          else return n'
  where handleOverRun = do debug "snd_pcm_readi reported buffer over-run"
                           pcm_prepare h
                           alsaRead h buf n

alsaWrite ::
   SampleFmt y =>
   Pcm -> Ptr y -> Int -> IO ()
alsaWrite h buf n = AlsaExc.rethrow $
    do alsaWrite_ h buf n
       return ()

alsaWrite_ ::
   SampleFmt y =>
   Pcm -> Ptr y -> Int -> IO Int
alsaWrite_ h buf n =
     do --debug $ "Writing " ++ show n ++ " samples..."
        n' <- pcm_writei h buf n `AlsaExc.catchXRun` handleUnderRun
        --debug $ "Wrote " ++ show n' ++ " samples."
	if (n' /= n)
            then do n'' <- alsaWrite_ h (advancePtr buf n') (n - n')
                    return (n' + n'')
            else return n'
  where handleUnderRun = do debug "snd_pcm_writei reported buffer under-run"
                            pcm_prepare h
                            alsaWrite_ h buf n


alsaSoundSource ::
   SampleFmt y =>
   String -> SoundFmt y -> SoundSource y Pcm
alsaSoundSource dev fmt =
    (nullSoundSource fmt) {
        soundSourceOpen  = alsaOpen dev fmt defaultBufferTime PcmStreamCapture,
        soundSourceClose = alsaClose,
        soundSourceStart = alsaStart,
        soundSourceStop  = alsaStop,
        soundSourceRead  = alsaRead
    }

alsaSoundSink ::
   SampleFmt y =>
   String -> SoundFmt y -> SoundSink y Pcm
alsaSoundSink dev fmt =
    (nullSoundSink fmt) {
        soundSinkOpen  = alsaOpen dev fmt defaultBufferTime PcmStreamPlayback,
        soundSinkClose = alsaClose,
        soundSinkStart = alsaStart,
        soundSinkStop  = alsaStop,
        soundSinkWrite = alsaWrite
    }

alsaSoundSourceTime ::
   SampleFmt y =>
   String -> SoundFmt y -> SoundBufferTime -> SoundSource y Pcm
alsaSoundSourceTime dev fmt time =
    (nullSoundSource fmt) {
        soundSourceOpen  = alsaOpen dev fmt time PcmStreamCapture,
        soundSourceClose = alsaClose,
        soundSourceStart = alsaStart,
        soundSourceStop  = alsaStop,
        soundSourceRead  = alsaRead
    }

alsaSoundSinkTime ::
   SampleFmt y =>
   String -> SoundFmt y -> SoundBufferTime -> SoundSink y Pcm
alsaSoundSinkTime dev fmt time =
    (nullSoundSink fmt) {
        soundSinkOpen  = alsaOpen dev fmt time PcmStreamPlayback,
        soundSinkClose = alsaClose,
        soundSinkStart = alsaStart,
        soundSinkStop  = alsaStop,
        soundSinkWrite = alsaWrite
    }

--
-- * File stuff
--

{- |
This expects pad bytes that are needed in memory
in order to satisfy aligment constraints.
This is only a problem for samples sizes like 24 bit.
-}
fileRead ::
   SampleFmt y =>
   Handle -> Ptr y -> Int -> IO Int
fileRead h buf n =
   liftM (`div` arraySize buf 1) $
   IO.hGetBuf h buf (arraySize buf n)

{- |
Same restrictions as for 'fileRead'.
-}
fileWrite ::
   SampleFmt y =>
   Handle -> Ptr y -> Int -> IO ()
fileWrite h buf n =
   IO.hPutBuf h buf (arraySize buf n)

fileSoundSource ::
   SampleFmt y =>
   FilePath -> SoundFmt y -> SoundSource y Handle
fileSoundSource file fmt =
    (nullSoundSource fmt) {
                           soundSourceOpen  = openBinaryFile file ReadMode,
                           soundSourceClose = hClose,
                           soundSourceRead  = fileRead
                          }

fileSoundSink ::
   SampleFmt y =>
   FilePath -> SoundFmt y -> SoundSink y Handle
fileSoundSink file fmt =
    (nullSoundSink fmt) {
                         soundSinkOpen  = openBinaryFile file WriteMode,
                         soundSinkClose = hClose,
                         soundSinkWrite = fileWrite
                        }

--
-- * Marshalling utilities
--

callocaArray :: Storable y => SoundFmt y -> Int -> (Ptr y -> IO b) -> IO b
callocaArray _ n f =
   allocaArray n $ \p ->
      clearBytes p (arraySize p n) >>
      f p

clearBytes :: Ptr a -> Int -> IO ()
clearBytes p n = memset p 0 (fromIntegral n) >> return ()

{-# INLINE arraySize #-}
arraySize :: Storable y => Ptr y -> Int -> Int
arraySize p n = advancePtr p n `minusPtr` p

foreign import ccall unsafe "string.h" memset :: Ptr a -> CInt -> CSize -> IO (Ptr a)