{-# LANGUAGE CPP #-}

{-|
  Encode lazy bytestrings to wave format, and decode lazy bytestrings in wave format to a WaveFile datum.
-}
module Sound.Codecs.WaveFile (
        -- * Constructors
        WaveFile (WaveFile),
        WaveChunk ( .. ),
        -- * Encode\/decode functions
        getWaveFile,
        toWaveFile,
        isWaveFile
        ) where

import Data.Int
import Data.Word

#if defined(__GLASGOW_HASKELL__)
import GHC.Float ( double2Int )
#endif

import Data.Bits (shiftR, shiftL, (.&.), (.|.))
import Data.Binary (Binary, Get, Put, get, put, decode)
import qualified Data.Binary.Put as BP
import qualified Data.Binary.Get as BG

import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Lazy.Char8 as C

import qualified Control.Monad as CM
import qualified Control.Monad.State as St
import qualified Control.Monad.Error as Err

import Data.List (foldl', unfoldr)
import Sound.Base

--constants used in the file
riffBS :: L.ByteString
riffBS = id $! C.pack "RIFF"

waveBS :: L.ByteString
waveBS = id $! C.pack "WAVE"

dataBS :: L.ByteString
dataBS = id $! C.pack "data"

fmtBS :: L.ByteString
fmtBS = id $! C.pack "fmt " --space is part of the fmt string

metaBS :: L.ByteString
metaBS = id $! C.pack "LIST"
--end constants

-- |A SubChunk of a Wave file.
data WaveChunk = WaveFormat { format :: SndFileInfo }
                 -- ^Format of the audio data.
                 | WaveData { waveData :: L.ByteString, chunkLength :: Integer }
                 -- ^The audio data
                 | WaveMeta { metaData :: L.ByteString, chunkLength :: Integer }
                 -- ^Any metadata in the file.
                 | UnknownWaveChunk { chunkType :: L.ByteString,
                        unparsedData :: L.ByteString,
                        chunkLength :: Integer }
                 -- ^an unknown chunk type
                 deriving (Show, Eq)
                        --chunkLength can be a maximum of Word32 length, but that's an inconvenient type to use.
                        --if using signed ints, must use Int64 for the same range, but for common cases Int64 is likely
                        --less efficient than a plain Integer

instance Binary WaveChunk where
    get = do fmtChunk <- BG.getLazyByteString 4
             case (C.unpack fmtChunk) of
                    "data" -> do
                              size <- BG.getWord32le
                              dString <- BG.getLazyByteString $ fromIntegral size
                              return $ WaveData dString (fromIntegral size)
                    "fmt " -> do
                              --chunkLen <- BG.getWord32le
                              BG.getWord32le --don't need the length
                              fType <- BG.getWord16le
                              if (fType == 1)
                                 then do
                                    chns <- BG.getWord16le
                                    sR <- BG.getWord32le
                                    BG.skip 4 --byterate
                                    BG.skip 2 --alignment
                                    bDepth <- BG.getWord16le
                                    return $ WaveFormat (SndFileInfo (fromIntegral chns) (fromIntegral sR) (fromIntegral bDepth))
                                 else
                                    fail ("Can't read non-PCM Wave chunk.  Type = " ++ show fType)
                    "LIST" -> do
                              chunkLen <- BG.getWord32le
                              dString <- BG.getLazyByteString $ fromIntegral chunkLen
                              return $ WaveMeta dString (fromIntegral chunkLen)
                    _ -> do --unknown chunk type
                        chunkLen <- BG.getWord32le
                        dString <- BG.getLazyByteString $ fromIntegral chunkLen
                        return $ UnknownWaveChunk fmtChunk dString $ fromIntegral chunkLen

    put (WaveData dBs chunkLen) =
            do  BP.putLazyByteString dataBS
                BP.putWord32le $ fromIntegral chunkLen
                BP.putLazyByteString dBs
    put (WaveFormat (SndFileInfo numChn sR bDepth)) =
            do  BP.putLazyByteString fmtBS
                BP.putWord32le 16 --length of format chunk
                BP.putWord16le 1 -- format type (1 for PCM)
                BP.putWord16le $ fromIntegral numChn
                BP.putWord32le $ fromIntegral sR
                BP.putWord32le $ fromIntegral dataRate
                BP.putWord16le $ fromIntegral align
                BP.putWord16le $ fromIntegral bDepth
            where
                align = (fromIntegral numBytes) * numChn
                dataRate = ((fromIntegral sR) * align)
                numBytes = bDepth `div` 8
    put (WaveMeta dBs chunkLen) = 
            do  BP.putLazyByteString metaBS
                BP.putWord32le $ fromIntegral chunkLen
                BP.putLazyByteString dBs
    put (UnknownWaveChunk ct dBs chunkLen) =
        BP.putLazyByteString ct >> (BP.putWord32le $ fromIntegral chunkLen) >> BP.putLazyByteString dBs

getChunkLength :: WaveChunk -> Integer
getChunkLength (WaveFormat _) = 16
getChunkLength a = chunkLength a

newtype WaveFile = WaveFile [WaveChunk] deriving (Show, Eq)

instance Binary WaveFile where
        get = do
                riffPart <- BG.getLazyByteString 4
                restLen <- BG.getWord32le
                wvPart <- BG.getLazyByteString 4
                case ((C.unpack riffPart) ++ (C.unpack wvPart)) of
                        "RIFFWAVE" -> do
                                lbs <- BG.getLazyByteString $ (-) (fromIntegral restLen) 4
                                return $ WaveFile $ map decode . unroll $ lbs
                        _ -> fail "Not a RIFF/Wave file"
        put (WaveFile chunks) = do
                BP.putLazyByteString riffBS
                BP.putWord32le $ fromIntegral totalLen
                BP.putLazyByteString waveBS
                mapM_ put chunks
                where totalLen = (+) 4 $ foldl' (+) 0 $ map (\x -> 8 + getChunkLength x) chunks

instance SndFileCls WaveFile where
        getSfInfo (WaveFile cs) = case (St.execState (mapM_ processChunk cs) Nothing) of
                Just sf -> return sf
                Nothing -> Err.throwError NoFormatError 
        getAudioData (WaveFile cs) = return $ concatASig $ St.evalState (mapM processChunk cs) Nothing
        getSfType _ = WavePCM

type WaveFileReader a = St.State (Maybe SndFileInfo) a

processChunk :: WaveChunk -> WaveFileReader AudioSig
processChunk (WaveFormat f) = do
        St.put $ Just f
        return $ makeAudioSignal 0 []
processChunk c@(WaveData _ _) = do
        mf <- St.get
        case mf of
                Just f -> return $ makeAudioSignal (cLenInFrames f c) (makeFrames (numChannels f) . decodeSoundData f $ c)
                Nothing -> fail "No format chunk found in Wave file."
processChunk _ = return $ makeAudioSignal 0 []

cLenInFrames :: SndFileInfo -> WaveChunk -> FrameCount
cLenInFrames sf (WaveData _ chunkLen) = chunkLen `div` divisor
        where divisor = (*) (fromIntegral $ div (bitDepth sf) 8) (fromIntegral $ numChannels sf)
cLenInFrames _ _ = 0

-- encoders and decoders for the wave data chunk

-- |Read a WaveData WaveChunk into a list of SoundData
decodeSoundData :: SndFileInfo -> WaveChunk -> [SoundData]
decodeSoundData sfInfo (WaveData theDataBS chunkLen) =
        BG.runGet readGet theDataBS
        where
              bitVal = bitDepth sfInfo
              readGet = case (bitVal) of
                                8 -> CM.replicateM (fromIntegral chunkLen) . mapFn $ getSD8
                                16 -> CM.replicateM (fromIntegral $ div chunkLen 2) . mapFn $ getSD16
                                24 -> CM.replicateM (fromIntegral $ div chunkLen 3) . mapFn $ getSD24
                                32 -> CM.replicateM (fromIntegral $ div chunkLen 4) . mapFn $ getSD32
                                a -> fail ("Can't read " ++ (show a) ++ "-bit audio.")
              mapFn :: (Functor m, Monad m, Integral a) => m (a) -> m (SoundData)
              mapFn = fmap (normalize bitVal)
decodeSoundData _ _ = [] --can't read data from a non-data chunk

{-|
  Functions to normalize newly-read data
  I believe this function is correctly implemented.
-}
normalize :: Integral a => BitDepth -> a -> SoundData
normalize 8 a = ((fromIntegral a - 128)) / 128
normalize _ 0 = 0
normalize bd a = case (a > 0) of
        True ->  (fromIntegral a) / divPos
        False -> (fromIntegral a) / divNeg
        where
                divPos = (fromIntegral (1 `shiftL` fromIntegral (bd - 1) :: Int)) - 1
                divNeg = fromIntegral (1 `shiftL` fromIntegral (bd - 1) :: Int)

getSD8 :: Get (Int8)
getSD8 = CM.liftM fromIntegral BG.getWord8
 
getSD16 :: Get (Int16)
getSD16 = CM.liftM fromIntegral BG.getWord16le

getSD24 :: Get (Int32)
getSD24 = do ab <- BG.getWord16le
             c <- BG.getWord8
             let m::Int32
                 m = shiftR (shiftL (fromIntegral c) 24) 8
             return $! m .|. (fromIntegral ab)

getSD32 :: Get (Int32)
getSD32 = CM.liftM fromIntegral BG.getWord32le

-- |Create a WaveData wavechunk from a SndFileInfo and an AudioSig
encodeSoundData :: (Monad m) => SndFileInfo -> AudioSig -> AudioMonad m WaveChunk
encodeSoundData sfInfo asig = do
        p <- sdPut
        let bs = case (fList) of
                [] -> L.empty
                _ -> BP.runPut p
        return $ WaveData bs chunkLen
        where
                frameLen = lengthInFrames asig
                fList = audioData asig
                bitVal = bitDepth sfInfo
                chunkLen = (fromIntegral frameLen) * (div (fromIntegral bitVal) 8) * (fromIntegral $ numChannels sfInfo)
                sdPut :: (Monad m) => AudioMonad  m (BP.PutM ())
                sdPut = case (bitVal) of
                        8 -> return $ mapM_ (putSD8 . fromIntegral . unNormalize 8) . concat $ fList
                        16 -> return $ mapM_ (putSD16 . fromIntegral . unNormalize 16) . concat $ fList
                        24 -> return $ mapM_ (putSD24 . fromIntegral . unNormalize 24) . concat $ fList
                        32 -> return $ mapM_ (putSD32 . fromIntegral . unNormalize 32) . concat $ fList
                        x -> Err.throwError $ InvalidBitDepthError x [8,16,24,32]

-- |Un-normalize data, convert to the format native type.
unNormalize :: BitDepth -> SoundData -> Int
unNormalize 8 a = GHC.Float.double2Int (128 * (1+a))
unNormalize bd a = let
        posMult = fromIntegral $ ((1 `shiftL` (fromIntegral bd - 1)) :: Integer) - 1
        negMult = fromIntegral (1 `shiftL` (fromIntegral bd - 1) :: Integer)
        in
        case (a >= 0) of
        True -> fastRound (posMult * clip a)
        False -> fastRound (negMult * clip a)

#if defined(__GLASGOW_HASKELL__)
fastRound :: SoundData -> Int
fastRound x = case (x >= 0 ) of
        True -> GHC.Float.double2Int (x + 0.5)
        False -> GHC.Float.double2Int (x - 0.5)
#else
-- don't know how to optimize this for other compilers
fastRound :: SoundData -> Int
fastRound = round
#endif

putSD8 :: Word8 -> Put
putSD8 = BP.putWord8

-- this is the biggest bottleneck; there isn't much I can do about that.
putSD16 :: Word16 -> Put
putSD16 = BP.putWord16le

-- Since Word24 isn't a data type, we need to fake one from a 32-bit value
-- like the other put functions, this one is little-endian.
putSD24 :: Word32 -> Put
putSD24 val = do
                BP.putWord8 . fromIntegral $ (.&.) val mask
                BP.putWord8 . fromIntegral $ shiftR ((.&.) val m2) 8
                BP.putWord8 . fromIntegral $ shiftR ((.&.) val m3) 16
              where
                mask::Word32
                mask = 0xFF
                m2 = shiftL mask 8
                m3 = shiftL mask 16

putSD32 :: Word32 -> Put
putSD32 = BP.putWord32le

-- |Clip a SoundData to range [-1,1] for writing out.
clip :: SoundData -> SoundData
clip = max (-1) . min 1
{-# INLINE clip #-}
--end encoders and decoders.


-- |Convert a L.ByteString into a list of ByteStrings, corresponding to WaveChunks.
-- used in binary instance.
unroll :: L.ByteString -> [L.ByteString]
unroll = unfoldr unroll'

unroll' :: L.ByteString -> Maybe (L.ByteString, L.ByteString)
unroll' bs
        | (bs == L.empty) = Nothing
        | otherwise =   case (first == L.empty) of
                                True -> Nothing
                                False -> Just (first, next)
                        where
                              (first, next) = L.splitAt chunkLen bs
                              getLen = do BG.skip 4
                                          val <- BG.getWord32le
                                          return val
                              chunkLen =  8 + (fromIntegral $ BG.runGet getLen bs)
--end unroll

-- |Create a WaveFile from a SndFileCls
toWaveFile :: (SndFileCls a, Monad m) => a -> AudioMonad m WaveFile
toWaveFile sfc = do
        sf <- getSfInfo sfc
        d <- getAudioData sfc
        cs <- encodeSoundData sf d
        return $ WaveFile $ (WaveFormat sf):cs:[]

-- |return a WaveFile from a bytestring (including header)
getWaveFile :: (Monad m) => L.ByteString -> AudioMonad m (WaveFile)
getWaveFile bs
        | (isWaveFile . L.take 12 $ bs) = return $ decode bs
        | otherwise = Err.throwError $ UnknownFileTypeError

-- |determine (based on header information) if the bytestring is a wave file.
isWaveFile      :: L.ByteString -> Bool
isWaveFile bs
        | (L.take 4 bs == riffBS) && ((L.drop 8 . L.take 12 $ bs) == waveBS) = True
        | otherwise = False