module Sound.Alsa
    (SampleFmt(..),
     SampleFreq,
     Time,
     SoundFmt(..),
     SoundSource(..),
     SoundSink(..),
     SoundBufferTime(..),
     withSoundSource,
     withSoundSourceRunning,
     withSoundSink,
     withSoundSinkRunning,
     soundFmtMIME,
     audioBytesPerSample,
     audioBytesPerFrame,
     soundSourceBytesPerFrame,
     soundSinkBytesPerFrame,
     soundSourceReadBytes,
     soundSinkWriteBytes,
     copySound,
     alsaSoundSource,
     alsaSoundSink,
     alsaSoundSourceTime,
     alsaSoundSinkTime,
     fileSoundSource,
     fileSoundSink,
    ) where

import Sound.Alsa.Core
import Sound.Alsa.Error

import Control.Concurrent
import Control.Exception (bracket, bracket_)
import Control.Monad (liftM,when)
import Foreign
import Foreign.C
import System.IO

--
-- * Generic sound API
--

data SampleFmt = SampleFmtLinear16BitSignedLE
               | SampleFmtMuLaw8Bit
  deriving (Show)

type SampleFreq = Int

data SoundFmt = SoundFmt {
	sampleFmt :: SampleFmt,
	sampleFreq :: SampleFreq,
 	numChannels :: Int
	}
  deriving (Show)

type Time = Int

data SoundBufferTime = SoundBufferTime {
	bufferTime, periodTime :: Time
	}
  deriving (Show)


-- | Counts are in samples, not bytes. Multi-channel data is interleaved.
data SoundSource handle =
    SoundSource {
                 soundSourceFmt   :: SoundFmt,
                 soundSourceOpen  :: IO handle,
                 soundSourceClose :: handle -> IO (),
                 soundSourceStart :: handle -> IO (),
                 soundSourceStop  :: handle -> IO (),
                 soundSourceRead  :: handle -> Ptr () -> Int -> IO Int
                }

data SoundSink handle =
    SoundSink {
                 soundSinkFmt   :: SoundFmt,
                 soundSinkOpen  :: IO handle,
                 soundSinkClose :: handle -> IO (),
                 soundSinkWrite :: handle -> Ptr () -> Int -> IO (),
                 soundSinkStart :: handle -> IO (),
                 soundSinkStop  :: handle -> IO ()
                }

--
--
--

defaultBufferTime :: SoundBufferTime
defaultBufferTime =
    SoundBufferTime {
        bufferTime = 500000, -- 0.5s
        periodTime = 100000  -- 0.1s
    }

nullSoundSource :: SoundFmt -> SoundSource h
nullSoundSource fmt =
    SoundSource {
	         soundSourceFmt   = fmt,
                 soundSourceOpen  = return undefined,
                 soundSourceClose = \_ -> return (),
                 soundSourceStart = \_ -> return (),
                 soundSourceStop  = \_ -> return (),
                 soundSourceRead  = \_ _ _ -> return 0
                }

nullSoundSink :: SoundFmt -> SoundSink h
nullSoundSink fmt =
    SoundSink {
	       soundSinkFmt   = fmt,
               soundSinkOpen  = return undefined,
               soundSinkClose = \_ -> return (),
               soundSinkStart = \_ -> return (),
               soundSinkStop  = \_ -> return (),
               soundSinkWrite = \_ _ _ -> return ()
              }


withSoundSource :: SoundSource h -> (h -> IO a) -> IO a
withSoundSource source =
    bracket (soundSourceOpen source) (soundSourceClose source)

withSoundSourceRunning :: SoundSource h -> h -> IO a -> IO a
withSoundSourceRunning src h = bracket_ (soundSourceStart src h) (soundSourceStop src h)

withSoundSink :: SoundSink h -> (h -> IO a) -> IO a
withSoundSink sink =
    bracket (soundSinkOpen sink) (soundSinkClose sink)

withSoundSinkRunning :: SoundSink h -> h -> IO a -> IO a
withSoundSinkRunning src h = bracket_ (soundSinkStart src h) (soundSinkStop src h)

soundFmtMIME :: SoundFmt -> String
soundFmtMIME fmt = t ++ r ++ c
  where t = case sampleFmt fmt of
		SampleFmtLinear16BitSignedLE -> "audio/L16"
		SampleFmtMuLaw8Bit           -> "audio/basic"
        r = ";rate=" ++ show (sampleFreq fmt)
        c | numChannels fmt == 1 = ""
	  | otherwise = ";channels=" ++ show (numChannels fmt)

audioBytesPerSample :: SoundFmt -> Int
audioBytesPerSample fmt =
	case sampleFmt fmt of
		SampleFmtLinear16BitSignedLE -> 2
		SampleFmtMuLaw8Bit           -> 1

-- assumes interleaved data
audioBytesPerFrame :: SoundFmt -> Int
audioBytesPerFrame fmt = numChannels fmt * audioBytesPerSample fmt

soundSourceBytesPerFrame :: SoundSource h -> Int
soundSourceBytesPerFrame = audioBytesPerFrame . soundSourceFmt

soundSinkBytesPerFrame :: SoundSink h -> Int
soundSinkBytesPerFrame = audioBytesPerFrame . soundSinkFmt

soundSourceReadBytes :: SoundSource h -> h -> Ptr () -> Int -> IO Int
soundSourceReadBytes src h buf n =
	liftM (* c) $ soundSourceRead src h buf (n `div` c)
  where c = soundSourceBytesPerFrame src

soundSinkWriteBytes :: SoundSink h -> h -> Ptr () -> Int -> IO ()
soundSinkWriteBytes dst h buf n =
	soundSinkWrite dst h buf (n `div` c)
  where c = soundSinkBytesPerFrame dst

copySound :: SoundSource h1
          -> SoundSink h2
          -> Int -- ^ Buffer size (in bytes) to use
          -> IO ()
copySound source sink bufSize =
    allocaBytes     bufSize $ \buf ->
    withSoundSource source  $ \from ->
    withSoundSink   sink    $ \to ->
       let loop = do n <- soundSourceReadBytes source from buf bufSize
                     when (n > 0) $ do soundSinkWriteBytes sink to buf n
                                       loop
        in loop

--
-- * Alsa stuff
--


debug :: String -> IO ()
debug s =
    do t <- myThreadId
       hPutStrLn stderr $ show t ++ ": " ++ s

alsaOpen :: String -- ^ device, e.g @"default"@
	-> SoundFmt
	-> SoundBufferTime
	-> PcmStream
	-> IO Pcm
alsaOpen dev fmt time stream = rethrowAlsaExceptions $
    do debug "alsaOpen"
       h <- pcm_open dev stream 0
       (buffer_time,buffer_size,period_time,period_size) <-
           setHwParams h (sampleFmtToPcmFormat (sampleFmt fmt))
                         (numChannels fmt)
                         (sampleFreq fmt)
                         (bufferTime time)
                         (periodTime time)
       setSwParams h buffer_size period_size
       pcm_prepare h
       debug $ "buffer_time = " ++ show buffer_time
       debug $ "buffer_size = " ++ show buffer_size
       debug $ "period_time = " ++ show period_time
       debug $ "period_size = " ++ show period_size
       when (stream == PcmStreamPlayback) $
         callocaBytes (audioBytesPerFrame fmt * period_size) $ \buf ->
	   do pcm_writei h buf period_size
              return ()
       return h

sampleFmtToPcmFormat :: SampleFmt -> PcmFormat
sampleFmtToPcmFormat SampleFmtLinear16BitSignedLE = PcmFormatS16Le
sampleFmtToPcmFormat SampleFmtMuLaw8Bit           = PcmFormatMuLaw

setHwParams :: Pcm
            -> PcmFormat
            -> Int -- ^ number of channels
            -> SampleFreq -- ^ sample frequency
            -> Time -- ^ buffer time
            -> Time -- ^ period time
            -> IO (Int,Int,Int,Int)
               -- ^ (buffer_time,buffer_size,period_time,period_size)
setHwParams h format channels rate buffer_time period_time
  = withHwParams h $ \p ->
    do pcm_hw_params_set_access h p PcmAccessRwInterleaved
       pcm_hw_params_set_format h p format
       pcm_hw_params_set_channels h p channels
       pcm_hw_params_set_rate h p rate EQ
       (actual_buffer_time,_) <-
           pcm_hw_params_set_buffer_time_near h p buffer_time EQ
       buffer_size <- pcm_hw_params_get_buffer_size p
       (actual_period_time,_) <-
           pcm_hw_params_set_period_time_near h p period_time EQ
       (period_size,_) <- pcm_hw_params_get_period_size p
       return (actual_buffer_time,buffer_size,
               actual_period_time,period_size)

setSwParams :: Pcm
            -> Int -- ^ buffer size
            -> Int -- ^ period size
            -> IO ()
setSwParams h _buffer_size period_size = withSwParams h $ \p ->
    do -- let start_threshold =
       --        (buffer_size `div` period_size) * period_size
       --pcm_sw_params_set_start_threshold h p start_threshold
       pcm_sw_params_set_start_threshold h p 0
       pcm_sw_params_set_avail_min h p period_size
       pcm_sw_params_set_xfer_align h p 1
       -- pad buffer with silence when needed
       --pcm_sw_params_set_silence_size h p period_size
       --pcm_sw_params_set_silence_threshold h p period_size

withHwParams :: Pcm -> (PcmHwParams -> IO a) -> IO a
withHwParams h f =
    do p <- pcm_hw_params_malloc
       pcm_hw_params_any h p
       x <- f p
       pcm_hw_params h p
       pcm_hw_params_free p
       return x

withSwParams :: Pcm -> (PcmSwParams -> IO a) -> IO a
withSwParams h f =
    do p <- pcm_sw_params_malloc
       pcm_sw_params_current h p
       x <- f p
       pcm_sw_params h p
       pcm_sw_params_free p
       return x

alsaClose :: Pcm -> IO ()
alsaClose pcm = rethrowAlsaExceptions $
    do debug "alsaClose"
       pcm_drain pcm
       pcm_close pcm

alsaStart :: Pcm -> IO ()
alsaStart pcm = rethrowAlsaExceptions $
    do debug "alsaStart"
       pcm_prepare pcm
       pcm_start pcm


-- FIXME: use pcm_drain for sinks?
alsaStop :: Pcm -> IO ()
alsaStop pcm = rethrowAlsaExceptions $
    do debug "alsaStop"
       pcm_drain pcm

alsaRead :: SoundFmt -> Pcm -> Ptr () -> Int -> IO Int
alsaRead fmt h buf n = rethrowAlsaExceptions $
     do --debug $ "Reading " ++ show n ++ " samples..."
        n' <- pcm_readi h buf n `catchXRun` handleOverRun
        --debug $ "Got " ++ show n' ++ " samples."
	if n' < n
          then do n'' <- alsaRead fmt h (buf `plusPtr` (n' * c)) (n - n')
	          return (n' + n'')
          else return n'
  where c = audioBytesPerFrame fmt
        handleOverRun = do debug "snd_pcm_readi reported buffer over-run"
                           pcm_prepare h
                           alsaRead fmt h buf n

alsaWrite :: SoundFmt -> Pcm -> Ptr () -> Int -> IO ()
alsaWrite fmt h buf n = rethrowAlsaExceptions $
    do alsaWrite_ fmt h buf n
       return ()

alsaWrite_ :: SoundFmt -> Pcm -> Ptr () -> Int -> IO Int
alsaWrite_ fmt h buf n =
     do --debug $ "Writing " ++ show n ++ " samples..."
        n' <- pcm_writei h buf n `catchXRun` handleUnderRun
        --debug $ "Wrote " ++ show n' ++ " samples."
	if (n' /= n)
            then do n'' <- alsaWrite_ fmt h (buf `plusPtr` (n' * c)) (n - n')
                    return (n' + n'')
            else return n'
  where c = audioBytesPerFrame fmt
        handleUnderRun = do debug "snd_pcm_writei reported buffer under-run"
                            pcm_prepare h
                            alsaWrite_ fmt h buf n


alsaSoundSource :: String -> SoundFmt -> SoundSource Pcm
alsaSoundSource dev fmt =
    (nullSoundSource fmt) {
        soundSourceOpen  = alsaOpen dev fmt defaultBufferTime PcmStreamCapture,
        soundSourceClose = alsaClose,
        soundSourceStart = alsaStart,
        soundSourceStop  = alsaStop,
        soundSourceRead  = alsaRead fmt
    }

alsaSoundSink :: String -> SoundFmt -> SoundSink Pcm
alsaSoundSink dev fmt =
    (nullSoundSink fmt) {
        soundSinkOpen  = alsaOpen dev fmt defaultBufferTime PcmStreamPlayback,
        soundSinkClose = alsaClose,
        soundSinkStart = alsaStart,
        soundSinkStop  = alsaStop,
        soundSinkWrite = alsaWrite fmt
    }

alsaSoundSourceTime :: String -> SoundFmt -> SoundBufferTime -> SoundSource Pcm
alsaSoundSourceTime dev fmt time =
    (nullSoundSource fmt) {
        soundSourceOpen  = alsaOpen dev fmt time PcmStreamCapture,
        soundSourceClose = alsaClose,
        soundSourceStart = alsaStart,
        soundSourceStop  = alsaStop,
        soundSourceRead  = alsaRead fmt
    }

alsaSoundSinkTime :: String -> SoundFmt -> SoundBufferTime -> SoundSink Pcm
alsaSoundSinkTime dev fmt time =
    (nullSoundSink fmt) {
        soundSinkOpen  = alsaOpen dev fmt time PcmStreamPlayback,
        soundSinkClose = alsaClose,
        soundSinkStart = alsaStart,
        soundSinkStop  = alsaStop,
        soundSinkWrite = alsaWrite fmt
    }

--
-- * File stuff
--

fileRead :: SoundFmt -> Handle -> Ptr () -> Int -> IO Int
fileRead fmt h buf n = liftM (`div` c) $ hGetBuf h buf (n * c)
  where c = audioBytesPerSample fmt

fileWrite :: SoundFmt -> Handle -> Ptr () -> Int -> IO ()
fileWrite fmt h buf n = hPutBuf h buf (n * c)
  where c = audioBytesPerSample fmt

fileSoundSource :: FilePath -> SoundFmt -> SoundSource Handle
fileSoundSource file fmt =
    (nullSoundSource fmt) {
                           soundSourceOpen  = openBinaryFile file ReadMode,
                           soundSourceClose = hClose,
                           soundSourceRead  = fileRead fmt
                          }

fileSoundSink :: FilePath -> SoundFmt -> SoundSink Handle
fileSoundSink file fmt =
    (nullSoundSink fmt) {
                         soundSinkOpen  = openBinaryFile file WriteMode,
                         soundSinkClose = hClose,
                         soundSinkWrite = fileWrite fmt
                        }

--
-- * Marshalling utilities
--

callocaBytes :: Int -> (Ptr a -> IO b) -> IO b
callocaBytes n f = allocaBytes n (\p -> clearBytes p n >> f p)

clearBytes :: Ptr a -> Int -> IO ()
clearBytes p n = memset p 0 (fromIntegral n) >> return ()

foreign import ccall unsafe "string.h" memset :: Ptr a -> CInt -> CSize -> IO (Ptr a)