{-# LANGUAGE CPP #-}
module Silero.Model (
SileroModel (..),
detectSpeech,
windowLength,
resetModel,
sampleRate,
withModel,
) where
import Data.Int (Int64)
import Data.Vector.Storable (Vector)
import qualified Data.Vector.Storable as Vector
import Foreign.Storable (Storable (..))
import GHC.Generics (Generic)
import GHC.IO (unsafeDupablePerformIO, unsafePerformIO)
import Paths_silero_vad (getDataFileName)
import UnliftIO (MonadIO (liftIO), MonadUnliftIO, bracket)
#if defined (linux_HOST_OS) || defined (darwin_HOST_OS)
import System.Posix (RTLDFlags (RTLD_NOW), dlsym, dlopen )
import Foreign (FunPtr, Ptr, castPtr)
import Foreign.C (CString, withCString)
#else
import System.Win32 (getProcAddress, loadLibrary)
import Foreign (FunPtr, Ptr, castPtr, castPtrToFunPtr)
import Foreign.C (CWString, withCWString)
#endif
foreign import ccall "model.h get_window_length" c_get_window_length :: IO Int64
foreign import ccall "model.h get_sample_rate" c_get_sample_rate :: IO Int64
foreign import ccall "model.h release_model" c_release_model :: Ptr () -> IO ()
foreign import ccall "model.h reset_model" c_reset_model :: Ptr () -> IO ()
foreign import ccall "model.h detect_speech" c_detect_speech :: Ptr () -> Ptr Float -> IO Float
#if defined (linux_HOST_OS) || defined (darwin_HOST_OS)
foreign import ccall "model.h load_model" c_load_model :: FunPtr () -> CString -> IO (Ptr ())
#else
foreign import ccall "model.h load_model" c_load_model :: FunPtr () -> CWString -> IO (Ptr ())
#endif
windowLength :: Int
windowLength :: Int
windowLength = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> Int64 -> Int
forall a b. (a -> b) -> a -> b
$ IO Int64 -> Int64
forall a. IO a -> a
unsafeDupablePerformIO IO Int64
c_get_window_length
sampleRate :: Int
sampleRate :: Int
sampleRate = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> Int64 -> Int
forall a b. (a -> b) -> a -> b
$ IO Int64 -> Int64
forall a. IO a -> a
unsafeDupablePerformIO IO Int64
c_get_sample_rate
newtype SileroModel = SileroModel
{ SileroModel -> Ptr ()
api :: Ptr ()
}
deriving ((forall x. SileroModel -> Rep SileroModel x)
-> (forall x. Rep SileroModel x -> SileroModel)
-> Generic SileroModel
forall x. Rep SileroModel x -> SileroModel
forall x. SileroModel -> Rep SileroModel x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. SileroModel -> Rep SileroModel x
from :: forall x. SileroModel -> Rep SileroModel x
$cto :: forall x. Rep SileroModel x -> SileroModel
to :: forall x. Rep SileroModel x -> SileroModel
Generic)
instance Storable SileroModel where
sizeOf :: SileroModel -> Int
sizeOf SileroModel
_ = Ptr () -> Int
forall a. Storable a => a -> Int
sizeOf (Ptr ()
forall a. HasCallStack => a
undefined :: Ptr ())
alignment :: SileroModel -> Int
alignment SileroModel
_ = Ptr () -> Int
forall a. Storable a => a -> Int
alignment (Ptr ()
forall a. HasCallStack => a
undefined :: Ptr ())
peek :: Ptr SileroModel -> IO SileroModel
peek Ptr SileroModel
ptr = do
Ptr ()
apiPtr <- Ptr (Ptr ()) -> IO (Ptr ())
forall a. Storable a => Ptr a -> IO a
peek (Ptr SileroModel -> Ptr (Ptr ())
forall a b. Ptr a -> Ptr b
castPtr Ptr SileroModel
ptr)
SileroModel -> IO SileroModel
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SileroModel -> IO SileroModel) -> SileroModel -> IO SileroModel
forall a b. (a -> b) -> a -> b
$ Ptr () -> SileroModel
SileroModel Ptr ()
apiPtr
poke :: Ptr SileroModel -> SileroModel -> IO ()
poke Ptr SileroModel
ptr (SileroModel Ptr ()
apiPtr) = Ptr (Ptr ()) -> Ptr () -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr SileroModel -> Ptr (Ptr ())
forall a b. Ptr a -> Ptr b
castPtr Ptr SileroModel
ptr) Ptr ()
apiPtr
libraryPath :: FilePath
#if defined(linux_HOST_OS)
libraryPath :: String
libraryPath = String
"lib/onnxruntime/linux-x64/libonnxruntime.so"
#elif defined(darwin_HOST_OS) && defined(aarch64_HOST_ARCH)
libraryPath = "lib/onnxruntime/mac-arm64/libonnxruntime.dylib"
#elif defined(darwin_HOST_OS) && !defined(aarch64_HOST_ARCH)
libraryPath = "lib/onnxruntime/mac-x64/libonnxruntime.dylib"
#else
libraryPath = "lib/onnxruntime/windows-x64/onnxruntime.dll"
#endif
{-# NOINLINE onnxruntime #-}
onnxruntime :: FunPtr ()
#if defined (linux_HOST_OS) || defined (darwin_HOST_OS)
onnxruntime :: FunPtr ()
onnxruntime =
IO (FunPtr ()) -> FunPtr ()
forall a. IO a -> a
unsafePerformIO (IO (FunPtr ()) -> FunPtr ()) -> IO (FunPtr ()) -> FunPtr ()
forall a b. (a -> b) -> a -> b
$
String -> IO String
getDataFileName String
libraryPath
IO String -> (String -> IO DL) -> IO DL
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (String -> [RTLDFlags] -> IO DL) -> [RTLDFlags] -> String -> IO DL
forall a b c. (a -> b -> c) -> b -> a -> c
flip String -> [RTLDFlags] -> IO DL
dlopen [RTLDFlags
RTLD_NOW]
IO DL -> (DL -> IO (FunPtr ())) -> IO (FunPtr ())
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (IO (FunPtr ()) -> IO (FunPtr ())
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (FunPtr ()) -> IO (FunPtr ()))
-> (DL -> IO (FunPtr ())) -> DL -> IO (FunPtr ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DL -> String -> IO (FunPtr ())) -> String -> DL -> IO (FunPtr ())
forall a b c. (a -> b -> c) -> b -> a -> c
flip DL -> String -> IO (FunPtr ())
forall a. DL -> String -> IO (FunPtr a)
dlsym String
"OrtGetApiBase")
#else
onnxruntime =
unsafePerformIO $
getDataFileName libraryPath
>>= loadLibrary
>>= (fmap castPtrToFunPtr . flip getProcAddress "OrtGetApiBase")
#endif
getModelPath :: IO String
getModelPath :: IO String
getModelPath = String -> IO String
getDataFileName String
"lib/silero_vad.onnx"
#if defined (linux_HOST_OS) || defined (darwin_HOST_OS)
withModelPath :: (CString -> IO a) -> IO a
withModelPath :: forall a. (CString -> IO a) -> IO a
withModelPath CString -> IO a
runModelPath = do
String
modelPath <- IO String
getModelPath
String -> (CString -> IO a) -> IO a
forall a. String -> (CString -> IO a) -> IO a
withCString String
modelPath ((CString -> IO a) -> IO a) -> (CString -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ CString -> IO a
runModelPath
#else
withModelPath :: (CWString -> IO a) -> IO a
withModelPath runModelPath = do
modelPath <- getModelPath
withCWString modelPath runModelPath
#endif
withModel :: (MonadUnliftIO m) => (SileroModel -> m a) -> m a
withModel :: forall (m :: * -> *) a.
MonadUnliftIO m =>
(SileroModel -> m a) -> m a
withModel SileroModel -> m a
runModel = do
m SileroModel
-> (SileroModel -> m ()) -> (SileroModel -> m a) -> m a
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
(Ptr () -> SileroModel
SileroModel (Ptr () -> SileroModel) -> m (Ptr ()) -> m SileroModel
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (Ptr ()) -> m (Ptr ())
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO ((CString -> IO (Ptr ())) -> IO (Ptr ())
forall a. (CString -> IO a) -> IO a
withModelPath ((CString -> IO (Ptr ())) -> IO (Ptr ()))
-> (CString -> IO (Ptr ())) -> IO (Ptr ())
forall a b. (a -> b) -> a -> b
$ FunPtr () -> CString -> IO (Ptr ())
c_load_model FunPtr ()
onnxruntime))
(IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (SileroModel -> IO ()) -> SileroModel -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr () -> IO ()
c_release_model (Ptr () -> IO ())
-> (SileroModel -> Ptr ()) -> SileroModel -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SileroModel -> Ptr ()
api)
SileroModel -> m a
runModel
resetModel :: (MonadIO m) => SileroModel -> m ()
resetModel :: forall (m :: * -> *). MonadIO m => SileroModel -> m ()
resetModel = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (SileroModel -> IO ()) -> SileroModel -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr () -> IO ()
c_reset_model (Ptr () -> IO ())
-> (SileroModel -> Ptr ()) -> SileroModel -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SileroModel -> Ptr ()
api
detectSpeech :: (MonadIO m) => SileroModel -> Vector Float -> m Float
detectSpeech :: forall (m :: * -> *).
MonadIO m =>
SileroModel -> Vector Float -> m Float
detectSpeech (SileroModel Ptr ()
api) Vector Float
samples
| Vector Float -> Int
forall a. Storable a => Vector a -> Int
Vector.length Vector Float
samples Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
windowLength =
Float -> m Float
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Float
0.0
| Bool
otherwise =
IO Float -> m Float
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Float -> m Float)
-> ((Ptr Float -> IO Float) -> IO Float)
-> (Ptr Float -> IO Float)
-> m Float
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Float -> (Ptr Float -> IO Float) -> IO Float
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
Vector.unsafeWith Vector Float
samples ((Ptr Float -> IO Float) -> m Float)
-> (Ptr Float -> IO Float) -> m Float
forall a b. (a -> b) -> a -> b
$
Ptr () -> Ptr Float -> IO Float
c_detect_speech Ptr ()
api