{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Menoh
(
Dims
, DType (..)
, HasDType (..)
, Error (..)
, ModelData (..)
, makeModelDataFromONNX
, optimizeModelData
, Model (..)
, makeModel
, makeModelWithConfig
, run
, getDType
, getDims
, unsafeGetBuffer
, withBuffer
, writeBufferFromVector
, writeBufferFromStorableVector
, readBufferToVector
, readBufferToStorableVector
, version
, bindingVersion
, VariableProfileTable (..)
, makeVariableProfileTable
, vptGetDType
, vptGetDims
, VariableProfileTableBuilder (..)
, makeVariableProfileTableBuilder
, addInputProfileDims2
, addInputProfileDims4
, addOutputProfile
, buildVariableProfileTable
, ModelBuilder (..)
, makeModelBuilder
, attachExternalBuffer
, buildModel
, buildModelWithConfig
) where
import Control.Concurrent
import Control.Monad
import Control.Monad.Trans.Control (MonadBaseControl, liftBaseOp)
import Control.Monad.IO.Class
import Control.Exception
import qualified Data.Aeson as J
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import Data.Proxy
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as VSM
import qualified Data.Vector.Generic as VG
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import Data.Version
import Foreign
import Foreign.C
import qualified Menoh.Base as Base
import qualified Paths_menoh
#include "MachDeps.h"
data Error
= ErrorStdError String
| ErrorUnknownError String
| ErrorInvalidFilename String
| ErrorONNXParseError String
| ErrorInvalidDType String
| ErrorInvalidAttributeType String
| ErrorUnsupportedOperatorAttribute String
| ErrorDimensionMismatch String
| ErrorVariableNotFound String
| ErrorIndexOutOfRange String
| ErrorJSONParseError String
| ErrorInvalidBackendName String
| ErrorUnsupportedOperator String
| ErrorFailedToConfigureOperator String
| ErrorBackendError String
| ErrorSameNamedVariableAlreadyExist String
deriving (Eq, Ord, Show, Read)
instance Exception Error
runMenoh :: IO Base.MenohErrorCode -> IO ()
runMenoh m = runInBoundThread' $ do
e <- m
if e == Base.menohErrorCodeSuccess then
return ()
else do
s <- peekCString =<< Base.menoh_get_last_error_message
case IntMap.lookup (fromIntegral e) table of
Just ex -> throwIO $ ex s
Nothing -> throwIO $ ErrorUnknownError $ s ++ "(error code: " ++ show (fromIntegral e :: Int) ++ ")"
where
table :: IntMap (String -> Error)
table = IntMap.fromList $ map (\(k,v) -> (fromIntegral k, v)) $
[ (Base.menohErrorCodeStdError , ErrorStdError)
, (Base.menohErrorCodeUnknownError , ErrorUnknownError)
, (Base.menohErrorCodeInvalidFilename , ErrorInvalidFilename)
, (Base.menohErrorCodeOnnxParseError , ErrorONNXParseError)
, (Base.menohErrorCodeInvalidDtype , ErrorInvalidDType)
, (Base.menohErrorCodeInvalidAttributeType , ErrorInvalidAttributeType)
, (Base.menohErrorCodeUnsupportedOperatorAttribute , ErrorUnsupportedOperatorAttribute)
, (Base.menohErrorCodeDimensionMismatch , ErrorDimensionMismatch)
, (Base.menohErrorCodeVariableNotFound , ErrorVariableNotFound)
, (Base.menohErrorCodeIndexOutOfRange , ErrorIndexOutOfRange)
, (Base.menohErrorCodeJsonParseError , ErrorJSONParseError)
, (Base.menohErrorCodeInvalidBackendName , ErrorInvalidBackendName)
, (Base.menohErrorCodeUnsupportedOperator , ErrorUnsupportedOperator)
, (Base.menohErrorCodeFailedToConfigureOperator , ErrorFailedToConfigureOperator)
, (Base.menohErrorCodeBackendError , ErrorBackendError)
, (Base.menohErrorCodeSameNamedVariableAlreadyExist , ErrorSameNamedVariableAlreadyExist)
]
runInBoundThread' :: IO a -> IO a
runInBoundThread' action
| rtsSupportsBoundThreads = runInBoundThread action
| otherwise = action
data DType
= DTypeFloat
| DTypeUnknown !Base.MenohDType
deriving (Eq, Ord, Show, Read)
instance Enum DType where
toEnum x
| x == fromIntegral Base.menohDtypeFloat = DTypeFloat
| otherwise = DTypeUnknown (fromIntegral x)
fromEnum DTypeFloat = fromIntegral Base.menohDtypeFloat
fromEnum (DTypeUnknown i) = fromIntegral i
class Storable a => HasDType a where
dtypeOf :: Proxy a -> DType
instance HasDType CFloat where
dtypeOf _ = DTypeFloat
#if SIZEOF_HSFLOAT == SIZEOF_FLOAT
instance HasDType Float where
dtypeOf _ = DTypeFloat
#endif
type Dims = [Int]
newtype ModelData = ModelData (ForeignPtr Base.MenohModelData)
makeModelDataFromONNX :: MonadIO m => FilePath -> m ModelData
makeModelDataFromONNX fpath = liftIO $ withCString fpath $ \fpath' -> alloca $ \ret -> do
runMenoh $ Base.menoh_make_model_data_from_onnx fpath' ret
liftM ModelData $ newForeignPtr Base.menoh_delete_model_data_funptr =<< peek ret
optimizeModelData :: MonadIO m => ModelData -> VariableProfileTable -> m ()
optimizeModelData (ModelData m) (VariableProfileTable vpt) = liftIO $
withForeignPtr m $ \m' -> withForeignPtr vpt $ \vpt' ->
runMenoh $ Base.menoh_model_data_optimize m' vpt'
newtype VariableProfileTableBuilder
= VariableProfileTableBuilder (ForeignPtr Base.MenohVariableProfileTableBuilder)
makeVariableProfileTableBuilder :: MonadIO m => m VariableProfileTableBuilder
makeVariableProfileTableBuilder = liftIO $ alloca $ \p -> do
runMenoh $ Base.menoh_make_variable_profile_table_builder p
liftM VariableProfileTableBuilder $ newForeignPtr Base.menoh_delete_variable_profile_table_builder_funptr =<< peek p
addInputProfileDims :: MonadIO m => VariableProfileTableBuilder -> String -> DType -> Dims -> m ()
addInputProfileDims vpt name dtype dims =
case dims of
[num, size] -> addInputProfileDims2 vpt name dtype (num, size)
[num, channel, height, width] -> addInputProfileDims4 vpt name dtype (num, channel, height, width)
_ -> liftIO $ throwIO $ ErrorDimensionMismatch $ "Menoh.addInputProfileDims: cannot handle dims of length " ++ show (length dims)
addInputProfileDims2
:: MonadIO m
=> VariableProfileTableBuilder
-> String
-> DType
-> (Int, Int)
-> m ()
addInputProfileDims2 (VariableProfileTableBuilder vpt) name dtype (num, size) = liftIO $
withForeignPtr vpt $ \vpt' -> withCString name $ \name' ->
runMenoh $ Base.menoh_variable_profile_table_builder_add_input_profile_dims_2
vpt' name' (fromIntegral (fromEnum dtype))
(fromIntegral num) (fromIntegral size)
addInputProfileDims4
:: MonadIO m
=> VariableProfileTableBuilder
-> String
-> DType
-> (Int, Int, Int, Int)
-> m ()
addInputProfileDims4 (VariableProfileTableBuilder vpt) name dtype (num, channel, height, width) = liftIO $
withForeignPtr vpt $ \vpt' -> withCString name $ \name' ->
runMenoh $ Base.menoh_variable_profile_table_builder_add_input_profile_dims_4
vpt' name' (fromIntegral (fromEnum dtype))
(fromIntegral num) (fromIntegral channel) (fromIntegral height) (fromIntegral width)
addOutputProfile :: MonadIO m => VariableProfileTableBuilder -> String -> DType -> m ()
addOutputProfile (VariableProfileTableBuilder vpt) name dtype = liftIO $
withForeignPtr vpt $ \vpt' -> withCString name $ \name' ->
runMenoh $ Base.menoh_variable_profile_table_builder_add_output_profile
vpt' name' (fromIntegral (fromEnum dtype))
buildVariableProfileTable
:: MonadIO m
=> VariableProfileTableBuilder
-> ModelData
-> m VariableProfileTable
buildVariableProfileTable (VariableProfileTableBuilder b) (ModelData m) = liftIO $
withForeignPtr b $ \b' -> withForeignPtr m $ \m' -> alloca $ \ret -> do
runMenoh $ Base.menoh_build_variable_profile_table b' m' ret
liftM VariableProfileTable $ newForeignPtr Base.menoh_delete_variable_profile_table_funptr =<< peek ret
newtype VariableProfileTable
= VariableProfileTable (ForeignPtr Base.MenohVariableProfileTable)
makeVariableProfileTable
:: MonadIO m
=> [(String, DType, Dims)]
-> [(String, DType)]
-> ModelData
-> m VariableProfileTable
makeVariableProfileTable input_name_and_dims_pair_list required_output_name_list model_data = liftIO $ do
b <- makeVariableProfileTableBuilder
forM_ input_name_and_dims_pair_list $ \(name,dtype,dims) -> do
addInputProfileDims b name dtype dims
forM_ required_output_name_list $ \(name,dtype) -> do
addOutputProfile b name dtype
buildVariableProfileTable b model_data
vptGetDType :: MonadIO m => VariableProfileTable -> String -> m DType
vptGetDType (VariableProfileTable vpt) name = liftIO $
withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> alloca $ \ret -> do
runMenoh $ Base.menoh_variable_profile_table_get_dims_size vpt' name' ret
(toEnum . fromIntegral) <$> peek ret
vptGetDims :: MonadIO m => VariableProfileTable -> String -> m Dims
vptGetDims (VariableProfileTable vpt) name = liftIO $
withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> alloca $ \ret -> do
runMenoh $ Base.menoh_variable_profile_table_get_dims_size vpt' name' ret
size <- peek ret
forM [0..size-1] $ \i -> do
runMenoh $ Base.menoh_variable_profile_table_get_dims_at vpt' name' (fromIntegral i) ret
fromIntegral <$> peek ret
newtype ModelBuilder = ModelBuilder (ForeignPtr Base.MenohModelBuilder)
makeModelBuilder :: MonadIO m => VariableProfileTable -> m ModelBuilder
makeModelBuilder (VariableProfileTable vpt) = liftIO $
withForeignPtr vpt $ \vpt' -> alloca $ \ret -> do
runMenoh $ Base.menoh_make_model_builder vpt' ret
liftM ModelBuilder $ newForeignPtr Base.menoh_delete_model_builder_funptr =<< peek ret
attachExternalBuffer :: MonadIO m => ModelBuilder -> String -> Ptr a -> m ()
attachExternalBuffer (ModelBuilder m) name buf = liftIO $
withForeignPtr m $ \m' -> withCString name $ \name' ->
runMenoh $ Base.menoh_model_builder_attach_external_buffer m' name' buf
buildModel
:: MonadIO m
=> ModelBuilder
-> ModelData
-> String
-> m Model
buildModel builder m backend = liftIO $
withCString "" $
buildModelWithConfigString builder m backend
buildModelWithConfig
:: (MonadIO m, J.ToJSON a)
=> ModelBuilder
-> ModelData
-> String
-> a
-> m Model
buildModelWithConfig builder m backend backend_config = liftIO $
BS.useAsCString (BL.toStrict (J.encode backend_config)) $
buildModelWithConfigString builder m backend
buildModelWithConfigString
:: MonadIO m
=> ModelBuilder
-> ModelData
-> String
-> CString
-> m Model
buildModelWithConfigString (ModelBuilder builder) (ModelData m) backend backend_config = liftIO $
withForeignPtr builder $ \builder' -> withForeignPtr m $ \m' -> withCString backend $ \backend' -> alloca $ \ret -> do
runMenoh $ Base.menoh_build_model builder' m' backend' backend_config ret
liftM Model $ newForeignPtr Base.menoh_delete_model_funptr =<< peek ret
newtype Model = Model (ForeignPtr Base.MenohModel)
run :: MonadIO m => Model -> m ()
run (Model model) = liftIO $ withForeignPtr model $ \model' -> do
runMenoh $ Base.menoh_model_run model'
getDType :: MonadIO m => Model -> String -> m DType
getDType (Model m) name = liftIO $ do
withForeignPtr m $ \m' -> withCString name $ \name' -> alloca $ \ret -> do
runMenoh $ Base.menoh_model_get_variable_dtype m' name' ret
liftM (toEnum . fromIntegral) $ peek ret
getDims :: MonadIO m => Model -> String -> m Dims
getDims (Model m) name = liftIO $ do
withForeignPtr m $ \m' -> withCString name $ \name' -> alloca $ \ret -> do
runMenoh $ Base.menoh_model_get_variable_dims_size m' name' ret
size <- peek ret
forM [0..size-1] $ \i -> do
runMenoh $ Base.menoh_model_get_variable_dims_at m' name' (fromIntegral i) ret
fromIntegral <$> peek ret
unsafeGetBuffer :: MonadIO m => Model -> String -> m (Ptr a)
unsafeGetBuffer (Model m) name = liftIO $ do
withForeignPtr m $ \m' -> withCString name $ \name' -> alloca $ \ret -> do
runMenoh $ Base.menoh_model_get_variable_buffer_handle m' name' ret
peek ret
withBuffer :: forall m r a. (MonadIO m, MonadBaseControl IO m) => Model -> String -> (Ptr a -> m r) -> m r
withBuffer (Model m) name f =
liftBaseOp (withForeignPtr m) $ \m' ->
(liftBaseOp (withCString name) :: (CString -> m r) -> m r) $ \name' ->
liftBaseOp alloca $ \ret -> do
p <- liftIO $ do
runMenoh $ Base.menoh_model_get_variable_buffer_handle m' name' ret
peek ret
f p
checkDType :: String -> DType -> DType -> IO ()
checkDType name dtype1 dtype2
| dtype1 /= dtype2 = throwIO $ ErrorInvalidDType $ name ++ ": dtype mismatch"
| otherwise = return ()
checkDTypeAndSize :: String -> (DType,Int) -> (DType,Int) -> IO ()
checkDTypeAndSize name (dtype1,n1) (dtype2,n2)
| dtype1 /= dtype2 = throwIO $ ErrorInvalidDType $ name ++ ": dtype mismatch"
| n1 /= n2 = throwIO $ ErrorDimensionMismatch $ name ++ ": dimension mismatch"
| otherwise = return ()
writeBufferFromVector :: forall v a m. (VG.Vector v a, HasDType a, MonadIO m) => Model -> String -> v a -> m ()
writeBufferFromVector model name vec = liftIO $ withBuffer model name $ \p -> do
dtype <- getDType model name
dims <- getDims model name
let n = product dims
checkDTypeAndSize "Menoh.writeBufferFromVector" (dtype, n) (dtypeOf (Proxy :: Proxy a), VG.length vec)
forM_ [0..n-1] $ \i -> do
pokeElemOff p i (vec VG.! i)
writeBufferFromStorableVector :: forall a m. (HasDType a, MonadIO m) => Model -> String -> VS.Vector a -> m ()
writeBufferFromStorableVector model name vec = liftIO $ withBuffer model name $ \p -> do
dtype <- getDType model name
dims <- getDims model name
let n = product dims
checkDTypeAndSize "Menoh.writeBufferFromStorableVector" (dtype, n) (dtypeOf (Proxy :: Proxy a), VG.length vec)
VS.unsafeWith vec $ \src -> do
copyArray p src n
readBufferToVector :: forall v a m. (VG.Vector v a, HasDType a, MonadIO m) => Model -> String -> m (v a)
readBufferToVector model name = liftIO $ withBuffer model name $ \p -> do
dtype <- getDType model name
dims <- getDims model name
checkDType "Menoh.readBufferToVector" dtype (dtypeOf (Proxy :: Proxy a))
let n = product dims
VG.generateM n $ peekElemOff p
readBufferToStorableVector :: forall a m. (HasDType a, MonadIO m) => Model -> String -> m (VS.Vector a)
readBufferToStorableVector model name = liftIO $ withBuffer model name $ \p -> do
dtype <- getDType model name
dims <- getDims model name
checkDType "Menoh.readBufferToStorableVector" dtype (dtypeOf (Proxy :: Proxy a))
let n = product dims
vec <- VSM.new n
VSM.unsafeWith vec $ \dst -> copyArray dst p n
VS.unsafeFreeze vec
makeModel
:: MonadIO m
=> VariableProfileTable
-> ModelData
-> String
-> m Model
makeModel vpt model_data backend_name = liftIO $ do
b <- makeModelBuilder vpt
buildModel b model_data backend_name
makeModelWithConfig
:: (MonadIO m, J.ToJSON a)
=> VariableProfileTable
-> ModelData
-> String
-> a
-> m Model
makeModelWithConfig vpt model_data backend_name backend_config = liftIO $ do
b <- makeModelBuilder vpt
buildModelWithConfig b model_data backend_name backend_config
version :: Version
version = makeVersion [Base.menoh_major_version, Base.menoh_minor_version, Base.menoh_patch_version]
bindingVersion :: Version
bindingVersion = Paths_menoh.version