{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.Script where
import Control.Exception.Safe (throwIO)
import Control.Monad (forM, forM_, replicateM)
import Data.Int (Int16, Int64)
import Data.List (intercalate)
import Data.Proxy
import Data.Reflection
import Data.Word (Word8)
import Foreign.C.Types
import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Storable
import Numeric
import System.IO.Unsafe
import Torch.Autograd
import Torch.DType
import Torch.Device
import Torch.Internal.Cast
import Torch.Internal.Class (Castable (..), CppObject (..), CppTuple2 (..), CppTuple3 (..), CppTuple4 (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Managed.Cast as ATen
import qualified Torch.Internal.Managed.Native as ATen
import qualified Torch.Internal.Managed.Type.Context as ATen
import Torch.Internal.Managed.Type.IValue
import qualified Torch.Internal.Managed.Type.Module as LibTorch
import qualified Torch.Internal.Managed.Type.StdArray as ATen
import qualified Torch.Internal.Managed.Type.StdString as ATen
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import qualified Torch.Internal.Managed.Type.TensorOptions as ATen
import Torch.Internal.Type (TensorList)
import qualified Torch.Internal.Type as ATen
import Torch.Internal.Unmanaged.Type.C10Dict
import Torch.Internal.Unmanaged.Type.IValue (IValueLike (..))
import qualified Torch.Internal.Unmanaged.Type.Module as Unmanaged
import Torch.NN
import Torch.Tensor (Tensor (..), toDevice)
import Torch.TensorOptions
newtype ScriptModule = UnsafeScriptModule (ForeignPtr ATen.Module)
newtype RawModule = UnsafeRawModule (ForeignPtr ATen.Module)
instance Show ScriptModule where
show :: ScriptModule -> String
show ScriptModule
obj = IO String -> String
forall a. IO a -> a
unsafePerformIO (IO String -> String) -> IO String -> String
forall a b. (a -> b) -> a -> b
$ ScriptModule -> IO String
dumpToStr' ScriptModule
obj
type RawIValue = ForeignPtr ATen.IValue
newtype Blob = UnsafeBlob (ForeignPtr (ATen.C10Ptr ATen.Blob))
newtype Object = UnsafeObject (ForeignPtr (ATen.C10Ptr ATen.IVObject))
newtype Future = UnsafeFuture (ForeignPtr (ATen.C10Ptr ATen.IVFuture))
newtype Capsule = UnsafeCapsule (ForeignPtr (ATen.C10Ptr ATen.Capsule))
newtype Graph = UnsafeGraph (ForeignPtr (ATen.SharedPtr ATen.JitGraph))
data JitGraph = JitGraph
{ JitGraph -> [JitValue]
graphInputs :: [JitValue],
JitGraph -> [JitValue]
graphOutputs :: [JitValue],
JitGraph -> [JitNode]
graphNodes :: [JitNode]
}
deriving (Int -> JitGraph -> ShowS
[JitGraph] -> ShowS
JitGraph -> String
(Int -> JitGraph -> ShowS)
-> (JitGraph -> String) -> ([JitGraph] -> ShowS) -> Show JitGraph
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> JitGraph -> ShowS
showsPrec :: Int -> JitGraph -> ShowS
$cshow :: JitGraph -> String
show :: JitGraph -> String
$cshowList :: [JitGraph] -> ShowS
showList :: [JitGraph] -> ShowS
Show, JitGraph -> JitGraph -> Bool
(JitGraph -> JitGraph -> Bool)
-> (JitGraph -> JitGraph -> Bool) -> Eq JitGraph
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: JitGraph -> JitGraph -> Bool
== :: JitGraph -> JitGraph -> Bool
$c/= :: JitGraph -> JitGraph -> Bool
/= :: JitGraph -> JitGraph -> Bool
Eq)
data JitNode = JitNode
{ JitNode -> [JitValue]
nodeInputs :: [JitValue],
JitNode -> [JitValue]
nodeOutputs :: [JitValue],
JitNode -> String
nodeKind :: String
}
deriving (Int -> JitNode -> ShowS
[JitNode] -> ShowS
JitNode -> String
(Int -> JitNode -> ShowS)
-> (JitNode -> String) -> ([JitNode] -> ShowS) -> Show JitNode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> JitNode -> ShowS
showsPrec :: Int -> JitNode -> ShowS
$cshow :: JitNode -> String
show :: JitNode -> String
$cshowList :: [JitNode] -> ShowS
showList :: [JitNode] -> ShowS
Show, JitNode -> JitNode -> Bool
(JitNode -> JitNode -> Bool)
-> (JitNode -> JitNode -> Bool) -> Eq JitNode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: JitNode -> JitNode -> Bool
== :: JitNode -> JitNode -> Bool
$c/= :: JitNode -> JitNode -> Bool
/= :: JitNode -> JitNode -> Bool
Eq)
data JitValue = JitValue
{ JitValue -> Int
valueId :: Int,
JitValue -> String
valueType :: String
}
deriving (Int -> JitValue -> ShowS
[JitValue] -> ShowS
JitValue -> String
(Int -> JitValue -> ShowS)
-> (JitValue -> String) -> ([JitValue] -> ShowS) -> Show JitValue
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> JitValue -> ShowS
showsPrec :: Int -> JitValue -> ShowS
$cshow :: JitValue -> String
show :: JitValue -> String
$cshowList :: [JitValue] -> ShowS
showList :: [JitValue] -> ShowS
Show, JitValue -> JitValue -> Bool
(JitValue -> JitValue -> Bool)
-> (JitValue -> JitValue -> Bool) -> Eq JitValue
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: JitValue -> JitValue -> Bool
== :: JitValue -> JitValue -> Bool
$c/= :: JitValue -> JitValue -> Bool
/= :: JitValue -> JitValue -> Bool
Eq)
instance Show Blob where
show :: Blob -> String
show Blob
_ = String
"Blob"
instance Show Future where
show :: Future -> String
show Future
_ = String
"Future"
instance Show Object where
show :: Object -> String
show Object
_ = String
"Object"
instance Show Capsule where
show :: Capsule -> String
show Capsule
_ = String
"Capsule"
data IValue
= IVNone
| IVTensor Tensor
| IVDouble Double
| IVInt Int64
| IVBool Bool
| IVTuple [IValue]
| IVIntList [Int64]
| IVDoubleList [Double]
| IVBoolList [Bool]
| IVString String
| IVTensorList [Tensor]
| IVBlob
| IVGenericList [IValue]
| IVGenericDict [(IValue, IValue)]
| IVFuture
| IVDevice
| IVObject
| IVUninitialized
| IVCapsule
deriving (Int -> IValue -> ShowS
[IValue] -> ShowS
IValue -> String
(Int -> IValue -> ShowS)
-> (IValue -> String) -> ([IValue] -> ShowS) -> Show IValue
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> IValue -> ShowS
showsPrec :: Int -> IValue -> ShowS
$cshow :: IValue -> String
show :: IValue -> String
$cshowList :: [IValue] -> ShowS
showList :: [IValue] -> ShowS
Show)
instance Castable ScriptModule (ForeignPtr ATen.Module) where
cast :: forall r. ScriptModule -> (ForeignPtr Module -> IO r) -> IO r
cast (UnsafeScriptModule ForeignPtr Module
obj) ForeignPtr Module -> IO r
f = ForeignPtr Module -> IO r
f ForeignPtr Module
obj
uncast :: forall r. ForeignPtr Module -> (ScriptModule -> IO r) -> IO r
uncast ForeignPtr Module
obj ScriptModule -> IO r
f = ScriptModule -> IO r
f (ScriptModule -> IO r) -> ScriptModule -> IO r
forall a b. (a -> b) -> a -> b
$ ForeignPtr Module -> ScriptModule
UnsafeScriptModule ForeignPtr Module
obj
instance Castable RawModule (ForeignPtr ATen.Module) where
cast :: forall r. RawModule -> (ForeignPtr Module -> IO r) -> IO r
cast (UnsafeRawModule ForeignPtr Module
obj) ForeignPtr Module -> IO r
f = ForeignPtr Module -> IO r
f ForeignPtr Module
obj
uncast :: forall r. ForeignPtr Module -> (RawModule -> IO r) -> IO r
uncast ForeignPtr Module
obj RawModule -> IO r
f = RawModule -> IO r
f (RawModule -> IO r) -> RawModule -> IO r
forall a b. (a -> b) -> a -> b
$ ForeignPtr Module -> RawModule
UnsafeRawModule ForeignPtr Module
obj
instance Castable Graph (ForeignPtr (ATen.SharedPtr ATen.JitGraph)) where
cast :: forall r.
Graph -> (ForeignPtr (SharedPtr JitGraph) -> IO r) -> IO r
cast (UnsafeGraph ForeignPtr (SharedPtr JitGraph)
obj) ForeignPtr (SharedPtr JitGraph) -> IO r
f = ForeignPtr (SharedPtr JitGraph) -> IO r
f ForeignPtr (SharedPtr JitGraph)
obj
uncast :: forall r.
ForeignPtr (SharedPtr JitGraph) -> (Graph -> IO r) -> IO r
uncast ForeignPtr (SharedPtr JitGraph)
obj Graph -> IO r
f = Graph -> IO r
f (Graph -> IO r) -> Graph -> IO r
forall a b. (a -> b) -> a -> b
$ ForeignPtr (SharedPtr JitGraph) -> Graph
UnsafeGraph ForeignPtr (SharedPtr JitGraph)
obj
newModule :: String -> IO RawModule
newModule :: String -> IO RawModule
newModule = (ForeignPtr StdString -> IO (ForeignPtr Module))
-> String -> IO RawModule
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr StdString -> IO (ForeignPtr Module)
LibTorch.newModule
saveScript :: ScriptModule -> FilePath -> IO ()
saveScript :: ScriptModule -> String -> IO ()
saveScript = (ForeignPtr Module -> String -> IO ())
-> ScriptModule -> String -> IO ()
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> String -> IO ()
LibTorch.save
saveScript' :: RawModule -> FilePath -> IO ()
saveScript' :: RawModule -> String -> IO ()
saveScript' = (ForeignPtr Module -> String -> IO ())
-> RawModule -> String -> IO ()
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> String -> IO ()
LibTorch.save
data LoadMode
= WithoutRequiredGrad
| WithRequiredGrad
deriving (Int -> LoadMode -> ShowS
[LoadMode] -> ShowS
LoadMode -> String
(Int -> LoadMode -> ShowS)
-> (LoadMode -> String) -> ([LoadMode] -> ShowS) -> Show LoadMode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> LoadMode -> ShowS
showsPrec :: Int -> LoadMode -> ShowS
$cshow :: LoadMode -> String
show :: LoadMode -> String
$cshowList :: [LoadMode] -> ShowS
showList :: [LoadMode] -> ShowS
Show, LoadMode -> LoadMode -> Bool
(LoadMode -> LoadMode -> Bool)
-> (LoadMode -> LoadMode -> Bool) -> Eq LoadMode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: LoadMode -> LoadMode -> Bool
== :: LoadMode -> LoadMode -> Bool
$c/= :: LoadMode -> LoadMode -> Bool
/= :: LoadMode -> LoadMode -> Bool
Eq)
loadScript :: LoadMode -> FilePath -> IO ScriptModule
loadScript :: LoadMode -> String -> IO ScriptModule
loadScript LoadMode
WithoutRequiredGrad String
file = (String -> IO (ForeignPtr Module)) -> String -> IO ScriptModule
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 String -> IO (ForeignPtr Module)
LibTorch.load String
file
loadScript LoadMode
WithRequiredGrad String
file = do
module' :: RawModule
module'@(UnsafeRawModule ForeignPtr Module
rmodule) <- (String -> IO (ForeignPtr Module)) -> String -> IO RawModule
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 String -> IO (ForeignPtr Module)
LibTorch.load String
file
[Tensor]
params <- RawModule -> IO [Tensor]
getParametersIO RawModule
module'
[IndependentTensor]
paramsWithRequiredGrad <- [Tensor]
-> (Tensor -> IO IndependentTensor) -> IO [IndependentTensor]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Tensor]
params Tensor -> IO IndependentTensor
makeIndependent
RawModule -> [Tensor] -> IO ()
setParameters RawModule
module' ((IndependentTensor -> Tensor) -> [IndependentTensor] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
map IndependentTensor -> Tensor
toDependent [IndependentTensor]
paramsWithRequiredGrad)
ScriptModule -> IO ScriptModule
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ForeignPtr Module -> ScriptModule
UnsafeScriptModule ForeignPtr Module
rmodule)
loadScript' :: FilePath -> IO RawModule
loadScript' :: String -> IO RawModule
loadScript' = (String -> IO (ForeignPtr Module)) -> String -> IO RawModule
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 String -> IO (ForeignPtr Module)
LibTorch.load
instance HasForward ScriptModule [IValue] IValue where
forward :: ScriptModule -> [IValue] -> IValue
forward ScriptModule
module' = IO IValue -> IValue
forall a. IO a -> a
unsafePerformIO (IO IValue -> IValue)
-> ([IValue] -> IO IValue) -> [IValue] -> IValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScriptModule -> [IValue] -> IO IValue
forall f a b. HasForward f a b => f -> a -> IO b
forwardStoch ScriptModule
module'
forwardStoch :: ScriptModule -> [IValue] -> IO IValue
forwardStoch = (ScriptModule -> [RawIValue] -> IO RawIValue)
-> ScriptModule -> [IValue] -> IO IValue
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ScriptModule -> [RawIValue] -> IO RawIValue
forward'
where
forward' :: ScriptModule -> [RawIValue] -> IO RawIValue
forward' :: ScriptModule -> [RawIValue] -> IO RawIValue
forward' = (ForeignPtr Module
-> ForeignPtr (StdVector IValue) -> IO RawIValue)
-> ScriptModule -> [RawIValue] -> IO RawIValue
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> ForeignPtr (StdVector IValue) -> IO RawIValue
LibTorch.forward
registerParameter :: RawModule -> String -> Tensor -> Bool -> IO ()
registerParameter :: RawModule -> String -> Tensor -> Bool -> IO ()
registerParameter = (ForeignPtr Module
-> ForeignPtr StdString -> ForeignPtr Tensor -> CBool -> IO ())
-> RawModule -> String -> Tensor -> Bool -> IO ()
forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Module
-> ForeignPtr StdString -> ForeignPtr Tensor -> CBool -> IO ()
LibTorch.registerParameter
registerModule :: RawModule -> String -> RawModule -> IO ()
registerModule :: RawModule -> String -> RawModule -> IO ()
registerModule = (ForeignPtr Module
-> ForeignPtr StdString -> ForeignPtr Module -> IO ())
-> RawModule -> String -> RawModule -> IO ()
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Module
-> ForeignPtr StdString -> ForeignPtr Module -> IO ()
LibTorch.registerModule
getParameters ::
ScriptModule ->
[Tensor]
getParameters :: ScriptModule -> [Tensor]
getParameters = IO [Tensor] -> [Tensor]
forall a. IO a -> a
unsafePerformIO (IO [Tensor] -> [Tensor])
-> (ScriptModule -> IO [Tensor]) -> ScriptModule -> [Tensor]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ForeignPtr Module -> IO (ForeignPtr TensorList))
-> ScriptModule -> IO [Tensor]
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr TensorList)
LibTorch.getParameters
getParametersIO ::
RawModule ->
IO [Tensor]
getParametersIO :: RawModule -> IO [Tensor]
getParametersIO = (ForeignPtr Module -> IO (ForeignPtr TensorList))
-> RawModule -> IO [Tensor]
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr TensorList)
LibTorch.getParameters
setParameters :: RawModule -> [Tensor] -> IO ()
setParameters :: RawModule -> [Tensor] -> IO ()
setParameters = (ForeignPtr Module -> ForeignPtr TensorList -> IO ())
-> RawModule -> [Tensor] -> IO ()
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> ForeignPtr TensorList -> IO ()
LibTorch.setParameters
updateParameters :: LoadMode -> ScriptModule -> [Tensor] -> ScriptModule
updateParameters :: LoadMode -> ScriptModule -> [Tensor] -> ScriptModule
updateParameters LoadMode
mode ScriptModule
module' [Tensor]
inputs = IO ScriptModule -> ScriptModule
forall a. IO a -> a
unsafePerformIO (IO ScriptModule -> ScriptModule)
-> IO ScriptModule -> ScriptModule
forall a b. (a -> b) -> a -> b
$
case LoadMode
mode of
LoadMode
WithoutRequiredGrad -> (ForeignPtr Module -> IO (ForeignPtr Module))
-> ScriptModule -> IO ScriptModule
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr Module)
LibTorch.clone ScriptModule
module'
LoadMode
WithRequiredGrad -> do
ScriptModule
r <- (ForeignPtr Module -> IO (ForeignPtr Module))
-> ScriptModule -> IO ScriptModule
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr Module)
LibTorch.clone ScriptModule
module'
[IndependentTensor]
paramsWithRequiredGrad <- [Tensor]
-> (Tensor -> IO IndependentTensor) -> IO [IndependentTensor]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Tensor]
inputs Tensor -> IO IndependentTensor
makeIndependent
ScriptModule -> [Tensor] -> IO ()
setParameters' ScriptModule
r ((IndependentTensor -> Tensor) -> [IndependentTensor] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
map IndependentTensor -> Tensor
toDependent [IndependentTensor]
paramsWithRequiredGrad)
ScriptModule -> IO ScriptModule
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ScriptModule
r
where
setParameters' :: ScriptModule -> [Tensor] -> IO ()
setParameters' :: ScriptModule -> [Tensor] -> IO ()
setParameters' = (ForeignPtr Module -> ForeignPtr TensorList -> IO ())
-> ScriptModule -> [Tensor] -> IO ()
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> ForeignPtr TensorList -> IO ()
LibTorch.setParameters
getNamedParameters ::
ScriptModule ->
[(String, Tensor)]
getNamedParameters :: ScriptModule -> [(String, Tensor)]
getNamedParameters (UnsafeScriptModule ForeignPtr Module
m) = IO [(String, Tensor)] -> [(String, Tensor)]
forall a. IO a -> a
unsafePerformIO (IO [(String, Tensor)] -> [(String, Tensor)])
-> IO [(String, Tensor)] -> [(String, Tensor)]
forall a b. (a -> b) -> a -> b
$ do
[(ForeignPtr StdString, ForeignPtr Tensor)]
dat <- ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Tensor)]
LibTorch.getNamedParameters ForeignPtr Module
m
[(ForeignPtr StdString, ForeignPtr Tensor)]
-> ((ForeignPtr StdString, ForeignPtr Tensor)
-> IO (String, Tensor))
-> IO [(String, Tensor)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(ForeignPtr StdString, ForeignPtr Tensor)]
dat (((ForeignPtr StdString, ForeignPtr Tensor) -> IO (String, Tensor))
-> IO [(String, Tensor)])
-> ((ForeignPtr StdString, ForeignPtr Tensor)
-> IO (String, Tensor))
-> IO [(String, Tensor)]
forall a b. (a -> b) -> a -> b
$ \(ForeignPtr StdString
key, ForeignPtr Tensor
value) ->
(,) (String -> Tensor -> (String, Tensor))
-> IO String -> IO (Tensor -> (String, Tensor))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ForeignPtr StdString -> (String -> IO String) -> IO String
forall r. ForeignPtr StdString -> (String -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
key String -> IO String
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return IO (Tensor -> (String, Tensor)) -> IO Tensor -> IO (String, Tensor)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ForeignPtr Tensor -> (Tensor -> IO Tensor) -> IO Tensor
forall r. ForeignPtr Tensor -> (Tensor -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
value Tensor -> IO Tensor
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
getNamedBuffers ::
ScriptModule ->
[(String, Tensor)]
getNamedBuffers :: ScriptModule -> [(String, Tensor)]
getNamedBuffers (UnsafeScriptModule ForeignPtr Module
m) = IO [(String, Tensor)] -> [(String, Tensor)]
forall a. IO a -> a
unsafePerformIO (IO [(String, Tensor)] -> [(String, Tensor)])
-> IO [(String, Tensor)] -> [(String, Tensor)]
forall a b. (a -> b) -> a -> b
$ do
[(ForeignPtr StdString, ForeignPtr Tensor)]
dat <- ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Tensor)]
LibTorch.getNamedBuffers ForeignPtr Module
m
[(ForeignPtr StdString, ForeignPtr Tensor)]
-> ((ForeignPtr StdString, ForeignPtr Tensor)
-> IO (String, Tensor))
-> IO [(String, Tensor)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(ForeignPtr StdString, ForeignPtr Tensor)]
dat (((ForeignPtr StdString, ForeignPtr Tensor) -> IO (String, Tensor))
-> IO [(String, Tensor)])
-> ((ForeignPtr StdString, ForeignPtr Tensor)
-> IO (String, Tensor))
-> IO [(String, Tensor)]
forall a b. (a -> b) -> a -> b
$ \(ForeignPtr StdString
key, ForeignPtr Tensor
value) ->
(,) (String -> Tensor -> (String, Tensor))
-> IO String -> IO (Tensor -> (String, Tensor))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ForeignPtr StdString -> (String -> IO String) -> IO String
forall r. ForeignPtr StdString -> (String -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
key String -> IO String
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return IO (Tensor -> (String, Tensor)) -> IO Tensor -> IO (String, Tensor)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ForeignPtr Tensor -> (Tensor -> IO Tensor) -> IO Tensor
forall r. ForeignPtr Tensor -> (Tensor -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
value Tensor -> IO Tensor
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
getNamedAttributes ::
ScriptModule ->
[(String, IValue)]
getNamedAttributes :: ScriptModule -> [(String, IValue)]
getNamedAttributes (UnsafeScriptModule ForeignPtr Module
m) = IO [(String, IValue)] -> [(String, IValue)]
forall a. IO a -> a
unsafePerformIO (IO [(String, IValue)] -> [(String, IValue)])
-> IO [(String, IValue)] -> [(String, IValue)]
forall a b. (a -> b) -> a -> b
$ do
[(ForeignPtr StdString, RawIValue)]
dat <- ForeignPtr Module -> IO [(ForeignPtr StdString, RawIValue)]
LibTorch.getNamedAttributes ForeignPtr Module
m
[(ForeignPtr StdString, RawIValue)]
-> ((ForeignPtr StdString, RawIValue) -> IO (String, IValue))
-> IO [(String, IValue)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(ForeignPtr StdString, RawIValue)]
dat (((ForeignPtr StdString, RawIValue) -> IO (String, IValue))
-> IO [(String, IValue)])
-> ((ForeignPtr StdString, RawIValue) -> IO (String, IValue))
-> IO [(String, IValue)]
forall a b. (a -> b) -> a -> b
$ \(ForeignPtr StdString
key, RawIValue
value) ->
(,) (String -> IValue -> (String, IValue))
-> IO String -> IO (IValue -> (String, IValue))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ForeignPtr StdString -> (String -> IO String) -> IO String
forall r. ForeignPtr StdString -> (String -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
key String -> IO String
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return IO (IValue -> (String, IValue)) -> IO IValue -> IO (String, IValue)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> RawIValue -> (IValue -> IO IValue) -> IO IValue
forall r. RawIValue -> (IValue -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast RawIValue
value IValue -> IO IValue
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
getNamedModules ::
ScriptModule ->
[(String, ScriptModule)]
getNamedModules :: ScriptModule -> [(String, ScriptModule)]
getNamedModules (UnsafeScriptModule ForeignPtr Module
m) = IO [(String, ScriptModule)] -> [(String, ScriptModule)]
forall a. IO a -> a
unsafePerformIO (IO [(String, ScriptModule)] -> [(String, ScriptModule)])
-> IO [(String, ScriptModule)] -> [(String, ScriptModule)]
forall a b. (a -> b) -> a -> b
$ do
[(ForeignPtr StdString, ForeignPtr Module)]
dat <- ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Module)]
LibTorch.getNamedModules ForeignPtr Module
m
[(ForeignPtr StdString, ForeignPtr Module)]
-> ((ForeignPtr StdString, ForeignPtr Module)
-> IO (String, ScriptModule))
-> IO [(String, ScriptModule)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(ForeignPtr StdString, ForeignPtr Module)]
dat (((ForeignPtr StdString, ForeignPtr Module)
-> IO (String, ScriptModule))
-> IO [(String, ScriptModule)])
-> ((ForeignPtr StdString, ForeignPtr Module)
-> IO (String, ScriptModule))
-> IO [(String, ScriptModule)]
forall a b. (a -> b) -> a -> b
$ \(ForeignPtr StdString
key, ForeignPtr Module
value) ->
(,) (String -> ScriptModule -> (String, ScriptModule))
-> IO String -> IO (ScriptModule -> (String, ScriptModule))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ForeignPtr StdString -> (String -> IO String) -> IO String
forall r. ForeignPtr StdString -> (String -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
key String -> IO String
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return IO (ScriptModule -> (String, ScriptModule))
-> IO ScriptModule -> IO (String, ScriptModule)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ForeignPtr Module
-> (ScriptModule -> IO ScriptModule) -> IO ScriptModule
forall r. ForeignPtr Module -> (ScriptModule -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Module
value ScriptModule -> IO ScriptModule
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
getNamedChildren ::
ScriptModule ->
[(String, ScriptModule)]
getNamedChildren :: ScriptModule -> [(String, ScriptModule)]
getNamedChildren (UnsafeScriptModule ForeignPtr Module
m) = IO [(String, ScriptModule)] -> [(String, ScriptModule)]
forall a. IO a -> a
unsafePerformIO (IO [(String, ScriptModule)] -> [(String, ScriptModule)])
-> IO [(String, ScriptModule)] -> [(String, ScriptModule)]
forall a b. (a -> b) -> a -> b
$ do
[(ForeignPtr StdString, ForeignPtr Module)]
dat <- ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Module)]
LibTorch.getNamedChildren ForeignPtr Module
m
[(ForeignPtr StdString, ForeignPtr Module)]
-> ((ForeignPtr StdString, ForeignPtr Module)
-> IO (String, ScriptModule))
-> IO [(String, ScriptModule)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(ForeignPtr StdString, ForeignPtr Module)]
dat (((ForeignPtr StdString, ForeignPtr Module)
-> IO (String, ScriptModule))
-> IO [(String, ScriptModule)])
-> ((ForeignPtr StdString, ForeignPtr Module)
-> IO (String, ScriptModule))
-> IO [(String, ScriptModule)]
forall a b. (a -> b) -> a -> b
$ \(ForeignPtr StdString
key, ForeignPtr Module
value) ->
(,) (String -> ScriptModule -> (String, ScriptModule))
-> IO String -> IO (ScriptModule -> (String, ScriptModule))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ForeignPtr StdString -> (String -> IO String) -> IO String
forall r. ForeignPtr StdString -> (String -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
key String -> IO String
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return IO (ScriptModule -> (String, ScriptModule))
-> IO ScriptModule -> IO (String, ScriptModule)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ForeignPtr Module
-> (ScriptModule -> IO ScriptModule) -> IO ScriptModule
forall r. ForeignPtr Module -> (ScriptModule -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Module
value ScriptModule -> IO ScriptModule
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
toScriptModule :: RawModule -> IO ScriptModule
toScriptModule :: RawModule -> IO ScriptModule
toScriptModule RawModule
rawModule = do
(UnsafeRawModule ForeignPtr Module
r) <- RawModule -> IO RawModule
cloneRawModule RawModule
rawModule
ScriptModule -> IO ScriptModule
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ScriptModule -> IO ScriptModule)
-> ScriptModule -> IO ScriptModule
forall a b. (a -> b) -> a -> b
$ ForeignPtr Module -> ScriptModule
UnsafeScriptModule ForeignPtr Module
r
toRawModule :: ScriptModule -> IO RawModule
toRawModule :: ScriptModule -> IO RawModule
toRawModule ScriptModule
scriptModule = do
(UnsafeScriptModule ForeignPtr Module
r) <- ScriptModule -> IO ScriptModule
clone' ScriptModule
scriptModule
RawModule -> IO RawModule
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RawModule -> IO RawModule) -> RawModule -> IO RawModule
forall a b. (a -> b) -> a -> b
$ ForeignPtr Module -> RawModule
UnsafeRawModule ForeignPtr Module
r
where
clone' :: ScriptModule -> IO ScriptModule
clone' = (ForeignPtr Module -> IO (ForeignPtr Module))
-> ScriptModule -> IO ScriptModule
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr Module)
LibTorch.clone
cloneRawModule :: RawModule -> IO RawModule
cloneRawModule :: RawModule -> IO RawModule
cloneRawModule = (ForeignPtr Module -> IO (ForeignPtr Module))
-> RawModule -> IO RawModule
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr Module)
LibTorch.clone
data RuntimeMode = Eval | Train deriving (Int -> RuntimeMode -> ShowS
[RuntimeMode] -> ShowS
RuntimeMode -> String
(Int -> RuntimeMode -> ShowS)
-> (RuntimeMode -> String)
-> ([RuntimeMode] -> ShowS)
-> Show RuntimeMode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RuntimeMode -> ShowS
showsPrec :: Int -> RuntimeMode -> ShowS
$cshow :: RuntimeMode -> String
show :: RuntimeMode -> String
$cshowList :: [RuntimeMode] -> ShowS
showList :: [RuntimeMode] -> ShowS
Show, RuntimeMode -> RuntimeMode -> Bool
(RuntimeMode -> RuntimeMode -> Bool)
-> (RuntimeMode -> RuntimeMode -> Bool) -> Eq RuntimeMode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: RuntimeMode -> RuntimeMode -> Bool
== :: RuntimeMode -> RuntimeMode -> Bool
$c/= :: RuntimeMode -> RuntimeMode -> Bool
/= :: RuntimeMode -> RuntimeMode -> Bool
Eq)
setRuntimeMode :: RawModule -> RuntimeMode -> IO ()
setRuntimeMode :: RawModule -> RuntimeMode -> IO ()
setRuntimeMode RawModule
rmod RuntimeMode
mode = (ForeignPtr Module -> CBool -> IO ()) -> RawModule -> Bool -> IO ()
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> CBool -> IO ()
LibTorch.train RawModule
rmod (RuntimeMode
mode RuntimeMode -> RuntimeMode -> Bool
forall a. Eq a => a -> a -> Bool
== RuntimeMode
Train)
define :: RawModule -> String -> IO ()
define :: RawModule -> String -> IO ()
define = (ForeignPtr Module -> ForeignPtr StdString -> IO ())
-> RawModule -> String -> IO ()
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> ForeignPtr StdString -> IO ()
LibTorch.define
dumpToStr ::
ScriptModule ->
Bool ->
Bool ->
Bool ->
IO String
dumpToStr :: ScriptModule -> Bool -> Bool -> Bool -> IO String
dumpToStr = (ForeignPtr Module
-> CBool -> CBool -> CBool -> IO (ForeignPtr StdString))
-> ScriptModule -> Bool -> Bool -> Bool -> IO String
forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Module
-> CBool -> CBool -> CBool -> IO (ForeignPtr StdString)
LibTorch.dumpToStr
dumpToStr' :: ScriptModule -> IO String
dumpToStr' :: ScriptModule -> IO String
dumpToStr' ScriptModule
obj = ScriptModule -> Bool -> Bool -> Bool -> IO String
dumpToStr ScriptModule
obj Bool
True Bool
True Bool
True
runMethod ::
ScriptModule ->
String ->
[IValue] ->
IValue
runMethod :: ScriptModule -> String -> [IValue] -> IValue
runMethod ScriptModule
module' String
func = IO IValue -> IValue
forall a. IO a -> a
unsafePerformIO (IO IValue -> IValue)
-> ([IValue] -> IO IValue) -> [IValue] -> IValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ScriptModule -> String -> [RawIValue] -> IO RawIValue)
-> ScriptModule -> String -> [IValue] -> IO IValue
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ScriptModule -> String -> [RawIValue] -> IO RawIValue
runMethod' ScriptModule
module' String
func
where
runMethod' :: ScriptModule -> String -> [RawIValue] -> IO RawIValue
runMethod' :: ScriptModule -> String -> [RawIValue] -> IO RawIValue
runMethod' = (ForeignPtr Module
-> ForeignPtr StdString
-> ForeignPtr (C10List IValue)
-> IO (Ptr IValue))
-> ScriptModule -> String -> [RawIValue] -> IO RawIValue
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Module
-> ForeignPtr StdString
-> ForeignPtr (C10List IValue)
-> IO (Ptr IValue)
LibTorch.runMethod
runMethod1 ::
ScriptModule ->
String ->
IValue ->
IValue
runMethod1 :: ScriptModule -> String -> IValue -> IValue
runMethod1 ScriptModule
module' String
func = IO IValue -> IValue
forall a. IO a -> a
unsafePerformIO (IO IValue -> IValue) -> (IValue -> IO IValue) -> IValue -> IValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ScriptModule -> String -> RawIValue -> IO RawIValue)
-> ScriptModule -> String -> IValue -> IO IValue
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ScriptModule -> String -> RawIValue -> IO RawIValue
runMethod1' ScriptModule
module' String
func
where
runMethod1' :: ScriptModule -> String -> RawIValue -> IO RawIValue
runMethod1' :: ScriptModule -> String -> RawIValue -> IO RawIValue
runMethod1' = (ForeignPtr Module
-> ForeignPtr StdString -> RawIValue -> IO (Ptr IValue))
-> ScriptModule -> String -> RawIValue -> IO RawIValue
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Module
-> ForeignPtr StdString -> RawIValue -> IO (Ptr IValue)
LibTorch.runMethod1
instance Parameterized ScriptModule where
flattenParameters :: ScriptModule -> [IndependentTensor]
flattenParameters ScriptModule
module' = (Tensor -> IndependentTensor) -> [Tensor] -> [IndependentTensor]
forall a b. (a -> b) -> [a] -> [b]
map Tensor -> IndependentTensor
IndependentTensor ([Tensor] -> [IndependentTensor])
-> [Tensor] -> [IndependentTensor]
forall a b. (a -> b) -> a -> b
$ ScriptModule -> [Tensor]
getParameters ScriptModule
module'
_replaceParameters :: ScriptModule -> ParamStream ScriptModule
_replaceParameters ScriptModule
module' = do
let len :: Int
len = [Tensor] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (ScriptModule -> [Tensor]
getParameters ScriptModule
module')
[IndependentTensor]
ps' <- Int
-> StateT [IndependentTensor] Identity IndependentTensor
-> StateT [IndependentTensor] Identity [IndependentTensor]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
len StateT [IndependentTensor] Identity IndependentTensor
nextParameter
ScriptModule -> ParamStream ScriptModule
forall a. a -> StateT [IndependentTensor] Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (ScriptModule -> ParamStream ScriptModule)
-> ScriptModule -> ParamStream ScriptModule
forall a b. (a -> b) -> a -> b
$ LoadMode -> ScriptModule -> [Tensor] -> ScriptModule
updateParameters LoadMode
WithRequiredGrad ScriptModule
module' ((IndependentTensor -> Tensor) -> [IndependentTensor] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
map IndependentTensor -> Tensor
toDependent [IndependentTensor]
ps')
trace ::
String ->
String ->
([Tensor] -> IO [Tensor]) ->
[Tensor] ->
IO RawModule
trace :: String
-> String -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO RawModule
trace String
moduleName String
functionName [Tensor] -> IO [Tensor]
func = (String
-> String -> ForeignPtr TensorList -> IO (ForeignPtr Module))
-> String -> String -> [Tensor] -> IO RawModule
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 (\String
m String
f ForeignPtr TensorList
inps -> String
-> String
-> (ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> ForeignPtr TensorList
-> IO (ForeignPtr Module)
LibTorch.trace String
m String
f (([Tensor] -> IO [Tensor])
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
trans [Tensor] -> IO [Tensor]
func) ForeignPtr TensorList
inps) String
moduleName String
functionName
where
trans :: ([Tensor] -> IO [Tensor]) -> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
trans :: ([Tensor] -> IO [Tensor])
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
trans [Tensor] -> IO [Tensor]
func ForeignPtr TensorList
inputs =
ForeignPtr TensorList
-> ([Tensor] -> IO (ForeignPtr TensorList))
-> IO (ForeignPtr TensorList)
forall r. ForeignPtr TensorList -> ([Tensor] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
inputs (([Tensor] -> IO (ForeignPtr TensorList))
-> IO (ForeignPtr TensorList))
-> ([Tensor] -> IO (ForeignPtr TensorList))
-> IO (ForeignPtr TensorList)
forall a b. (a -> b) -> a -> b
$ \[Tensor]
inputs' -> do
[Tensor]
ret <- [Tensor] -> IO [Tensor]
func [Tensor]
inputs'
[Tensor]
-> (ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> IO (ForeignPtr TensorList)
forall r. [Tensor] -> (ForeignPtr TensorList -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [Tensor]
ret ForeignPtr TensorList -> IO (ForeignPtr TensorList)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
traceWithParameters ::
Parameterized f =>
String ->
(f -> [Tensor] -> IO [Tensor]) ->
f ->
[Tensor] ->
IO RawModule
traceWithParameters :: forall f.
Parameterized f =>
String
-> (f -> [Tensor] -> IO [Tensor]) -> f -> [Tensor] -> IO RawModule
traceWithParameters String
moduleName f -> [Tensor] -> IO [Tensor]
func f
parameterized_parameters [Tensor]
inputs = do
let parameters :: [Tensor]
parameters = (IndependentTensor -> Tensor) -> [IndependentTensor] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
map IndependentTensor -> Tensor
toDependent (f -> [IndependentTensor]
forall f. Parameterized f => f -> [IndependentTensor]
flattenParameters f
parameterized_parameters)
fromParams :: [Tensor] -> f
fromParams [Tensor]
params = f -> [IndependentTensor] -> f
forall f. Parameterized f => f -> [IndependentTensor] -> f
replaceParameters f
parameterized_parameters ((Tensor -> IndependentTensor) -> [Tensor] -> [IndependentTensor]
forall a b. (a -> b) -> [a] -> [b]
map Tensor -> IndependentTensor
IndependentTensor [Tensor]
params)
plen :: Int
plen = [Tensor] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tensor]
parameters
ilen :: Int
ilen = [Tensor] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tensor]
inputs
RawModule
r <-
String
-> String -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO RawModule
trace
String
moduleName
String
"forwardWithParameters"
( \[Tensor]
parametersAndInputs ->
f -> [Tensor] -> IO [Tensor]
func
([Tensor] -> f
fromParams (Int -> [Tensor] -> [Tensor]
forall a. Int -> [a] -> [a]
take Int
plen [Tensor]
parametersAndInputs))
(Int -> [Tensor] -> [Tensor]
forall a. Int -> [a] -> [a]
drop Int
plen [Tensor]
parametersAndInputs)
)
([Tensor]
parameters [Tensor] -> [Tensor] -> [Tensor]
forall a. [a] -> [a] -> [a]
++ [Tensor]
inputs)
[(Integer, Tensor)] -> ((Integer, Tensor) -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Integer] -> [Tensor] -> [(Integer, Tensor)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Integer
0 ..] [Tensor]
parameters) (((Integer, Tensor) -> IO ()) -> IO ())
-> ((Integer, Tensor) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Integer
i, Tensor
p) ->
RawModule -> String -> Tensor -> Bool -> IO ()
registerParameter RawModule
r (String
"p" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
i) Tensor
p Bool
False
let args :: String
args = String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (Int -> String) -> [Int] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
i -> String
"i" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i) [Int
0 .. (Int
ilen Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)]
params :: String
params = String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (Int -> String) -> [Int] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
i -> String
"self.p" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i) [Int
0 .. (Int
plen Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)]
RawModule -> String -> IO ()
define RawModule
r (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$
String
"def forward(self, " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
args String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"):\n" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" return self.forwardWithParameters(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
params String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
args String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" )\n"
RawModule -> IO RawModule
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RawModule
r
traceAsGraph ::
([Tensor] -> IO [Tensor]) ->
[Tensor] ->
IO Graph
traceAsGraph :: ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO Graph
traceAsGraph [Tensor] -> IO [Tensor]
func = (ForeignPtr TensorList -> IO (ForeignPtr (SharedPtr JitGraph)))
-> [Tensor] -> IO Graph
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ((ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> ForeignPtr TensorList -> IO (ForeignPtr (SharedPtr JitGraph))
LibTorch.traceAsGraph (([Tensor] -> IO [Tensor])
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
trans [Tensor] -> IO [Tensor]
func))
where
trans :: ([Tensor] -> IO [Tensor]) -> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
trans :: ([Tensor] -> IO [Tensor])
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
trans [Tensor] -> IO [Tensor]
func ForeignPtr TensorList
inputs =
ForeignPtr TensorList
-> ([Tensor] -> IO (ForeignPtr TensorList))
-> IO (ForeignPtr TensorList)
forall r. ForeignPtr TensorList -> ([Tensor] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
inputs (([Tensor] -> IO (ForeignPtr TensorList))
-> IO (ForeignPtr TensorList))
-> ([Tensor] -> IO (ForeignPtr TensorList))
-> IO (ForeignPtr TensorList)
forall a b. (a -> b) -> a -> b
$ \[Tensor]
inputs' -> do
[Tensor]
ret <- [Tensor] -> IO [Tensor]
func [Tensor]
inputs'
[Tensor]
-> (ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> IO (ForeignPtr TensorList)
forall r. [Tensor] -> (ForeignPtr TensorList -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [Tensor]
ret ForeignPtr TensorList -> IO (ForeignPtr TensorList)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
printGraph :: Graph -> IO String
printGraph :: Graph -> IO String
printGraph = (ForeignPtr (SharedPtr JitGraph) -> IO (ForeignPtr StdString))
-> Graph -> IO String
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr (SharedPtr JitGraph) -> IO (ForeignPtr StdString)
LibTorch.printGraph
printOnnx :: Graph -> IO String
printOnnx :: Graph -> IO String
printOnnx = (ForeignPtr (SharedPtr JitGraph) -> IO (ForeignPtr StdString))
-> Graph -> IO String
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr (SharedPtr JitGraph) -> IO (ForeignPtr StdString)
LibTorch.printOnnx
graphToJitGraph :: Graph -> IO JitGraph
graphToJitGraph :: Graph -> IO JitGraph
graphToJitGraph (UnsafeGraph ForeignPtr (SharedPtr JitGraph)
graph) =
ForeignPtr (SharedPtr JitGraph)
-> (Ptr (SharedPtr JitGraph) -> IO JitGraph) -> IO JitGraph
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr (SharedPtr JitGraph)
graph ((Ptr (SharedPtr JitGraph) -> IO JitGraph) -> IO JitGraph)
-> (Ptr (SharedPtr JitGraph) -> IO JitGraph) -> IO JitGraph
forall a b. (a -> b) -> a -> b
$ \Ptr (SharedPtr JitGraph)
g0 -> Ptr (SharedPtr JitGraph)
-> (Ptr JitGraph -> IO JitGraph) -> IO JitGraph
forall a.
Ptr (SharedPtr JitGraph) -> (Ptr JitGraph -> IO a) -> IO a
Unmanaged.withJitGraph Ptr (SharedPtr JitGraph)
g0 ((Ptr JitGraph -> IO JitGraph) -> IO JitGraph)
-> (Ptr JitGraph -> IO JitGraph) -> IO JitGraph
forall a b. (a -> b) -> a -> b
$ \Ptr JitGraph
g -> do
[JitValue]
graphInputs <- [Ptr JitValue] -> IO [JitValue]
forall {t :: * -> *} {a}.
(Traversable t, Castable a (Ptr JitValue)) =>
t a -> IO (t JitValue)
toJitValue ([Ptr JitValue] -> IO [JitValue])
-> IO [Ptr JitValue] -> IO [JitValue]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr JitGraph -> IO [Ptr JitValue]
Unmanaged.graphInputs Ptr JitGraph
g
[JitValue]
graphOutputs <- [Ptr JitValue] -> IO [JitValue]
forall {t :: * -> *} {a}.
(Traversable t, Castable a (Ptr JitValue)) =>
t a -> IO (t JitValue)
toJitValue ([Ptr JitValue] -> IO [JitValue])
-> IO [Ptr JitValue] -> IO [JitValue]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr JitGraph -> IO [Ptr JitValue]
Unmanaged.graphOutputs Ptr JitGraph
g
[JitNode]
graphNodes <- [Ptr JitNode] -> IO [JitNode]
forall {t :: * -> *}.
Traversable t =>
t (Ptr JitNode) -> IO (t JitNode)
toJitNode ([Ptr JitNode] -> IO [JitNode]) -> IO [Ptr JitNode] -> IO [JitNode]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr JitGraph -> IO [Ptr JitNode]
Unmanaged.graphNodes Ptr JitGraph
g
JitGraph -> IO JitGraph
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure JitGraph {[JitValue]
[JitNode]
graphInputs :: [JitValue]
graphOutputs :: [JitValue]
graphNodes :: [JitNode]
graphInputs :: [JitValue]
graphOutputs :: [JitValue]
graphNodes :: [JitNode]
..}
where
toJitValue :: t a -> IO (t JitValue)
toJitValue t a
inputs =
t a -> (a -> IO JitValue) -> IO (t JitValue)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM t a
inputs ((a -> IO JitValue) -> IO (t JitValue))
-> (a -> IO JitValue) -> IO (t JitValue)
forall a b. (a -> b) -> a -> b
$ \a
i -> do
Int
valueId <- (Ptr JitValue -> IO CInt) -> a -> IO Int
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 Ptr JitValue -> IO CInt
Unmanaged.valueId a
i
String
valueType <- IO (ForeignPtr StdString) -> IO String
forall a ca. Castable a ca => IO ca -> IO a
cast0 ((Ptr JitValue -> IO (Ptr StdString))
-> a -> IO (ForeignPtr StdString)
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 Ptr JitValue -> IO (Ptr StdString)
Unmanaged.valueType a
i :: IO (ForeignPtr ATen.StdString))
JitValue -> IO JitValue
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure JitValue {Int
String
valueId :: Int
valueType :: String
valueId :: Int
valueType :: String
..}
toJitNode :: t (Ptr JitNode) -> IO (t JitNode)
toJitNode t (Ptr JitNode)
nodes =
t (Ptr JitNode) -> (Ptr JitNode -> IO JitNode) -> IO (t JitNode)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM t (Ptr JitNode)
nodes ((Ptr JitNode -> IO JitNode) -> IO (t JitNode))
-> (Ptr JitNode -> IO JitNode) -> IO (t JitNode)
forall a b. (a -> b) -> a -> b
$ \Ptr JitNode
n -> do
[JitValue]
nodeInputs <- [Ptr JitValue] -> IO [JitValue]
forall {t :: * -> *} {a}.
(Traversable t, Castable a (Ptr JitValue)) =>
t a -> IO (t JitValue)
toJitValue ([Ptr JitValue] -> IO [JitValue])
-> IO [Ptr JitValue] -> IO [JitValue]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr JitNode -> IO [Ptr JitValue]
Unmanaged.nodeInputs Ptr JitNode
n
[JitValue]
nodeOutputs <- [Ptr JitValue] -> IO [JitValue]
forall {t :: * -> *} {a}.
(Traversable t, Castable a (Ptr JitValue)) =>
t a -> IO (t JitValue)
toJitValue ([Ptr JitValue] -> IO [JitValue])
-> IO [Ptr JitValue] -> IO [JitValue]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr JitNode -> IO [Ptr JitValue]
Unmanaged.nodeOutputs Ptr JitNode
n
String
nodeKind <- IO (ForeignPtr StdString) -> IO String
forall a ca. Castable a ca => IO ca -> IO a
cast0 ((Ptr JitNode -> IO (Ptr StdString))
-> Ptr JitNode -> IO (ForeignPtr StdString)
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 Ptr JitNode -> IO (Ptr StdString)
Unmanaged.nodeKind Ptr JitNode
n :: IO (ForeignPtr ATen.StdString))
JitNode -> IO JitNode
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure JitNode {String
[JitValue]
nodeInputs :: [JitValue]
nodeOutputs :: [JitValue]
nodeKind :: String
nodeInputs :: [JitValue]
nodeOutputs :: [JitValue]
nodeKind :: String
..}
instance Castable [IValue] [RawIValue] where
cast :: forall r. [IValue] -> ([RawIValue] -> IO r) -> IO r
cast [IValue]
a [RawIValue] -> IO r
f = [IValue] -> (IValue -> IO RawIValue) -> IO [RawIValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [IValue]
a (IValue -> (RawIValue -> IO RawIValue) -> IO RawIValue
forall r. IValue -> (RawIValue -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
`cast` RawIValue -> IO RawIValue
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return) IO [RawIValue] -> ([RawIValue] -> IO r) -> IO r
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [RawIValue] -> IO r
f
uncast :: forall r. [RawIValue] -> ([IValue] -> IO r) -> IO r
uncast [RawIValue]
a [IValue] -> IO r
f = [RawIValue] -> (RawIValue -> IO IValue) -> IO [IValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [RawIValue]
a (RawIValue -> (IValue -> IO IValue) -> IO IValue
forall r. RawIValue -> (IValue -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
`uncast` IValue -> IO IValue
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return) IO [IValue] -> ([IValue] -> IO r) -> IO r
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [IValue] -> IO r
f
instance Castable IValue RawIValue where
cast :: forall r. IValue -> (RawIValue -> IO r) -> IO r
cast IValue
IVNone RawIValue -> IO r
f = IO RawIValue
newIValue IO RawIValue -> (RawIValue -> IO r) -> IO r
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RawIValue -> IO r
f
cast (IVTensor (Unsafe ForeignPtr Tensor
v)) RawIValue -> IO r
f = ForeignPtr Tensor -> IO RawIValue
forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr Tensor
v IO RawIValue -> (RawIValue -> IO r) -> IO r
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RawIValue -> IO r
f
cast (IVDouble Double
v) RawIValue -> IO r
f = Double -> IO RawIValue
forall a b. IValueLike a b => a -> IO b
toIValue Double
v IO RawIValue -> (RawIValue -> IO r) -> IO r
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RawIValue -> IO r
f
cast (IVInt Int64
v) RawIValue -> IO r
f = Int64 -> IO RawIValue
forall a b. IValueLike a b => a -> IO b
toIValue Int64
v IO RawIValue -> (RawIValue -> IO r) -> IO r
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RawIValue -> IO r
f
cast (IVBool Bool
v) RawIValue -> IO r
f = Bool -> IO RawIValue
forall a b. IValueLike a b => a -> IO b
toIValue Bool
v IO RawIValue -> (RawIValue -> IO r) -> IO r
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RawIValue -> IO r
f
cast (IVTuple [IValue]
v) RawIValue -> IO r
f = do
[RawIValue]
rawIValues <- [IValue] -> ([RawIValue] -> IO [RawIValue]) -> IO [RawIValue]
forall r. [IValue] -> ([RawIValue] -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [IValue]
v [RawIValue] -> IO [RawIValue]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [RawIValue]
ForeignPtr (C10Ptr IVTuple)
c10tuple <- [RawIValue]
-> (ForeignPtr (C10Ptr IVTuple)
-> IO (ForeignPtr (C10Ptr IVTuple)))
-> IO (ForeignPtr (C10Ptr IVTuple))
forall r.
[RawIValue] -> (ForeignPtr (C10Ptr IVTuple) -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [RawIValue]
rawIValues ForeignPtr (C10Ptr IVTuple) -> IO (ForeignPtr (C10Ptr IVTuple))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10Ptr ATen.IVTuple))
RawIValue -> IO r
f (RawIValue -> IO r) -> IO RawIValue -> IO r
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ForeignPtr (C10Ptr IVTuple) -> IO RawIValue
forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr (C10Ptr IVTuple)
c10tuple
cast (IVIntList [Int64]
v) RawIValue -> IO r
f = do
ForeignPtr (C10List Int64)
v' <- [Int64]
-> (ForeignPtr (C10List Int64) -> IO (ForeignPtr (C10List Int64)))
-> IO (ForeignPtr (C10List Int64))
forall r. [Int64] -> (ForeignPtr (C10List Int64) -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [Int64]
v ForeignPtr (C10List Int64) -> IO (ForeignPtr (C10List Int64))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10List Int64))
RawIValue -> IO r
f (RawIValue -> IO r) -> IO RawIValue -> IO r
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ForeignPtr (C10List Int64) -> IO RawIValue
forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr (C10List Int64)
v'
cast (IVDoubleList [Double]
v) RawIValue -> IO r
f = do
[CDouble]
cdoubles <- [Double] -> (Double -> IO CDouble) -> IO [CDouble]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Double]
v (Double -> (CDouble -> IO CDouble) -> IO CDouble
forall r. Double -> (CDouble -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
`cast` CDouble -> IO CDouble
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return) :: IO [CDouble]
ForeignPtr (C10List CDouble)
c10list <- [CDouble]
-> (ForeignPtr (C10List CDouble)
-> IO (ForeignPtr (C10List CDouble)))
-> IO (ForeignPtr (C10List CDouble))
forall r.
[CDouble] -> (ForeignPtr (C10List CDouble) -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [CDouble]
cdoubles ForeignPtr (C10List CDouble) -> IO (ForeignPtr (C10List CDouble))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10List CDouble))
RawIValue -> IO r
f (RawIValue -> IO r) -> IO RawIValue -> IO r
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ForeignPtr (C10List CDouble) -> IO RawIValue
forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr (C10List CDouble)
c10list
cast (IVBoolList [Bool]
v) RawIValue -> IO r
f = do
[CBool]
cbools <- [Bool] -> (Bool -> IO CBool) -> IO [CBool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Bool]
v (Bool -> (CBool -> IO CBool) -> IO CBool
forall r. Bool -> (CBool -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
`cast` CBool -> IO CBool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return) :: IO [CBool]
ForeignPtr (C10List CBool)
c10list <- [CBool]
-> (ForeignPtr (C10List CBool) -> IO (ForeignPtr (C10List CBool)))
-> IO (ForeignPtr (C10List CBool))
forall r. [CBool] -> (ForeignPtr (C10List CBool) -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [CBool]
cbools ForeignPtr (C10List CBool) -> IO (ForeignPtr (C10List CBool))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10List CBool))
RawIValue -> IO r
f (RawIValue -> IO r) -> IO RawIValue -> IO r
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ForeignPtr (C10List CBool) -> IO RawIValue
forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr (C10List CBool)
c10list
cast (IVString String
v) RawIValue -> IO r
f = do
ForeignPtr StdString
v' <- String
-> (ForeignPtr StdString -> IO (ForeignPtr StdString))
-> IO (ForeignPtr StdString)
forall r. String -> (ForeignPtr StdString -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast String
v ForeignPtr StdString -> IO (ForeignPtr StdString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.StdString)
RawIValue -> IO r
f (RawIValue -> IO r) -> IO RawIValue -> IO r
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ForeignPtr StdString -> IO RawIValue
forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr StdString
v'
cast (IVTensorList [Tensor]
v) RawIValue -> IO r
f = do
ForeignPtr (C10List Tensor)
v' <- [Tensor]
-> (ForeignPtr (C10List Tensor)
-> IO (ForeignPtr (C10List Tensor)))
-> IO (ForeignPtr (C10List Tensor))
forall r. [Tensor] -> (ForeignPtr (C10List Tensor) -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [Tensor]
v ForeignPtr (C10List Tensor) -> IO (ForeignPtr (C10List Tensor))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10List ATen.Tensor))
RawIValue -> IO r
f (RawIValue -> IO r) -> IO RawIValue -> IO r
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ForeignPtr (C10List Tensor) -> IO RawIValue
forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr (C10List Tensor)
v'
cast (IVGenericList [IValue]
v) RawIValue -> IO r
f = do
[RawIValue]
rawIValues <- [IValue] -> ([RawIValue] -> IO [RawIValue]) -> IO [RawIValue]
forall r. [IValue] -> ([RawIValue] -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [IValue]
v [RawIValue] -> IO [RawIValue]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [RawIValue]
ForeignPtr (C10List IValue)
c10list <- [RawIValue]
-> (ForeignPtr (C10List IValue)
-> IO (ForeignPtr (C10List IValue)))
-> IO (ForeignPtr (C10List IValue))
forall r.
[RawIValue] -> (ForeignPtr (C10List IValue) -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [RawIValue]
rawIValues ForeignPtr (C10List IValue) -> IO (ForeignPtr (C10List IValue))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10List ATen.IValue))
RawIValue -> IO r
f (RawIValue -> IO r) -> IO RawIValue -> IO r
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ForeignPtr (C10List IValue) -> IO RawIValue
forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr (C10List IValue)
c10list
cast (IVGenericDict [(IValue, IValue)]
v) RawIValue -> IO r
f = do
[RawIValue]
keys <- [IValue] -> ([RawIValue] -> IO [RawIValue]) -> IO [RawIValue]
forall r. [IValue] -> ([RawIValue] -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast (((IValue, IValue) -> IValue) -> [(IValue, IValue)] -> [IValue]
forall a b. (a -> b) -> [a] -> [b]
map (IValue, IValue) -> IValue
forall a b. (a, b) -> a
fst [(IValue, IValue)]
v) [RawIValue] -> IO [RawIValue]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [RawIValue]
[RawIValue]
values <- [IValue] -> ([RawIValue] -> IO [RawIValue]) -> IO [RawIValue]
forall r. [IValue] -> ([RawIValue] -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast (((IValue, IValue) -> IValue) -> [(IValue, IValue)] -> [IValue]
forall a b. (a -> b) -> [a] -> [b]
map (IValue, IValue) -> IValue
forall a b. (a, b) -> b
snd [(IValue, IValue)]
v) [RawIValue] -> IO [RawIValue]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [RawIValue]
let rawIValues :: [(RawIValue, RawIValue)]
rawIValues = [RawIValue] -> [RawIValue] -> [(RawIValue, RawIValue)]
forall a b. [a] -> [b] -> [(a, b)]
zip [RawIValue]
keys [RawIValue]
values
ForeignPtr (C10Dict '(IValue, IValue))
c10list <- [(RawIValue, RawIValue)]
-> (ForeignPtr (C10Dict '(IValue, IValue))
-> IO (ForeignPtr (C10Dict '(IValue, IValue))))
-> IO (ForeignPtr (C10Dict '(IValue, IValue)))
forall r.
[(RawIValue, RawIValue)]
-> (ForeignPtr (C10Dict '(IValue, IValue)) -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [(RawIValue, RawIValue)]
rawIValues ForeignPtr (C10Dict '(IValue, IValue))
-> IO (ForeignPtr (C10Dict '(IValue, IValue)))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10Dict '(ATen.IValue, ATen.IValue)))
RawIValue -> IO r
f (RawIValue -> IO r) -> IO RawIValue -> IO r
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ForeignPtr (C10Dict '(IValue, IValue)) -> IO RawIValue
forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr (C10Dict '(IValue, IValue))
c10list
cast IValue
a RawIValue -> IO r
f = IOError -> IO r
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO (IOError -> IO r) -> IOError -> IO r
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError (String -> IOError) -> String -> IOError
forall a b. (a -> b) -> a -> b
$ String
"Unsupported data-type:" String -> ShowS
forall a. [a] -> [a] -> [a]
++ IValue -> String
forall a. Show a => a -> String
show IValue
a
uncast :: forall r. RawIValue -> (IValue -> IO r) -> IO r
uncast RawIValue
obj IValue -> IO r
f =
[(IO CBool, IO r)] -> IO r
forall {m :: * -> *} {a} {a}.
(MonadThrow m, Eq a, Num a) =>
[(m a, m a)] -> m a
select
[ (RawIValue -> IO CBool
iValue_isNone RawIValue
obj, IValue -> IO r
f IValue
IVNone),
(RawIValue -> IO CBool
iValue_isTensor RawIValue
obj, RawIValue -> IO (ForeignPtr Tensor)
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj IO (ForeignPtr Tensor) -> (ForeignPtr Tensor -> IO r) -> IO r
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IValue -> IO r
f (IValue -> IO r)
-> (ForeignPtr Tensor -> IValue) -> ForeignPtr Tensor -> IO r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> IValue
IVTensor (Tensor -> IValue)
-> (ForeignPtr Tensor -> Tensor) -> ForeignPtr Tensor -> IValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr Tensor -> Tensor
Unsafe),
(RawIValue -> IO CBool
iValue_isDouble RawIValue
obj, RawIValue -> IO Double
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj IO Double -> (Double -> IO r) -> IO r
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IValue -> IO r
f (IValue -> IO r) -> (Double -> IValue) -> Double -> IO r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> IValue
IVDouble),
(RawIValue -> IO CBool
iValue_isInt RawIValue
obj, RawIValue -> IO Int64
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj IO Int64 -> (Int64 -> IO r) -> IO r
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IValue -> IO r
f (IValue -> IO r) -> (Int64 -> IValue) -> Int64 -> IO r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> IValue
IVInt),
(RawIValue -> IO CBool
iValue_isBool RawIValue
obj, RawIValue -> IO Bool
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj IO Bool -> (Bool -> IO r) -> IO r
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IValue -> IO r
f (IValue -> IO r) -> (Bool -> IValue) -> Bool -> IO r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> IValue
IVBool),
( RawIValue -> IO CBool
iValue_isString RawIValue
obj,
do
ForeignPtr StdString
v <- RawIValue -> IO (ForeignPtr StdString)
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr ATen.StdString)
String
str <- ForeignPtr StdString -> (String -> IO String) -> IO String
forall r. ForeignPtr StdString -> (String -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
v String -> IO String
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO String
IValue -> IO r
f (String -> IValue
IVString String
str)
),
( RawIValue -> IO CBool
iValue_isTensorList RawIValue
obj,
do
ForeignPtr (C10List Tensor)
v' <- RawIValue -> IO (ForeignPtr (C10List Tensor))
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10List ATen.Tensor))
[Tensor]
ts <- ForeignPtr (C10List Tensor)
-> ([Tensor] -> IO [Tensor]) -> IO [Tensor]
forall r. ForeignPtr (C10List Tensor) -> ([Tensor] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List Tensor)
v' [Tensor] -> IO [Tensor]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [Tensor]
IValue -> IO r
f ([Tensor] -> IValue
IVTensorList [Tensor]
ts)
),
( RawIValue -> IO CBool
iValue_isDoubleList RawIValue
obj,
do
ForeignPtr (C10List CDouble)
v' <- RawIValue -> IO (ForeignPtr (C10List CDouble))
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10List CDouble))
[CDouble]
cdoubles <- ForeignPtr (C10List CDouble)
-> ([CDouble] -> IO [CDouble]) -> IO [CDouble]
forall r.
ForeignPtr (C10List CDouble) -> ([CDouble] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List CDouble)
v' [CDouble] -> IO [CDouble]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [CDouble]
[Double]
doubles <- [CDouble] -> (CDouble -> IO Double) -> IO [Double]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [CDouble]
cdoubles (CDouble -> (Double -> IO Double) -> IO Double
forall r. CDouble -> (Double -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
`uncast` Double -> IO Double
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return) :: IO [Double]
IValue -> IO r
f ([Double] -> IValue
IVDoubleList [Double]
doubles)
),
( RawIValue -> IO CBool
iValue_isIntList RawIValue
obj,
do
ForeignPtr (C10List Int64)
v' <- RawIValue -> IO (ForeignPtr (C10List Int64))
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10List Int64))
[Int64]
ts <- ForeignPtr (C10List Int64) -> ([Int64] -> IO [Int64]) -> IO [Int64]
forall r. ForeignPtr (C10List Int64) -> ([Int64] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List Int64)
v' [Int64] -> IO [Int64]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [Int64]
IValue -> IO r
f ([Int64] -> IValue
IVIntList [Int64]
ts)
),
( RawIValue -> IO CBool
iValue_isBoolList RawIValue
obj,
do
ForeignPtr (C10List CBool)
v' <- RawIValue -> IO (ForeignPtr (C10List CBool))
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10List CBool))
[CBool]
cbools <- ForeignPtr (C10List CBool) -> ([CBool] -> IO [CBool]) -> IO [CBool]
forall r. ForeignPtr (C10List CBool) -> ([CBool] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List CBool)
v' [CBool] -> IO [CBool]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [CBool]
[Bool]
bools <- [CBool] -> (CBool -> IO Bool) -> IO [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [CBool]
cbools (CBool -> (Bool -> IO Bool) -> IO Bool
forall r. CBool -> (Bool -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
`uncast` Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return) :: IO [Bool]
IValue -> IO r
f ([Bool] -> IValue
IVBoolList [Bool]
bools)
),
( RawIValue -> IO CBool
iValue_isTuple RawIValue
obj,
do
ForeignPtr (C10Ptr IVTuple)
c10tuple <- RawIValue -> IO (ForeignPtr (C10Ptr IVTuple))
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10Ptr ATen.IVTuple))
[RawIValue]
rawIValues <- ForeignPtr (C10Ptr IVTuple)
-> ([RawIValue] -> IO [RawIValue]) -> IO [RawIValue]
forall r.
ForeignPtr (C10Ptr IVTuple) -> ([RawIValue] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10Ptr IVTuple)
c10tuple [RawIValue] -> IO [RawIValue]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [RawIValue]
[IValue]
ts <- [RawIValue] -> ([IValue] -> IO [IValue]) -> IO [IValue]
forall r. [RawIValue] -> ([IValue] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast [RawIValue]
rawIValues [IValue] -> IO [IValue]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [IValue]
IValue -> IO r
f ([IValue] -> IValue
IVTuple [IValue]
ts)
),
( RawIValue -> IO CBool
iValue_isList RawIValue
obj,
do
ForeignPtr (C10List IValue)
c10list <- RawIValue -> IO (ForeignPtr (C10List IValue))
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10List ATen.IValue))
[RawIValue]
rawIValues <- ForeignPtr (C10List IValue)
-> ([RawIValue] -> IO [RawIValue]) -> IO [RawIValue]
forall r.
ForeignPtr (C10List IValue) -> ([RawIValue] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List IValue)
c10list [RawIValue] -> IO [RawIValue]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [RawIValue]
[IValue]
ts <- [RawIValue] -> ([IValue] -> IO [IValue]) -> IO [IValue]
forall r. [RawIValue] -> ([IValue] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast [RawIValue]
rawIValues [IValue] -> IO [IValue]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [IValue]
IValue -> IO r
f ([IValue] -> IValue
IVGenericList [IValue]
ts)
),
( RawIValue -> IO CBool
iValue_isGenericDict RawIValue
obj,
do
ForeignPtr (C10Dict '(IValue, IValue))
c10list <- RawIValue -> IO (ForeignPtr (C10Dict '(IValue, IValue)))
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10Dict '(ATen.IValue, ATen.IValue)))
[(RawIValue, RawIValue)]
rawIValues <- ForeignPtr (C10Dict '(IValue, IValue))
-> ([(RawIValue, RawIValue)] -> IO [(RawIValue, RawIValue)])
-> IO [(RawIValue, RawIValue)]
forall r.
ForeignPtr (C10Dict '(IValue, IValue))
-> ([(RawIValue, RawIValue)] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10Dict '(IValue, IValue))
c10list [(RawIValue, RawIValue)] -> IO [(RawIValue, RawIValue)]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [(RawIValue, RawIValue)]
[(IValue, IValue)]
ts <- [(RawIValue, RawIValue)]
-> ((RawIValue, RawIValue) -> IO (IValue, IValue))
-> IO [(IValue, IValue)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(RawIValue, RawIValue)]
rawIValues (((RawIValue, RawIValue) -> IO (IValue, IValue))
-> IO [(IValue, IValue)])
-> ((RawIValue, RawIValue) -> IO (IValue, IValue))
-> IO [(IValue, IValue)]
forall a b. (a -> b) -> a -> b
$ \(RawIValue
a, RawIValue
b) -> do
IValue
a' <- RawIValue -> (IValue -> IO IValue) -> IO IValue
forall r. RawIValue -> (IValue -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast RawIValue
a IValue -> IO IValue
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
IValue
b' <- RawIValue -> (IValue -> IO IValue) -> IO IValue
forall r. RawIValue -> (IValue -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast RawIValue
b IValue -> IO IValue
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
(IValue, IValue) -> IO (IValue, IValue)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IValue
a', IValue
b')
IValue -> IO r
f ([(IValue, IValue)] -> IValue
IVGenericDict [(IValue, IValue)]
ts)
),
(RawIValue -> IO CBool
iValue_isBlob RawIValue
obj, IValue -> IO r
f IValue
IVBlob),
(RawIValue -> IO CBool
iValue_isFuture RawIValue
obj, IValue -> IO r
f IValue
IVFuture),
(RawIValue -> IO CBool
iValue_isDevice RawIValue
obj, IValue -> IO r
f IValue
IVDevice),
(RawIValue -> IO CBool
iValue_isObject RawIValue
obj, IValue -> IO r
f IValue
IVObject),
(RawIValue -> IO CBool
iValue_isCapsule RawIValue
obj, IValue -> IO r
f IValue
IVCapsule)
]
where
select :: [(m a, m a)] -> m a
select [] = IOError -> m a
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO (IOError -> m a) -> IOError -> m a
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
"Unsupported IValue"
select ((m a
cond, m a
body) : [(m a, m a)]
xs) =
m a
cond m a -> (a -> m a) -> m a
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
a
1 -> m a
body
a
_ -> [(m a, m a)] -> m a
select [(m a, m a)]
xs