{-# LANGUAGE AllowAmbiguousTypes #-}

module LambdaSound.SaveAndLoad.RawSamples
  ( saveWav,
    saveRaw,
    saveRawCompressed,
    loadWav,
    loadRaw,
    loadRawCompressed,
  )
where

import Codec.Audio.Wave
import Codec.Compression.GZip (compress, decompress)
import Control.Monad.IO.Class (liftIO)
import Data.ByteString qualified as B
import Data.ByteString.Lazy qualified as BL
import Data.Functor ((<&>))
import Data.Int (Int16, Int32, Int64)
import Data.List.NonEmpty
import Data.Massiv.Array qualified as M
import Data.Semigroup (Max (..))
import Data.Vector.Storable.ByteString (byteStringToVector, vectorToByteString)
import Data.Word (Word8)
import LambdaSound.Sound (Hz, Pulse)

-- | Save sound samples to a wave file with the given sampling frequency
saveWav :: FilePath -> Hz -> M.Vector M.S Pulse -> IO ()
saveWav :: [Char] -> Hz -> Vector S Pulse -> IO ()
saveWav [Char]
filepath Hz
sampleRate Vector S Pulse
floats = do
  let floatsLength :: Ix1
floatsLength = Sz Ix1 -> Ix1
forall ix. Sz ix -> ix
M.unSz (Sz Ix1 -> Ix1) -> Sz Ix1 -> Ix1
forall a b. (a -> b) -> a -> b
$ Vector S Pulse -> Sz Ix1
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size Vector S Pulse
floats
      wave :: Wave
wave =
        Wave
          { waveFileFormat :: WaveFormat
waveFileFormat = WaveFormat
WaveVanilla,
            waveSampleRate :: Word32
waveSampleRate = Hz -> Word32
forall b. Integral b => Hz -> b
forall a b. (RealFrac a, Integral b) => a -> b
round Hz
sampleRate,
            waveSampleFormat :: SampleFormat
waveSampleFormat = SampleFormat
SampleFormatIeeeFloat32Bit,
            waveChannelMask :: Set SpeakerPosition
waveChannelMask = Set SpeakerPosition
speakerMono,
            waveDataOffset :: Word32
waveDataOffset = Word32
0,
            waveDataSize :: Word64
waveDataSize = Ix1 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Ix1
floatsLength Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
4,
            waveSamplesTotal :: Word64
waveSamplesTotal = Ix1 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Ix1
floatsLength,
            waveOtherChunks :: [(ByteString, ByteString)]
waveOtherChunks = []
          }
  [Char] -> Wave -> (Handle -> IO ()) -> IO ()
forall (m :: * -> *).
MonadIO m =>
[Char] -> Wave -> (Handle -> IO ()) -> m ()
writeWaveFile [Char]
filepath Wave
wave ((Handle -> IO ()) -> IO ()) -> (Handle -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Handle
handle ->
    Handle -> ByteString -> IO ()
B.hPut Handle
handle (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Vector Pulse -> ByteString
forall a. Storable a => Vector a -> ByteString
vectorToByteString (Vector Pulse -> ByteString) -> Vector Pulse -> ByteString
forall a b. (a -> b) -> a -> b
$ Vector S Pulse -> Vector Pulse
forall ix e. Index ix => Array S ix e -> Vector e
M.toStorableVector Vector S Pulse
floats

-- | Load a wave file to get the sampling frequencies and the sound samples for the channels.
loadWav :: FilePath -> IO (Hz, NonEmpty (M.Vector M.S Pulse))
loadWav :: [Char] -> IO (Hz, NonEmpty (Vector S Pulse))
loadWav [Char]
filePath = do
  Wave
wave <- [Char] -> IO Wave
forall (m :: * -> *). MonadIO m => [Char] -> m Wave
readWaveFile [Char]
filePath
  ByteString
file <- [Char] -> IO ByteString
B.readFile [Char]
filePath
  let sourceVector :: Vector D Pulse
sourceVector = Wave -> ByteString -> Vector D Pulse
readSource Wave
wave (ByteString -> Vector D Pulse) -> ByteString -> Vector D Pulse
forall a b. (a -> b) -> a -> b
$ Ix1 -> ByteString -> ByteString
B.drop (Word32 -> Ix1
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Ix1) -> Word32 -> Ix1
forall a b. (a -> b) -> a -> b
$ Wave -> Word32
waveDataOffset Wave
wave) ByteString
file
  (Hz, NonEmpty (Vector S Pulse))
-> IO (Hz, NonEmpty (Vector S Pulse))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Word32 -> Hz
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Hz) -> Word32 -> Hz
forall a b. (a -> b) -> a -> b
$ Wave -> Word32
waveSampleRate Wave
wave, Wave -> Vector D Pulse -> NonEmpty (Vector S Pulse)
splitInChannels Wave
wave Vector D Pulse
sourceVector)
  where
    splitInChannels :: Wave -> M.Vector M.D Pulse -> NonEmpty (M.Vector M.S Pulse)
    splitInChannels :: Wave -> Vector D Pulse -> NonEmpty (Vector S Pulse)
splitInChannels Wave
wave Vector D Pulse
sourceVector =
      let channels :: Ix1
channels = Word16 -> Ix1
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16 -> Ix1) -> Word16 -> Ix1
forall a b. (a -> b) -> a -> b
$ Wave -> Word16
waveChannels Wave
wave
       in [Vector S Pulse] -> NonEmpty (Vector S Pulse)
forall a. HasCallStack => [a] -> NonEmpty a
fromList ([Vector S Pulse] -> NonEmpty (Vector S Pulse))
-> [Vector S Pulse] -> NonEmpty (Vector S Pulse)
forall a b. (a -> b) -> a -> b
$
            if Ix1
channels Ix1 -> Ix1 -> Bool
forall a. Eq a => a -> a -> Bool
== Ix1
1
              then [Vector D Pulse -> Vector S Pulse
forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
M.compute Vector D Pulse
sourceVector]
              else
                [Ix1
0 .. Ix1 -> Ix1
forall a. Enum a => a -> a
pred Ix1
channels] [Ix1] -> (Ix1 -> Vector S Pulse) -> [Vector S Pulse]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \Ix1
channelOffset ->
                  Array DL Ix1 Pulse -> Vector S Pulse
forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
M.compute (Array DL Ix1 Pulse -> Vector S Pulse)
-> Array DL Ix1 Pulse -> Vector S Pulse
forall a b. (a -> b) -> a -> b
$ Stride Ix1 -> Vector D Pulse -> Array DL Ix1 Pulse
forall r ix e.
(Source r e, Load r ix e) =>
Stride ix -> Array r ix e -> Array DL ix e
M.downsample (Ix1 -> Stride Ix1
forall ix. Index ix => ix -> Stride ix
M.Stride Ix1
channels) (Vector D Pulse -> Array DL Ix1 Pulse)
-> Vector D Pulse -> Array DL Ix1 Pulse
forall a b. (a -> b) -> a -> b
$ Sz Ix1 -> Vector D Pulse -> Vector D Pulse
forall r e. Source r e => Sz Ix1 -> Vector r e -> Vector r e
M.drop (Ix1 -> Sz Ix1
M.Sz1 Ix1
channelOffset) Vector D Pulse
sourceVector
    readSource :: Wave -> B.ByteString -> M.Vector M.D Pulse
    readSource :: Wave -> ByteString -> Vector D Pulse
readSource Wave
wave ByteString
sampleData =
      case Wave -> SampleFormat
waveSampleFormat Wave
wave of
        SampleFormat
SampleFormatIeeeFloat32Bit -> forall a. (Real a, Num a, Storable a) => Vector D Pulse
mapAndLoad @Float
        SampleFormat
SampleFormatIeeeFloat64Bit -> forall a. (Real a, Num a, Storable a) => Vector D Pulse
mapAndLoad @Double
        SampleFormatPcmInt Word16
8 -> (Pulse -> Pulse) -> Vector D Pulse -> Vector D Pulse
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map ((Pulse -> Pulse -> Pulse
forall a. Num a => a -> a -> a
+ (-Pulse
1)) (Pulse -> Pulse) -> (Pulse -> Pulse) -> Pulse -> Pulse
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Pulse -> Pulse -> Pulse
forall a. Num a => a -> a -> a
* Pulse
2)) (Vector D Pulse -> Vector D Pulse)
-> Vector D Pulse -> Vector D Pulse
forall a b. (a -> b) -> a -> b
$ forall a. (Real a, Num a, Storable a) => Vector D Pulse
mapAndLoad @Word8
        SampleFormatPcmInt Word16
16 -> forall a. (Real a, Num a, Storable a) => Vector D Pulse
mapAndLoad @Int16
        SampleFormatPcmInt Word16
32 -> forall a. (Real a, Num a, Storable a) => Vector D Pulse
mapAndLoad @Int32
        SampleFormatPcmInt Word16
64 -> forall a. (Real a, Num a, Storable a) => Vector D Pulse
mapAndLoad @Int64
        SampleFormat
_ -> [Char] -> Vector D Pulse
forall a. HasCallStack => [Char] -> a
error ([Char] -> Vector D Pulse) -> [Char] -> Vector D Pulse
forall a b. (a -> b) -> a -> b
$ [Char]
"The sample format \"" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> SampleFormat -> [Char]
forall a. Show a => a -> [Char]
show (Wave -> SampleFormat
waveSampleFormat Wave
wave) [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"\" is not supported"
      where
        mapAndLoad :: forall a. (Real a, Num a, M.Storable a) => M.Vector M.D Pulse
        mapAndLoad :: forall a. (Real a, Num a, Storable a) => Vector D Pulse
mapAndLoad =
          let rawArray :: Vector S a
rawArray =
                forall e. Comp -> Vector e -> Vector S e
M.fromStorableVector @a Comp
M.Seq (Vector a -> Vector S a) -> Vector a -> Vector S a
forall a b. (a -> b) -> a -> b
$
                  ByteString -> Vector a
forall a. Storable a => ByteString -> Vector a
byteStringToVector ByteString
sampleData
              (Max Pulse
maxSample) = a -> Pulse
forall a b. (Real a, Fractional b) => a -> b
realToFrac (a -> Pulse) -> Max a -> Max Pulse
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> Max a) -> Max a -> Vector S a -> Max a
forall ix r e m.
(Index ix, Source r e, Semigroup m) =>
(e -> m) -> m -> Array r ix e -> m
M.foldSemi (a -> Max a
forall a. a -> Max a
Max (a -> Max a) -> (a -> a) -> a -> Max a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
forall a. Num a => a -> a
abs) (a -> Max a
forall a. a -> Max a
Max a
0) Vector S a
rawArray
           in (a -> Pulse) -> Vector S a -> Vector D Pulse
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map ((Pulse -> Pulse -> Pulse
forall a. Fractional a => a -> a -> a
/ Pulse
maxSample) (Pulse -> Pulse) -> (a -> Pulse) -> a -> Pulse
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Pulse
forall a b. (Real a, Fractional b) => a -> b
realToFrac) Vector S a
rawArray
        {-# INLINE mapAndLoad #-}

-- | Save the sound samples as raw floats
saveRaw :: FilePath -> M.Vector M.S Pulse -> IO ()
saveRaw :: [Char] -> Vector S Pulse -> IO ()
saveRaw [Char]
filePath Vector S Pulse
floats =
  [Char] -> ByteString -> IO ()
B.writeFile [Char]
filePath (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Vector Pulse -> ByteString
forall a. Storable a => Vector a -> ByteString
vectorToByteString (Vector Pulse -> ByteString) -> Vector Pulse -> ByteString
forall a b. (a -> b) -> a -> b
$ Vector S Pulse -> Vector Pulse
forall ix e. Index ix => Array S ix e -> Vector e
M.toStorableVector Vector S Pulse
floats

-- | Save the sound samples as raw floats compressed with gzip
saveRawCompressed :: FilePath -> M.Vector M.S Pulse -> IO ()
saveRawCompressed :: [Char] -> Vector S Pulse -> IO ()
saveRawCompressed [Char]
filePath Vector S Pulse
floats = do
  let bytes :: ByteString
bytes = ByteString -> ByteString
compress (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BL.fromStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Vector Pulse -> ByteString
forall a. Storable a => Vector a -> ByteString
vectorToByteString (Vector Pulse -> ByteString) -> Vector Pulse -> ByteString
forall a b. (a -> b) -> a -> b
$ Vector S Pulse -> Vector Pulse
forall ix e. Index ix => Array S ix e -> Vector e
M.toStorableVector Vector S Pulse
floats
  [Char] -> ByteString -> IO ()
BL.writeFile [Char]
filePath ByteString
bytes

-- Load the gzip compressed raw sound samples
loadRawCompressed :: FilePath -> IO (M.Vector M.S Pulse)
loadRawCompressed :: [Char] -> IO (Vector S Pulse)
loadRawCompressed [Char]
filePath = do
  ByteString
file <- IO ByteString -> IO ByteString
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> IO ByteString) -> IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ [Char] -> IO ByteString
BL.readFile [Char]
filePath
  Vector S Pulse -> IO (Vector S Pulse)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Vector S Pulse -> IO (Vector S Pulse))
-> Vector S Pulse -> IO (Vector S Pulse)
forall a b. (a -> b) -> a -> b
$ Comp -> Vector Pulse -> Vector S Pulse
forall e. Comp -> Vector e -> Vector S e
M.fromStorableVector Comp
M.Seq (Vector Pulse -> Vector S Pulse) -> Vector Pulse -> Vector S Pulse
forall a b. (a -> b) -> a -> b
$ ByteString -> Vector Pulse
forall a. Storable a => ByteString -> Vector a
byteStringToVector (ByteString -> Vector Pulse) -> ByteString -> Vector Pulse
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BL.toStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
decompress ByteString
file

-- Load the raw sound samples
loadRaw :: FilePath -> IO (M.Vector M.S Pulse)
loadRaw :: [Char] -> IO (Vector S Pulse)
loadRaw [Char]
filePath = do
  ByteString
file <- IO ByteString -> IO ByteString
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> IO ByteString) -> IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ [Char] -> IO ByteString
B.readFile [Char]
filePath
  Vector S Pulse -> IO (Vector S Pulse)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Vector S Pulse -> IO (Vector S Pulse))
-> Vector S Pulse -> IO (Vector S Pulse)
forall a b. (a -> b) -> a -> b
$ Comp -> Vector Pulse -> Vector S Pulse
forall e. Comp -> Vector e -> Vector S e
M.fromStorableVector Comp
M.Seq (Vector Pulse -> Vector S Pulse) -> Vector Pulse -> Vector S Pulse
forall a b. (a -> b) -> a -> b
$ ByteString -> Vector Pulse
forall a. Storable a => ByteString -> Vector a
byteStringToVector ByteString
file