{-# LANGUAGE CPP #-}

module Silero.Model (
  SileroModel (..),
  detectSpeech,
  windowLength,
  resetModel,
  sampleRate,
  withModel,
) where

import Data.Int (Int64)
import Data.Vector.Storable (Vector)
import qualified Data.Vector.Storable as Vector
import Foreign.Storable (Storable (..))
import GHC.Generics (Generic)
import GHC.IO (unsafeDupablePerformIO, unsafePerformIO)
import Paths_silero_vad (getDataFileName)
import UnliftIO (MonadIO (liftIO), MonadUnliftIO, bracket)

#if defined (linux_HOST_OS) || defined (darwin_HOST_OS)

import System.Posix (RTLDFlags (RTLD_NOW), dlsym, dlopen )
import Foreign (FunPtr, Ptr, castPtr)
import Foreign.C (CString, withCString)

#else


import System.Win32 (getProcAddress, loadLibrary)
import Foreign (FunPtr, Ptr, castPtr, castPtrToFunPtr)
import Foreign.C (CWString, withCWString)

#endif

foreign import ccall "model.h get_window_length" c_get_window_length :: IO Int64

foreign import ccall "model.h get_sample_rate" c_get_sample_rate :: IO Int64

foreign import ccall "model.h release_model" c_release_model :: Ptr () -> IO ()

foreign import ccall "model.h reset_model" c_reset_model :: Ptr () -> IO ()

foreign import ccall "model.h detect_speech" c_detect_speech :: Ptr () -> Ptr Float -> IO Float

#if defined (linux_HOST_OS) || defined (darwin_HOST_OS)

foreign import ccall "model.h load_model" c_load_model :: FunPtr () -> CString -> IO (Ptr ())

#else

foreign import ccall "model.h load_model" c_load_model :: FunPtr () -> CWString -> IO (Ptr ())

#endif

windowLength :: Int
windowLength :: Int
windowLength = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> Int64 -> Int
forall a b. (a -> b) -> a -> b
$ IO Int64 -> Int64
forall a. IO a -> a
unsafeDupablePerformIO IO Int64
c_get_window_length

sampleRate :: Int
sampleRate :: Int
sampleRate = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> Int64 -> Int
forall a b. (a -> b) -> a -> b
$ IO Int64 -> Int64
forall a. IO a -> a
unsafeDupablePerformIO IO Int64
c_get_sample_rate

-- |
-- Holds state to be used for voice activity detection.
-- **Warning**: This is **NOT** thread-safe due to this mutating state internally.
newtype SileroModel = SileroModel
  { SileroModel -> Ptr ()
api :: Ptr ()
  }
  deriving ((forall x. SileroModel -> Rep SileroModel x)
-> (forall x. Rep SileroModel x -> SileroModel)
-> Generic SileroModel
forall x. Rep SileroModel x -> SileroModel
forall x. SileroModel -> Rep SileroModel x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. SileroModel -> Rep SileroModel x
from :: forall x. SileroModel -> Rep SileroModel x
$cto :: forall x. Rep SileroModel x -> SileroModel
to :: forall x. Rep SileroModel x -> SileroModel
Generic)

instance Storable SileroModel where
  sizeOf :: SileroModel -> Int
sizeOf SileroModel
_ = Ptr () -> Int
forall a. Storable a => a -> Int
sizeOf (Ptr ()
forall a. HasCallStack => a
undefined :: Ptr ())
  alignment :: SileroModel -> Int
alignment SileroModel
_ = Ptr () -> Int
forall a. Storable a => a -> Int
alignment (Ptr ()
forall a. HasCallStack => a
undefined :: Ptr ())
  peek :: Ptr SileroModel -> IO SileroModel
peek Ptr SileroModel
ptr = do
    Ptr ()
apiPtr <- Ptr (Ptr ()) -> IO (Ptr ())
forall a. Storable a => Ptr a -> IO a
peek (Ptr SileroModel -> Ptr (Ptr ())
forall a b. Ptr a -> Ptr b
castPtr Ptr SileroModel
ptr)
    SileroModel -> IO SileroModel
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SileroModel -> IO SileroModel) -> SileroModel -> IO SileroModel
forall a b. (a -> b) -> a -> b
$ Ptr () -> SileroModel
SileroModel Ptr ()
apiPtr
  poke :: Ptr SileroModel -> SileroModel -> IO ()
poke Ptr SileroModel
ptr (SileroModel Ptr ()
apiPtr) = Ptr (Ptr ()) -> Ptr () -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr SileroModel -> Ptr (Ptr ())
forall a b. Ptr a -> Ptr b
castPtr Ptr SileroModel
ptr) Ptr ()
apiPtr

libraryPath :: FilePath
#if defined(linux_HOST_OS)

libraryPath :: String
libraryPath = String
"lib/onnxruntime/linux-x64/libonnxruntime.so"

#elif defined(darwin_HOST_OS) && defined(aarch64_HOST_ARCH)

libraryPath = "lib/onnxruntime/mac-arm64/libonnxruntime.dylib"

#elif defined(darwin_HOST_OS) && !defined(aarch64_HOST_ARCH)

libraryPath = "lib/onnxruntime/mac-x64/libonnxruntime.dylib"

#else

libraryPath = "lib/onnxruntime/windows-x64/onnxruntime.dll"

#endif

{-# NOINLINE onnxruntime #-}
onnxruntime :: FunPtr ()
#if defined (linux_HOST_OS) || defined (darwin_HOST_OS)

onnxruntime :: FunPtr ()
onnxruntime =
  IO (FunPtr ()) -> FunPtr ()
forall a. IO a -> a
unsafePerformIO (IO (FunPtr ()) -> FunPtr ()) -> IO (FunPtr ()) -> FunPtr ()
forall a b. (a -> b) -> a -> b
$
    String -> IO String
getDataFileName String
libraryPath
      IO String -> (String -> IO DL) -> IO DL
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (String -> [RTLDFlags] -> IO DL) -> [RTLDFlags] -> String -> IO DL
forall a b c. (a -> b -> c) -> b -> a -> c
flip String -> [RTLDFlags] -> IO DL
dlopen [RTLDFlags
RTLD_NOW]
      IO DL -> (DL -> IO (FunPtr ())) -> IO (FunPtr ())
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (IO (FunPtr ()) -> IO (FunPtr ())
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (FunPtr ()) -> IO (FunPtr ()))
-> (DL -> IO (FunPtr ())) -> DL -> IO (FunPtr ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DL -> String -> IO (FunPtr ())) -> String -> DL -> IO (FunPtr ())
forall a b c. (a -> b -> c) -> b -> a -> c
flip DL -> String -> IO (FunPtr ())
forall a. DL -> String -> IO (FunPtr a)
dlsym String
"OrtGetApiBase")

#else

onnxruntime =
  unsafePerformIO $
    getDataFileName libraryPath
      >>= loadLibrary
      >>= (fmap castPtrToFunPtr . flip getProcAddress "OrtGetApiBase")

#endif

getModelPath :: IO String
getModelPath :: IO String
getModelPath = String -> IO String
getDataFileName String
"lib/silero_vad.onnx"

#if defined (linux_HOST_OS) || defined (darwin_HOST_OS)

withModelPath :: (CString -> IO a) -> IO a
withModelPath :: forall a. (CString -> IO a) -> IO a
withModelPath CString -> IO a
runModelPath = do
  String
modelPath <- IO String
getModelPath
  String -> (CString -> IO a) -> IO a
forall a. String -> (CString -> IO a) -> IO a
withCString String
modelPath ((CString -> IO a) -> IO a) -> (CString -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ CString -> IO a
runModelPath

#else

withModelPath :: (CWString -> IO a) -> IO a
withModelPath runModelPath = do
  modelPath <- getModelPath
  withCWString modelPath runModelPath

#endif

-- | **Warning: SileroModel holds internal state and is NOT thread safe.**
withModel :: (MonadUnliftIO m) => (SileroModel -> m a) -> m a
withModel :: forall (m :: * -> *) a.
MonadUnliftIO m =>
(SileroModel -> m a) -> m a
withModel SileroModel -> m a
runModel = do
  m SileroModel
-> (SileroModel -> m ()) -> (SileroModel -> m a) -> m a
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
    (Ptr () -> SileroModel
SileroModel (Ptr () -> SileroModel) -> m (Ptr ()) -> m SileroModel
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (Ptr ()) -> m (Ptr ())
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO ((CString -> IO (Ptr ())) -> IO (Ptr ())
forall a. (CString -> IO a) -> IO a
withModelPath ((CString -> IO (Ptr ())) -> IO (Ptr ()))
-> (CString -> IO (Ptr ())) -> IO (Ptr ())
forall a b. (a -> b) -> a -> b
$ FunPtr () -> CString -> IO (Ptr ())
c_load_model FunPtr ()
onnxruntime))
    (IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (SileroModel -> IO ()) -> SileroModel -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr () -> IO ()
c_release_model (Ptr () -> IO ())
-> (SileroModel -> Ptr ()) -> SileroModel -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SileroModel -> Ptr ()
api)
    SileroModel -> m a
runModel

-- | **Warning: SileroModel holds internal state and is NOT thread safe.**
resetModel :: (MonadIO m) => SileroModel -> m ()
resetModel :: forall (m :: * -> *). MonadIO m => SileroModel -> m ()
resetModel = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (SileroModel -> IO ()) -> SileroModel -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr () -> IO ()
c_reset_model (Ptr () -> IO ())
-> (SileroModel -> Ptr ()) -> SileroModel -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SileroModel -> Ptr ()
api

-- |
-- Detect if speech is found within the given audio samples.
-- This has the following requirements:
-- - Must be 16khz sample rate.
-- - Must be mono-channel.
-- - Must be 16-bit audio.
-- - Must contain exactly 512 samples.
--
-- | **Warning: SileroModel holds internal state and is NOT thread safe.**
detectSpeech :: (MonadIO m) => SileroModel -> Vector Float -> m Float
detectSpeech :: forall (m :: * -> *).
MonadIO m =>
SileroModel -> Vector Float -> m Float
detectSpeech (SileroModel Ptr ()
api) Vector Float
samples
  | Vector Float -> Int
forall a. Storable a => Vector a -> Int
Vector.length Vector Float
samples Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
windowLength =
      Float -> m Float
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Float
0.0
  | Bool
otherwise =
      IO Float -> m Float
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Float -> m Float)
-> ((Ptr Float -> IO Float) -> IO Float)
-> (Ptr Float -> IO Float)
-> m Float
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Float -> (Ptr Float -> IO Float) -> IO Float
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
Vector.unsafeWith Vector Float
samples ((Ptr Float -> IO Float) -> m Float)
-> (Ptr Float -> IO Float) -> m Float
forall a b. (a -> b) -> a -> b
$
        Ptr () -> Ptr Float -> IO Float
c_detect_speech Ptr ()
api