{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.Script where
import Control.Exception.Safe (throwIO)
import Control.Monad (forM, forM_, replicateM)
import Data.Int (Int16, Int64)
import Data.List (intercalate)
import Data.Proxy
import Data.Reflection
import Data.Word (Word8)
import Foreign.C.Types
import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Storable
import Numeric
import System.IO.Unsafe
import Torch.Autograd
import Torch.DType
import Torch.Device
import Torch.Internal.Cast
import Torch.Internal.Class (Castable (..), CppObject (..), CppTuple2 (..), CppTuple3 (..), CppTuple4 (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Managed.Cast as ATen
import qualified Torch.Internal.Managed.Native as ATen
import qualified Torch.Internal.Managed.Type.Context as ATen
import Torch.Internal.Managed.Type.IValue
import qualified Torch.Internal.Managed.Type.Module as LibTorch
import qualified Torch.Internal.Managed.Type.StdArray as ATen
import qualified Torch.Internal.Managed.Type.StdString as ATen
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import qualified Torch.Internal.Managed.Type.TensorOptions as ATen
import Torch.Internal.Type (TensorList)
import qualified Torch.Internal.Type as ATen
import Torch.Internal.Unmanaged.Type.C10Dict
import Torch.Internal.Unmanaged.Type.IValue (IValueLike (..))
import qualified Torch.Internal.Unmanaged.Type.Module as Unmanaged
import Torch.NN
import Torch.Tensor (Tensor (..), toDevice)
import Torch.TensorOptions
newtype ScriptModule = UnsafeScriptModule (ForeignPtr ATen.Module)
newtype RawModule = UnsafeRawModule (ForeignPtr ATen.Module)
instance Show ScriptModule where
show :: ScriptModule -> String
show ScriptModule
obj = IO String -> String
forall a. IO a -> a
unsafePerformIO (IO String -> String) -> IO String -> String
forall a b. (a -> b) -> a -> b
$ ScriptModule -> IO String
dumpToStr' ScriptModule
obj
type RawIValue = ForeignPtr ATen.IValue
newtype Blob = UnsafeBlob (ForeignPtr (ATen.C10Ptr ATen.Blob))
newtype Object = UnsafeObject (ForeignPtr (ATen.C10Ptr ATen.IVObject))
newtype Future = UnsafeFuture (ForeignPtr (ATen.C10Ptr ATen.IVFuture))
newtype Capsule = UnsafeCapsule (ForeignPtr (ATen.C10Ptr ATen.Capsule))
newtype Graph = UnsafeGraph (ForeignPtr (ATen.SharedPtr ATen.JitGraph))
data JitGraph = JitGraph
{ JitGraph -> [JitValue]
graphInputs :: [JitValue],
JitGraph -> [JitValue]
graphOutputs :: [JitValue],
JitGraph -> [JitNode]
graphNodes :: [JitNode]
}
deriving (Int -> JitGraph -> ShowS
[JitGraph] -> ShowS
JitGraph -> String
(Int -> JitGraph -> ShowS)
-> (JitGraph -> String) -> ([JitGraph] -> ShowS) -> Show JitGraph
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> JitGraph -> ShowS
showsPrec :: Int -> JitGraph -> ShowS
$cshow :: JitGraph -> String
show :: JitGraph -> String
$cshowList :: [JitGraph] -> ShowS
showList :: [JitGraph] -> ShowS
Show, JitGraph -> JitGraph -> Bool
(JitGraph -> JitGraph -> Bool)
-> (JitGraph -> JitGraph -> Bool) -> Eq JitGraph
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: JitGraph -> JitGraph -> Bool
== :: JitGraph -> JitGraph -> Bool
$c/= :: JitGraph -> JitGraph -> Bool
/= :: JitGraph -> JitGraph -> Bool
Eq)
data JitNode = JitNode
{ JitNode -> [JitValue]
nodeInputs :: [JitValue],
JitNode -> [JitValue]
nodeOutputs :: [JitValue],
JitNode -> String
nodeKind :: String
}
deriving (Int -> JitNode -> ShowS
[JitNode] -> ShowS
JitNode -> String
(Int -> JitNode -> ShowS)
-> (JitNode -> String) -> ([JitNode] -> ShowS) -> Show JitNode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> JitNode -> ShowS
showsPrec :: Int -> JitNode -> ShowS
$cshow :: JitNode -> String
show :: JitNode -> String
$cshowList :: [JitNode] -> ShowS
showList :: [JitNode] -> ShowS
Show, JitNode -> JitNode -> Bool
(JitNode -> JitNode -> Bool)
-> (JitNode -> JitNode -> Bool) -> Eq JitNode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: JitNode -> JitNode -> Bool
== :: JitNode -> JitNode -> Bool
$c/= :: JitNode -> JitNode -> Bool
/= :: JitNode -> JitNode -> Bool
Eq)
data JitValue = JitValue
{ JitValue -> Int
valueId :: Int,
JitValue -> String
valueType :: String
}
deriving (Int -> JitValue -> ShowS
[JitValue] -> ShowS
JitValue -> String
(Int -> JitValue -> ShowS)
-> (JitValue -> String) -> ([JitValue] -> ShowS) -> Show JitValue
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> JitValue -> ShowS
showsPrec :: Int -> JitValue -> ShowS
$cshow :: JitValue -> String
show :: JitValue -> String
$cshowList :: [JitValue] -> ShowS
showList :: [JitValue] -> ShowS
Show, JitValue -> JitValue -> Bool
(JitValue -> JitValue -> Bool)
-> (JitValue -> JitValue -> Bool) -> Eq JitValue
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: JitValue -> JitValue -> Bool
== :: JitValue -> JitValue -> Bool
$c/= :: JitValue -> JitValue -> Bool
/= :: JitValue -> JitValue -> Bool
Eq)
instance Show Blob where
show :: Blob -> String
show Blob
_ = String
"Blob"
instance Show Future where
show :: Future -> String
show Future
_ = String
"Future"
instance Show Object where
show :: Object -> String
show Object
_ = String
"Object"
instance Show Capsule where
show :: Capsule -> String
show Capsule
_ = String
"Capsule"
data IValue
= IVNone
| IVTensor Tensor
| IVDouble Double
| IVInt Int64
| IVBool Bool
| IVTuple [IValue]
| IVIntList [Int64]
| IVDoubleList [Double]
| IVBoolList [Bool]
| IVString String
| IVTensorList [Tensor]
| IVBlob
| IVGenericList [IValue]
| IVGenericDict [(IValue, IValue)]
| IVFuture
| IVDevice
| IVObject
| IVUninitialized
| IVCapsule
deriving (Int -> IValue -> ShowS
[IValue] -> ShowS
IValue -> String
(Int -> IValue -> ShowS)
-> (IValue -> String) -> ([IValue] -> ShowS) -> Show IValue
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> IValue -> ShowS
showsPrec :: Int -> IValue -> ShowS
$cshow :: IValue -> String
show :: IValue -> String
$cshowList :: [IValue] -> ShowS
showList :: [IValue] -> ShowS
Show)
instance Castable ScriptModule (ForeignPtr ATen.Module) where
cast :: forall r. ScriptModule -> (ForeignPtr Module -> IO r) -> IO r
cast (UnsafeScriptModule ForeignPtr Module
obj) ForeignPtr Module -> IO r
f = ForeignPtr Module -> IO r
f ForeignPtr Module
obj
uncast :: forall r. ForeignPtr Module -> (ScriptModule -> IO r) -> IO r
uncast ForeignPtr Module
obj ScriptModule -> IO r
f = ScriptModule -> IO r
f (ScriptModule -> IO r) -> ScriptModule -> IO r
forall a b. (a -> b) -> a -> b
$ ForeignPtr Module -> ScriptModule
UnsafeScriptModule ForeignPtr Module
obj
instance Castable RawModule (ForeignPtr ATen.Module) where
cast :: forall r. RawModule -> (ForeignPtr Module -> IO r) -> IO r
cast (UnsafeRawModule ForeignPtr Module
obj) ForeignPtr Module -> IO r
f = ForeignPtr Module -> IO r
f ForeignPtr Module
obj
uncast :: forall r. ForeignPtr Module -> (RawModule -> IO r) -> IO r
uncast ForeignPtr Module
obj RawModule -> IO r
f = RawModule -> IO r
f (RawModule -> IO r) -> RawModule -> IO r
forall a b. (a -> b) -> a -> b
$ ForeignPtr Module -> RawModule
UnsafeRawModule ForeignPtr Module
obj
instance Castable Graph (ForeignPtr (ATen.SharedPtr ATen.JitGraph)) where
cast :: forall r.
Graph -> (ForeignPtr (SharedPtr JitGraph) -> IO r) -> IO r
cast (UnsafeGraph ForeignPtr (SharedPtr JitGraph)
obj) ForeignPtr (SharedPtr JitGraph) -> IO r
f = ForeignPtr (SharedPtr JitGraph) -> IO r
f ForeignPtr (SharedPtr JitGraph)
obj
uncast :: forall r.
ForeignPtr (SharedPtr JitGraph) -> (Graph -> IO r) -> IO r
uncast ForeignPtr (SharedPtr JitGraph)
obj Graph -> IO r
f = Graph -> IO r
f (Graph -> IO r) -> Graph -> IO r
forall a b. (a -> b) -> a -> b
$ ForeignPtr (SharedPtr JitGraph) -> Graph
UnsafeGraph ForeignPtr (SharedPtr JitGraph)
obj
newModule :: String -> IO RawModule
newModule :: String -> IO RawModule
newModule = (ForeignPtr StdString -> IO (ForeignPtr Module))
-> String -> IO RawModule
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr StdString -> IO (ForeignPtr Module)
LibTorch.newModule
saveScript :: ScriptModule -> FilePath -> IO ()
saveScript :: ScriptModule -> String -> IO ()
saveScript = (ForeignPtr Module -> String -> IO ())
-> ScriptModule -> String -> IO ()
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> String -> IO ()
LibTorch.save
saveScript' :: RawModule -> FilePath -> IO ()
saveScript' :: RawModule -> String -> IO ()
saveScript' = (ForeignPtr Module -> String -> IO ())
-> RawModule -> String -> IO ()
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> String -> IO ()
LibTorch.save
data LoadMode
= WithoutRequiredGrad
| WithRequiredGrad
deriving (Int -> LoadMode -> ShowS
[LoadMode] -> ShowS
LoadMode -> String
(Int -> LoadMode -> ShowS)
-> (LoadMode -> String) -> ([LoadMode] -> ShowS) -> Show LoadMode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> LoadMode -> ShowS
showsPrec :: Int -> LoadMode -> ShowS
$cshow :: LoadMode -> String
show :: LoadMode -> String
$cshowList :: [LoadMode] -> ShowS
showList :: [LoadMode] -> ShowS
Show, LoadMode -> LoadMode -> Bool
(LoadMode -> LoadMode -> Bool)
-> (LoadMode -> LoadMode -> Bool) -> Eq LoadMode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: LoadMode -> LoadMode -> Bool
== :: LoadMode -> LoadMode -> Bool
$c/= :: LoadMode -> LoadMode -> Bool
/= :: LoadMode -> LoadMode -> Bool
Eq)
loadScript :: LoadMode -> FilePath -> IO ScriptModule
loadScript :: LoadMode -> String -> IO ScriptModule
loadScript LoadMode
WithoutRequiredGrad String
file = (String -> IO (ForeignPtr Module)) -> String -> IO ScriptModule
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 String -> IO (ForeignPtr Module)
LibTorch.load String
file
loadScript LoadMode
WithRequiredGrad String
file = do
module'@(UnsafeRawModule rmodule) <- (String -> IO (ForeignPtr Module)) -> String -> IO RawModule
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 String -> IO (ForeignPtr Module)
LibTorch.load String
file
params <- getParametersIO module'
paramsWithRequiredGrad <- forM params makeIndependent
setParameters module' (map toDependent paramsWithRequiredGrad)
return (UnsafeScriptModule rmodule)
loadScript' :: FilePath -> IO RawModule
loadScript' :: String -> IO RawModule
loadScript' = (String -> IO (ForeignPtr Module)) -> String -> IO RawModule
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 String -> IO (ForeignPtr Module)
LibTorch.load
instance HasForward ScriptModule [IValue] IValue where
forward :: ScriptModule -> [IValue] -> IValue
forward ScriptModule
module' = IO IValue -> IValue
forall a. IO a -> a
unsafePerformIO (IO IValue -> IValue)
-> ([IValue] -> IO IValue) -> [IValue] -> IValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScriptModule -> [IValue] -> IO IValue
forall f a b. HasForward f a b => f -> a -> IO b
forwardStoch ScriptModule
module'
forwardStoch :: ScriptModule -> [IValue] -> IO IValue
forwardStoch = (ScriptModule -> [RawIValue] -> IO RawIValue)
-> ScriptModule -> [IValue] -> IO IValue
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ScriptModule -> [RawIValue] -> IO RawIValue
forward'
where
forward' :: ScriptModule -> [RawIValue] -> IO RawIValue
forward' :: ScriptModule -> [RawIValue] -> IO RawIValue
forward' = (ForeignPtr Module
-> ForeignPtr (StdVector IValue) -> IO RawIValue)
-> ScriptModule -> [RawIValue] -> IO RawIValue
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> ForeignPtr (StdVector IValue) -> IO RawIValue
LibTorch.forward
registerParameter :: RawModule -> String -> Tensor -> Bool -> IO ()
registerParameter :: RawModule -> String -> Tensor -> Bool -> IO ()
registerParameter = (ForeignPtr Module
-> ForeignPtr StdString -> ForeignPtr Tensor -> CBool -> IO ())
-> RawModule -> String -> Tensor -> Bool -> IO ()
forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Module
-> ForeignPtr StdString -> ForeignPtr Tensor -> CBool -> IO ()
LibTorch.registerParameter
registerModule :: RawModule -> String -> RawModule -> IO ()
registerModule :: RawModule -> String -> RawModule -> IO ()
registerModule = (ForeignPtr Module
-> ForeignPtr StdString -> ForeignPtr Module -> IO ())
-> RawModule -> String -> RawModule -> IO ()
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Module
-> ForeignPtr StdString -> ForeignPtr Module -> IO ()
LibTorch.registerModule
getParameters ::
ScriptModule ->
[Tensor]
getParameters :: ScriptModule -> [Tensor]
getParameters = IO [Tensor] -> [Tensor]
forall a. IO a -> a
unsafePerformIO (IO [Tensor] -> [Tensor])
-> (ScriptModule -> IO [Tensor]) -> ScriptModule -> [Tensor]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ForeignPtr Module -> IO (ForeignPtr TensorList))
-> ScriptModule -> IO [Tensor]
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr TensorList)
LibTorch.getParameters
getParametersIO ::
RawModule ->
IO [Tensor]
getParametersIO :: RawModule -> IO [Tensor]
getParametersIO = (ForeignPtr Module -> IO (ForeignPtr TensorList))
-> RawModule -> IO [Tensor]
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr TensorList)
LibTorch.getParameters
setParameters :: RawModule -> [Tensor] -> IO ()
setParameters :: RawModule -> [Tensor] -> IO ()
setParameters = (ForeignPtr Module -> ForeignPtr TensorList -> IO ())
-> RawModule -> [Tensor] -> IO ()
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> ForeignPtr TensorList -> IO ()
LibTorch.setParameters
updateParameters :: LoadMode -> ScriptModule -> [Tensor] -> ScriptModule
updateParameters :: LoadMode -> ScriptModule -> [Tensor] -> ScriptModule
updateParameters LoadMode
mode ScriptModule
module' [Tensor]
inputs = IO ScriptModule -> ScriptModule
forall a. IO a -> a
unsafePerformIO (IO ScriptModule -> ScriptModule)
-> IO ScriptModule -> ScriptModule
forall a b. (a -> b) -> a -> b
$
case LoadMode
mode of
LoadMode
WithoutRequiredGrad -> (ForeignPtr Module -> IO (ForeignPtr Module))
-> ScriptModule -> IO ScriptModule
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr Module)
LibTorch.clone ScriptModule
module'
LoadMode
WithRequiredGrad -> do
r <- (ForeignPtr Module -> IO (ForeignPtr Module))
-> ScriptModule -> IO ScriptModule
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr Module)
LibTorch.clone ScriptModule
module'
paramsWithRequiredGrad <- forM inputs makeIndependent
setParameters' r (map toDependent paramsWithRequiredGrad)
return r
where
setParameters' :: ScriptModule -> [Tensor] -> IO ()
setParameters' :: ScriptModule -> [Tensor] -> IO ()
setParameters' = (ForeignPtr Module -> ForeignPtr TensorList -> IO ())
-> ScriptModule -> [Tensor] -> IO ()
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> ForeignPtr TensorList -> IO ()
LibTorch.setParameters
getNamedParameters ::
ScriptModule ->
[(String, Tensor)]
getNamedParameters :: ScriptModule -> [(String, Tensor)]
getNamedParameters (UnsafeScriptModule ForeignPtr Module
m) = IO [(String, Tensor)] -> [(String, Tensor)]
forall a. IO a -> a
unsafePerformIO (IO [(String, Tensor)] -> [(String, Tensor)])
-> IO [(String, Tensor)] -> [(String, Tensor)]
forall a b. (a -> b) -> a -> b
$ do
dat <- ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Tensor)]
LibTorch.getNamedParameters ForeignPtr Module
m
forM dat $ \(ForeignPtr StdString
key, ForeignPtr Tensor
value) ->
(,) (String -> Tensor -> (String, Tensor))
-> IO String -> IO (Tensor -> (String, Tensor))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ForeignPtr StdString -> (String -> IO String) -> IO String
forall r. ForeignPtr StdString -> (String -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
key String -> IO String
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return IO (Tensor -> (String, Tensor)) -> IO Tensor -> IO (String, Tensor)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ForeignPtr Tensor -> (Tensor -> IO Tensor) -> IO Tensor
forall r. ForeignPtr Tensor -> (Tensor -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
value Tensor -> IO Tensor
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
getNamedBuffers ::
ScriptModule ->
[(String, Tensor)]
getNamedBuffers :: ScriptModule -> [(String, Tensor)]
getNamedBuffers (UnsafeScriptModule ForeignPtr Module
m) = IO [(String, Tensor)] -> [(String, Tensor)]
forall a. IO a -> a
unsafePerformIO (IO [(String, Tensor)] -> [(String, Tensor)])
-> IO [(String, Tensor)] -> [(String, Tensor)]
forall a b. (a -> b) -> a -> b
$ do
dat <- ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Tensor)]
LibTorch.getNamedBuffers ForeignPtr Module
m
forM dat $ \(ForeignPtr StdString
key, ForeignPtr Tensor
value) ->
(,) (String -> Tensor -> (String, Tensor))
-> IO String -> IO (Tensor -> (String, Tensor))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ForeignPtr StdString -> (String -> IO String) -> IO String
forall r. ForeignPtr StdString -> (String -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
key String -> IO String
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return IO (Tensor -> (String, Tensor)) -> IO Tensor -> IO (String, Tensor)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ForeignPtr Tensor -> (Tensor -> IO Tensor) -> IO Tensor
forall r. ForeignPtr Tensor -> (Tensor -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
value Tensor -> IO Tensor
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
getNamedAttributes ::
ScriptModule ->
[(String, IValue)]
getNamedAttributes :: ScriptModule -> [(String, IValue)]
getNamedAttributes (UnsafeScriptModule ForeignPtr Module
m) = IO [(String, IValue)] -> [(String, IValue)]
forall a. IO a -> a
unsafePerformIO (IO [(String, IValue)] -> [(String, IValue)])
-> IO [(String, IValue)] -> [(String, IValue)]
forall a b. (a -> b) -> a -> b
$ do
dat <- ForeignPtr Module -> IO [(ForeignPtr StdString, RawIValue)]
LibTorch.getNamedAttributes ForeignPtr Module
m
forM dat $ \(ForeignPtr StdString
key, RawIValue
value) ->
(,) (String -> IValue -> (String, IValue))
-> IO String -> IO (IValue -> (String, IValue))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ForeignPtr StdString -> (String -> IO String) -> IO String
forall r. ForeignPtr StdString -> (String -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
key String -> IO String
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return IO (IValue -> (String, IValue)) -> IO IValue -> IO (String, IValue)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> RawIValue -> (IValue -> IO IValue) -> IO IValue
forall r. RawIValue -> (IValue -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast RawIValue
value IValue -> IO IValue
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
getNamedModules ::
ScriptModule ->
[(String, ScriptModule)]
getNamedModules :: ScriptModule -> [(String, ScriptModule)]
getNamedModules (UnsafeScriptModule ForeignPtr Module
m) = IO [(String, ScriptModule)] -> [(String, ScriptModule)]
forall a. IO a -> a
unsafePerformIO (IO [(String, ScriptModule)] -> [(String, ScriptModule)])
-> IO [(String, ScriptModule)] -> [(String, ScriptModule)]
forall a b. (a -> b) -> a -> b
$ do
dat <- ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Module)]
LibTorch.getNamedModules ForeignPtr Module
m
forM dat $ \(ForeignPtr StdString
key, ForeignPtr Module
value) ->
(,) (String -> ScriptModule -> (String, ScriptModule))
-> IO String -> IO (ScriptModule -> (String, ScriptModule))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ForeignPtr StdString -> (String -> IO String) -> IO String
forall r. ForeignPtr StdString -> (String -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
key String -> IO String
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return IO (ScriptModule -> (String, ScriptModule))
-> IO ScriptModule -> IO (String, ScriptModule)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ForeignPtr Module
-> (ScriptModule -> IO ScriptModule) -> IO ScriptModule
forall r. ForeignPtr Module -> (ScriptModule -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Module
value ScriptModule -> IO ScriptModule
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
getNamedChildren ::
ScriptModule ->
[(String, ScriptModule)]
getNamedChildren :: ScriptModule -> [(String, ScriptModule)]
getNamedChildren (UnsafeScriptModule ForeignPtr Module
m) = IO [(String, ScriptModule)] -> [(String, ScriptModule)]
forall a. IO a -> a
unsafePerformIO (IO [(String, ScriptModule)] -> [(String, ScriptModule)])
-> IO [(String, ScriptModule)] -> [(String, ScriptModule)]
forall a b. (a -> b) -> a -> b
$ do
dat <- ForeignPtr Module -> IO [(ForeignPtr StdString, ForeignPtr Module)]
LibTorch.getNamedChildren ForeignPtr Module
m
forM dat $ \(ForeignPtr StdString
key, ForeignPtr Module
value) ->
(,) (String -> ScriptModule -> (String, ScriptModule))
-> IO String -> IO (ScriptModule -> (String, ScriptModule))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ForeignPtr StdString -> (String -> IO String) -> IO String
forall r. ForeignPtr StdString -> (String -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr StdString
key String -> IO String
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return IO (ScriptModule -> (String, ScriptModule))
-> IO ScriptModule -> IO (String, ScriptModule)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ForeignPtr Module
-> (ScriptModule -> IO ScriptModule) -> IO ScriptModule
forall r. ForeignPtr Module -> (ScriptModule -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Module
value ScriptModule -> IO ScriptModule
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
toScriptModule :: RawModule -> IO ScriptModule
toScriptModule :: RawModule -> IO ScriptModule
toScriptModule RawModule
rawModule = do
(UnsafeRawModule r) <- RawModule -> IO RawModule
cloneRawModule RawModule
rawModule
return $ UnsafeScriptModule r
toRawModule :: ScriptModule -> IO RawModule
toRawModule :: ScriptModule -> IO RawModule
toRawModule ScriptModule
scriptModule = do
(UnsafeScriptModule r) <- ScriptModule -> IO ScriptModule
clone' ScriptModule
scriptModule
return $ UnsafeRawModule r
where
clone' :: ScriptModule -> IO ScriptModule
clone' = (ForeignPtr Module -> IO (ForeignPtr Module))
-> ScriptModule -> IO ScriptModule
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr Module)
LibTorch.clone
cloneRawModule :: RawModule -> IO RawModule
cloneRawModule :: RawModule -> IO RawModule
cloneRawModule = (ForeignPtr Module -> IO (ForeignPtr Module))
-> RawModule -> IO RawModule
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Module -> IO (ForeignPtr Module)
LibTorch.clone
data RuntimeMode = Eval | Train deriving (Int -> RuntimeMode -> ShowS
[RuntimeMode] -> ShowS
RuntimeMode -> String
(Int -> RuntimeMode -> ShowS)
-> (RuntimeMode -> String)
-> ([RuntimeMode] -> ShowS)
-> Show RuntimeMode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RuntimeMode -> ShowS
showsPrec :: Int -> RuntimeMode -> ShowS
$cshow :: RuntimeMode -> String
show :: RuntimeMode -> String
$cshowList :: [RuntimeMode] -> ShowS
showList :: [RuntimeMode] -> ShowS
Show, RuntimeMode -> RuntimeMode -> Bool
(RuntimeMode -> RuntimeMode -> Bool)
-> (RuntimeMode -> RuntimeMode -> Bool) -> Eq RuntimeMode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: RuntimeMode -> RuntimeMode -> Bool
== :: RuntimeMode -> RuntimeMode -> Bool
$c/= :: RuntimeMode -> RuntimeMode -> Bool
/= :: RuntimeMode -> RuntimeMode -> Bool
Eq)
setRuntimeMode :: RawModule -> RuntimeMode -> IO ()
setRuntimeMode :: RawModule -> RuntimeMode -> IO ()
setRuntimeMode RawModule
rmod RuntimeMode
mode = (ForeignPtr Module -> CBool -> IO ()) -> RawModule -> Bool -> IO ()
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> CBool -> IO ()
LibTorch.train RawModule
rmod (RuntimeMode
mode RuntimeMode -> RuntimeMode -> Bool
forall a. Eq a => a -> a -> Bool
== RuntimeMode
Train)
define :: RawModule -> String -> IO ()
define :: RawModule -> String -> IO ()
define = (ForeignPtr Module -> ForeignPtr StdString -> IO ())
-> RawModule -> String -> IO ()
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Module -> ForeignPtr StdString -> IO ()
LibTorch.define
dumpToStr ::
ScriptModule ->
Bool ->
Bool ->
Bool ->
IO String
dumpToStr :: ScriptModule -> Bool -> Bool -> Bool -> IO String
dumpToStr = (ForeignPtr Module
-> CBool -> CBool -> CBool -> IO (ForeignPtr StdString))
-> ScriptModule -> Bool -> Bool -> Bool -> IO String
forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Module
-> CBool -> CBool -> CBool -> IO (ForeignPtr StdString)
LibTorch.dumpToStr
dumpToStr' :: ScriptModule -> IO String
dumpToStr' :: ScriptModule -> IO String
dumpToStr' ScriptModule
obj = ScriptModule -> Bool -> Bool -> Bool -> IO String
dumpToStr ScriptModule
obj Bool
True Bool
True Bool
True
runMethod ::
ScriptModule ->
String ->
[IValue] ->
IValue
runMethod :: ScriptModule -> String -> [IValue] -> IValue
runMethod ScriptModule
module' String
func = IO IValue -> IValue
forall a. IO a -> a
unsafePerformIO (IO IValue -> IValue)
-> ([IValue] -> IO IValue) -> [IValue] -> IValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ScriptModule -> String -> [RawIValue] -> IO RawIValue)
-> ScriptModule -> String -> [IValue] -> IO IValue
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ScriptModule -> String -> [RawIValue] -> IO RawIValue
runMethod' ScriptModule
module' String
func
where
runMethod' :: ScriptModule -> String -> [RawIValue] -> IO RawIValue
runMethod' :: ScriptModule -> String -> [RawIValue] -> IO RawIValue
runMethod' = (ForeignPtr Module
-> ForeignPtr StdString
-> ForeignPtr (C10List IValue)
-> IO (Ptr IValue))
-> ScriptModule -> String -> [RawIValue] -> IO RawIValue
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Module
-> ForeignPtr StdString
-> ForeignPtr (C10List IValue)
-> IO (Ptr IValue)
LibTorch.runMethod
runMethod1 ::
ScriptModule ->
String ->
IValue ->
IValue
runMethod1 :: ScriptModule -> String -> IValue -> IValue
runMethod1 ScriptModule
module' String
func = IO IValue -> IValue
forall a. IO a -> a
unsafePerformIO (IO IValue -> IValue) -> (IValue -> IO IValue) -> IValue -> IValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ScriptModule -> String -> RawIValue -> IO RawIValue)
-> ScriptModule -> String -> IValue -> IO IValue
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ScriptModule -> String -> RawIValue -> IO RawIValue
runMethod1' ScriptModule
module' String
func
where
runMethod1' :: ScriptModule -> String -> RawIValue -> IO RawIValue
runMethod1' :: ScriptModule -> String -> RawIValue -> IO RawIValue
runMethod1' = (ForeignPtr Module
-> ForeignPtr StdString -> RawIValue -> IO (Ptr IValue))
-> ScriptModule -> String -> RawIValue -> IO RawIValue
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Module
-> ForeignPtr StdString -> RawIValue -> IO (Ptr IValue)
LibTorch.runMethod1
instance Parameterized ScriptModule where
flattenParameters :: ScriptModule -> [IndependentTensor]
flattenParameters ScriptModule
module' = (Tensor -> IndependentTensor) -> [Tensor] -> [IndependentTensor]
forall a b. (a -> b) -> [a] -> [b]
map Tensor -> IndependentTensor
IndependentTensor ([Tensor] -> [IndependentTensor])
-> [Tensor] -> [IndependentTensor]
forall a b. (a -> b) -> a -> b
$ ScriptModule -> [Tensor]
getParameters ScriptModule
module'
_replaceParameters :: ScriptModule -> ParamStream ScriptModule
_replaceParameters ScriptModule
module' = do
let len :: Int
len = [Tensor] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (ScriptModule -> [Tensor]
getParameters ScriptModule
module')
ps' <- Int
-> StateT [IndependentTensor] Identity IndependentTensor
-> StateT [IndependentTensor] Identity [IndependentTensor]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
len StateT [IndependentTensor] Identity IndependentTensor
nextParameter
return $ updateParameters WithRequiredGrad module' (map toDependent ps')
trace ::
String ->
String ->
([Tensor] -> IO [Tensor]) ->
[Tensor] ->
IO RawModule
trace :: String
-> String -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO RawModule
trace String
moduleName String
functionName [Tensor] -> IO [Tensor]
func = (String
-> String -> ForeignPtr TensorList -> IO (ForeignPtr Module))
-> String -> String -> [Tensor] -> IO RawModule
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 (\String
m String
f ForeignPtr TensorList
inps -> String
-> String
-> (ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> ForeignPtr TensorList
-> IO (ForeignPtr Module)
LibTorch.trace String
m String
f (([Tensor] -> IO [Tensor])
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
trans [Tensor] -> IO [Tensor]
func) ForeignPtr TensorList
inps) String
moduleName String
functionName
where
trans :: ([Tensor] -> IO [Tensor]) -> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
trans :: ([Tensor] -> IO [Tensor])
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
trans [Tensor] -> IO [Tensor]
func ForeignPtr TensorList
inputs =
ForeignPtr TensorList
-> ([Tensor] -> IO (ForeignPtr TensorList))
-> IO (ForeignPtr TensorList)
forall r. ForeignPtr TensorList -> ([Tensor] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
inputs (([Tensor] -> IO (ForeignPtr TensorList))
-> IO (ForeignPtr TensorList))
-> ([Tensor] -> IO (ForeignPtr TensorList))
-> IO (ForeignPtr TensorList)
forall a b. (a -> b) -> a -> b
$ \[Tensor]
inputs' -> do
ret <- [Tensor] -> IO [Tensor]
func [Tensor]
inputs'
cast ret return
traceWithParameters ::
Parameterized f =>
String ->
(f -> [Tensor] -> IO [Tensor]) ->
f ->
[Tensor] ->
IO RawModule
traceWithParameters :: forall f.
Parameterized f =>
String
-> (f -> [Tensor] -> IO [Tensor]) -> f -> [Tensor] -> IO RawModule
traceWithParameters String
moduleName f -> [Tensor] -> IO [Tensor]
func f
parameterized_parameters [Tensor]
inputs = do
let parameters :: [Tensor]
parameters = (IndependentTensor -> Tensor) -> [IndependentTensor] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
map IndependentTensor -> Tensor
toDependent (f -> [IndependentTensor]
forall f. Parameterized f => f -> [IndependentTensor]
flattenParameters f
parameterized_parameters)
fromParams :: [Tensor] -> f
fromParams [Tensor]
params = f -> [IndependentTensor] -> f
forall f. Parameterized f => f -> [IndependentTensor] -> f
replaceParameters f
parameterized_parameters ((Tensor -> IndependentTensor) -> [Tensor] -> [IndependentTensor]
forall a b. (a -> b) -> [a] -> [b]
map Tensor -> IndependentTensor
IndependentTensor [Tensor]
params)
plen :: Int
plen = [Tensor] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tensor]
parameters
ilen :: Int
ilen = [Tensor] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tensor]
inputs
r <-
String
-> String -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO RawModule
trace
String
moduleName
String
"forwardWithParameters"
( \[Tensor]
parametersAndInputs ->
f -> [Tensor] -> IO [Tensor]
func
([Tensor] -> f
fromParams (Int -> [Tensor] -> [Tensor]
forall a. Int -> [a] -> [a]
take Int
plen [Tensor]
parametersAndInputs))
(Int -> [Tensor] -> [Tensor]
forall a. Int -> [a] -> [a]
drop Int
plen [Tensor]
parametersAndInputs)
)
([Tensor]
parameters [Tensor] -> [Tensor] -> [Tensor]
forall a. [a] -> [a] -> [a]
++ [Tensor]
inputs)
forM_ (zip [0 ..] parameters) $ \(Integer
i, Tensor
p) ->
RawModule -> String -> Tensor -> Bool -> IO ()
registerParameter RawModule
r (String
"p" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
i) Tensor
p Bool
False
let args = String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (Int -> String) -> [Int] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
i -> String
"i" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i) [Int
0 .. (Int
ilen Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)]
params = String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (Int -> String) -> [Int] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
i -> String
"self.p" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i) [Int
0 .. (Int
plen Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)]
define r $
"def forward(self, " ++ args ++ "):\n" ++ " return self.forwardWithParameters(" ++ params ++ ", " ++ args ++ " )\n"
return r
traceAsGraph ::
([Tensor] -> IO [Tensor]) ->
[Tensor] ->
IO Graph
traceAsGraph :: ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO Graph
traceAsGraph [Tensor] -> IO [Tensor]
func = (ForeignPtr TensorList -> IO (ForeignPtr (SharedPtr JitGraph)))
-> [Tensor] -> IO Graph
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ((ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> ForeignPtr TensorList -> IO (ForeignPtr (SharedPtr JitGraph))
LibTorch.traceAsGraph (([Tensor] -> IO [Tensor])
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
trans [Tensor] -> IO [Tensor]
func))
where
trans :: ([Tensor] -> IO [Tensor]) -> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
trans :: ([Tensor] -> IO [Tensor])
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
trans [Tensor] -> IO [Tensor]
func ForeignPtr TensorList
inputs =
ForeignPtr TensorList
-> ([Tensor] -> IO (ForeignPtr TensorList))
-> IO (ForeignPtr TensorList)
forall r. ForeignPtr TensorList -> ([Tensor] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
inputs (([Tensor] -> IO (ForeignPtr TensorList))
-> IO (ForeignPtr TensorList))
-> ([Tensor] -> IO (ForeignPtr TensorList))
-> IO (ForeignPtr TensorList)
forall a b. (a -> b) -> a -> b
$ \[Tensor]
inputs' -> do
ret <- [Tensor] -> IO [Tensor]
func [Tensor]
inputs'
cast ret return
printGraph :: Graph -> IO String
printGraph :: Graph -> IO String
printGraph = (ForeignPtr (SharedPtr JitGraph) -> IO (ForeignPtr StdString))
-> Graph -> IO String
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr (SharedPtr JitGraph) -> IO (ForeignPtr StdString)
LibTorch.printGraph
printOnnx :: Graph -> IO String
printOnnx :: Graph -> IO String
printOnnx = (ForeignPtr (SharedPtr JitGraph) -> IO (ForeignPtr StdString))
-> Graph -> IO String
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr (SharedPtr JitGraph) -> IO (ForeignPtr StdString)
LibTorch.printOnnx
graphToJitGraph :: Graph -> IO JitGraph
graphToJitGraph :: Graph -> IO JitGraph
graphToJitGraph (UnsafeGraph ForeignPtr (SharedPtr JitGraph)
graph) =
ForeignPtr (SharedPtr JitGraph)
-> (Ptr (SharedPtr JitGraph) -> IO JitGraph) -> IO JitGraph
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr (SharedPtr JitGraph)
graph ((Ptr (SharedPtr JitGraph) -> IO JitGraph) -> IO JitGraph)
-> (Ptr (SharedPtr JitGraph) -> IO JitGraph) -> IO JitGraph
forall a b. (a -> b) -> a -> b
$ \Ptr (SharedPtr JitGraph)
g0 -> Ptr (SharedPtr JitGraph)
-> (Ptr JitGraph -> IO JitGraph) -> IO JitGraph
forall a.
Ptr (SharedPtr JitGraph) -> (Ptr JitGraph -> IO a) -> IO a
Unmanaged.withJitGraph Ptr (SharedPtr JitGraph)
g0 ((Ptr JitGraph -> IO JitGraph) -> IO JitGraph)
-> (Ptr JitGraph -> IO JitGraph) -> IO JitGraph
forall a b. (a -> b) -> a -> b
$ \Ptr JitGraph
g -> do
graphInputs <- [Ptr JitValue] -> IO [JitValue]
forall {t :: * -> *} {a}.
(Traversable t, Castable a (Ptr JitValue)) =>
t a -> IO (t JitValue)
toJitValue ([Ptr JitValue] -> IO [JitValue])
-> IO [Ptr JitValue] -> IO [JitValue]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr JitGraph -> IO [Ptr JitValue]
Unmanaged.graphInputs Ptr JitGraph
g
graphOutputs <- toJitValue =<< Unmanaged.graphOutputs g
graphNodes <- toJitNode =<< Unmanaged.graphNodes g
pure JitGraph {..}
where
toJitValue :: t a -> IO (t JitValue)
toJitValue t a
inputs =
t a -> (a -> IO JitValue) -> IO (t JitValue)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM t a
inputs ((a -> IO JitValue) -> IO (t JitValue))
-> (a -> IO JitValue) -> IO (t JitValue)
forall a b. (a -> b) -> a -> b
$ \a
i -> do
valueId <- (Ptr JitValue -> IO CInt) -> a -> IO Int
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 Ptr JitValue -> IO CInt
Unmanaged.valueId a
i
valueType <- cast0 (cast1 Unmanaged.valueType i :: IO (ForeignPtr ATen.StdString))
pure JitValue {..}
toJitNode :: t (Ptr JitNode) -> IO (t JitNode)
toJitNode t (Ptr JitNode)
nodes =
t (Ptr JitNode) -> (Ptr JitNode -> IO JitNode) -> IO (t JitNode)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM t (Ptr JitNode)
nodes ((Ptr JitNode -> IO JitNode) -> IO (t JitNode))
-> (Ptr JitNode -> IO JitNode) -> IO (t JitNode)
forall a b. (a -> b) -> a -> b
$ \Ptr JitNode
n -> do
nodeInputs <- [Ptr JitValue] -> IO [JitValue]
forall {t :: * -> *} {a}.
(Traversable t, Castable a (Ptr JitValue)) =>
t a -> IO (t JitValue)
toJitValue ([Ptr JitValue] -> IO [JitValue])
-> IO [Ptr JitValue] -> IO [JitValue]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr JitNode -> IO [Ptr JitValue]
Unmanaged.nodeInputs Ptr JitNode
n
nodeOutputs <- toJitValue =<< Unmanaged.nodeOutputs n
nodeKind <- cast0 (cast1 Unmanaged.nodeKind n :: IO (ForeignPtr ATen.StdString))
pure JitNode {..}
instance Castable [IValue] [RawIValue] where
cast :: forall r. [IValue] -> ([RawIValue] -> IO r) -> IO r
cast [IValue]
a [RawIValue] -> IO r
f = [IValue] -> (IValue -> IO RawIValue) -> IO [RawIValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [IValue]
a (IValue -> (RawIValue -> IO RawIValue) -> IO RawIValue
forall r. IValue -> (RawIValue -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
`cast` RawIValue -> IO RawIValue
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return) IO [RawIValue] -> ([RawIValue] -> IO r) -> IO r
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [RawIValue] -> IO r
f
uncast :: forall r. [RawIValue] -> ([IValue] -> IO r) -> IO r
uncast [RawIValue]
a [IValue] -> IO r
f = [RawIValue] -> (RawIValue -> IO IValue) -> IO [IValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [RawIValue]
a (RawIValue -> (IValue -> IO IValue) -> IO IValue
forall r. RawIValue -> (IValue -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
`uncast` IValue -> IO IValue
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return) IO [IValue] -> ([IValue] -> IO r) -> IO r
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [IValue] -> IO r
f
instance Castable IValue RawIValue where
cast :: forall r. IValue -> (RawIValue -> IO r) -> IO r
cast IValue
IVNone RawIValue -> IO r
f = IO RawIValue
newIValue IO RawIValue -> (RawIValue -> IO r) -> IO r
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RawIValue -> IO r
f
cast (IVTensor (Unsafe ForeignPtr Tensor
v)) RawIValue -> IO r
f = ForeignPtr Tensor -> IO RawIValue
forall a b. IValueLike a b => a -> IO b
toIValue ForeignPtr Tensor
v IO RawIValue -> (RawIValue -> IO r) -> IO r
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RawIValue -> IO r
f
cast (IVDouble Double
v) RawIValue -> IO r
f = Double -> IO RawIValue
forall a b. IValueLike a b => a -> IO b
toIValue Double
v IO RawIValue -> (RawIValue -> IO r) -> IO r
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RawIValue -> IO r
f
cast (IVInt Int64
v) RawIValue -> IO r
f = Int64 -> IO RawIValue
forall a b. IValueLike a b => a -> IO b
toIValue Int64
v IO RawIValue -> (RawIValue -> IO r) -> IO r
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RawIValue -> IO r
f
cast (IVBool Bool
v) RawIValue -> IO r
f = Bool -> IO RawIValue
forall a b. IValueLike a b => a -> IO b
toIValue Bool
v IO RawIValue -> (RawIValue -> IO r) -> IO r
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RawIValue -> IO r
f
cast (IVTuple [IValue]
v) RawIValue -> IO r
f = do
rawIValues <- [IValue] -> ([RawIValue] -> IO [RawIValue]) -> IO [RawIValue]
forall r. [IValue] -> ([RawIValue] -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [IValue]
v [RawIValue] -> IO [RawIValue]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [RawIValue]
c10tuple <- cast rawIValues return :: IO (ForeignPtr (ATen.C10Ptr ATen.IVTuple))
f =<< toIValue c10tuple
cast (IVIntList [Int64]
v) RawIValue -> IO r
f = do
v' <- [Int64]
-> (ForeignPtr (C10List Int64) -> IO (ForeignPtr (C10List Int64)))
-> IO (ForeignPtr (C10List Int64))
forall r. [Int64] -> (ForeignPtr (C10List Int64) -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [Int64]
v ForeignPtr (C10List Int64) -> IO (ForeignPtr (C10List Int64))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10List Int64))
f =<< toIValue v'
cast (IVDoubleList [Double]
v) RawIValue -> IO r
f = do
cdoubles <- [Double] -> (Double -> IO CDouble) -> IO [CDouble]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Double]
v (Double -> (CDouble -> IO CDouble) -> IO CDouble
forall r. Double -> (CDouble -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
`cast` CDouble -> IO CDouble
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return) :: IO [CDouble]
c10list <- cast cdoubles return :: IO (ForeignPtr (ATen.C10List CDouble))
f =<< toIValue c10list
cast (IVBoolList [Bool]
v) RawIValue -> IO r
f = do
cbools <- [Bool] -> (Bool -> IO CBool) -> IO [CBool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Bool]
v (Bool -> (CBool -> IO CBool) -> IO CBool
forall r. Bool -> (CBool -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
`cast` CBool -> IO CBool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return) :: IO [CBool]
c10list <- cast cbools return :: IO (ForeignPtr (ATen.C10List CBool))
f =<< toIValue c10list
cast (IVString String
v) RawIValue -> IO r
f = do
v' <- String
-> (ForeignPtr StdString -> IO (ForeignPtr StdString))
-> IO (ForeignPtr StdString)
forall r. String -> (ForeignPtr StdString -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast String
v ForeignPtr StdString -> IO (ForeignPtr StdString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.StdString)
f =<< toIValue v'
cast (IVTensorList [Tensor]
v) RawIValue -> IO r
f = do
v' <- [Tensor]
-> (ForeignPtr (C10List Tensor)
-> IO (ForeignPtr (C10List Tensor)))
-> IO (ForeignPtr (C10List Tensor))
forall r. [Tensor] -> (ForeignPtr (C10List Tensor) -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [Tensor]
v ForeignPtr (C10List Tensor) -> IO (ForeignPtr (C10List Tensor))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr (ATen.C10List ATen.Tensor))
f =<< toIValue v'
cast (IVGenericList [IValue]
v) RawIValue -> IO r
f = do
rawIValues <- [IValue] -> ([RawIValue] -> IO [RawIValue]) -> IO [RawIValue]
forall r. [IValue] -> ([RawIValue] -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [IValue]
v [RawIValue] -> IO [RawIValue]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [RawIValue]
c10list <- cast rawIValues return :: IO (ForeignPtr (ATen.C10List ATen.IValue))
f =<< toIValue c10list
cast (IVGenericDict [(IValue, IValue)]
v) RawIValue -> IO r
f = do
keys <- [IValue] -> ([RawIValue] -> IO [RawIValue]) -> IO [RawIValue]
forall r. [IValue] -> ([RawIValue] -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast (((IValue, IValue) -> IValue) -> [(IValue, IValue)] -> [IValue]
forall a b. (a -> b) -> [a] -> [b]
map (IValue, IValue) -> IValue
forall a b. (a, b) -> a
fst [(IValue, IValue)]
v) [RawIValue] -> IO [RawIValue]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO [RawIValue]
values <- cast (map snd v) return :: IO [RawIValue]
let rawIValues = [RawIValue] -> [RawIValue] -> [(RawIValue, RawIValue)]
forall a b. [a] -> [b] -> [(a, b)]
zip [RawIValue]
keys [RawIValue]
values
c10list <- cast rawIValues return :: IO (ForeignPtr (ATen.C10Dict '(ATen.IValue, ATen.IValue)))
f =<< toIValue c10list
cast IValue
a RawIValue -> IO r
f = IOError -> IO r
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO (IOError -> IO r) -> IOError -> IO r
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError (String -> IOError) -> String -> IOError
forall a b. (a -> b) -> a -> b
$ String
"Unsupported data-type:" String -> ShowS
forall a. [a] -> [a] -> [a]
++ IValue -> String
forall a. Show a => a -> String
show IValue
a
uncast :: forall r. RawIValue -> (IValue -> IO r) -> IO r
uncast RawIValue
obj IValue -> IO r
f =
[(IO CBool, IO r)] -> IO r
forall {m :: * -> *} {a} {a}.
(MonadThrow m, Eq a, Num a) =>
[(m a, m a)] -> m a
select
[ (RawIValue -> IO CBool
iValue_isNone RawIValue
obj, IValue -> IO r
f IValue
IVNone),
(RawIValue -> IO CBool
iValue_isTensor RawIValue
obj, RawIValue -> IO (ForeignPtr Tensor)
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj IO (ForeignPtr Tensor) -> (ForeignPtr Tensor -> IO r) -> IO r
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IValue -> IO r
f (IValue -> IO r)
-> (ForeignPtr Tensor -> IValue) -> ForeignPtr Tensor -> IO r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> IValue
IVTensor (Tensor -> IValue)
-> (ForeignPtr Tensor -> Tensor) -> ForeignPtr Tensor -> IValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr Tensor -> Tensor
Unsafe),
(RawIValue -> IO CBool
iValue_isDouble RawIValue
obj, RawIValue -> IO Double
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj IO Double -> (Double -> IO r) -> IO r
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IValue -> IO r
f (IValue -> IO r) -> (Double -> IValue) -> Double -> IO r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> IValue
IVDouble),
(RawIValue -> IO CBool
iValue_isInt RawIValue
obj, RawIValue -> IO Int64
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj IO Int64 -> (Int64 -> IO r) -> IO r
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IValue -> IO r
f (IValue -> IO r) -> (Int64 -> IValue) -> Int64 -> IO r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> IValue
IVInt),
(RawIValue -> IO CBool
iValue_isBool RawIValue
obj, RawIValue -> IO Bool
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj IO Bool -> (Bool -> IO r) -> IO r
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IValue -> IO r
f (IValue -> IO r) -> (Bool -> IValue) -> Bool -> IO r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> IValue
IVBool),
( RawIValue -> IO CBool
iValue_isString RawIValue
obj,
do
v <- RawIValue -> IO (ForeignPtr StdString)
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr ATen.StdString)
str <- uncast v return :: IO String
f (IVString str)
),
( RawIValue -> IO CBool
iValue_isTensorList RawIValue
obj,
do
v' <- RawIValue -> IO (ForeignPtr (C10List Tensor))
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10List ATen.Tensor))
ts <- uncast v' return :: IO [Tensor]
f (IVTensorList ts)
),
( RawIValue -> IO CBool
iValue_isDoubleList RawIValue
obj,
do
v' <- RawIValue -> IO (ForeignPtr (C10List CDouble))
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10List CDouble))
cdoubles <- uncast v' return :: IO [CDouble]
doubles <- forM cdoubles (`uncast` return) :: IO [Double]
f (IVDoubleList doubles)
),
( RawIValue -> IO CBool
iValue_isIntList RawIValue
obj,
do
v' <- RawIValue -> IO (ForeignPtr (C10List Int64))
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10List Int64))
ts <- uncast v' return :: IO [Int64]
f (IVIntList ts)
),
( RawIValue -> IO CBool
iValue_isBoolList RawIValue
obj,
do
v' <- RawIValue -> IO (ForeignPtr (C10List CBool))
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10List CBool))
cbools <- uncast v' return :: IO [CBool]
bools <- forM cbools (`uncast` return) :: IO [Bool]
f (IVBoolList bools)
),
( RawIValue -> IO CBool
iValue_isTuple RawIValue
obj,
do
c10tuple <- RawIValue -> IO (ForeignPtr (C10Ptr IVTuple))
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10Ptr ATen.IVTuple))
rawIValues <- uncast c10tuple return :: IO [RawIValue]
ts <- uncast rawIValues return :: IO [IValue]
f (IVTuple ts)
),
( RawIValue -> IO CBool
iValue_isList RawIValue
obj,
do
c10list <- RawIValue -> IO (ForeignPtr (C10List IValue))
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10List ATen.IValue))
rawIValues <- uncast c10list return :: IO [RawIValue]
ts <- uncast rawIValues return :: IO [IValue]
f (IVGenericList ts)
),
( RawIValue -> IO CBool
iValue_isGenericDict RawIValue
obj,
do
c10list <- RawIValue -> IO (ForeignPtr (C10Dict '(IValue, IValue)))
forall a b. IValueLike a b => b -> IO a
fromIValue RawIValue
obj :: IO (ForeignPtr (ATen.C10Dict '(ATen.IValue, ATen.IValue)))
rawIValues <- uncast c10list return :: IO [(RawIValue, RawIValue)]
ts <- forM rawIValues $ \(RawIValue
a, RawIValue
b) -> do
a' <- RawIValue -> (IValue -> IO IValue) -> IO IValue
forall r. RawIValue -> (IValue -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast RawIValue
a IValue -> IO IValue
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
b' <- uncast b return
return (a', b')
f (IVGenericDict ts)
),
(RawIValue -> IO CBool
iValue_isBlob RawIValue
obj, IValue -> IO r
f IValue
IVBlob),
(RawIValue -> IO CBool
iValue_isFuture RawIValue
obj, IValue -> IO r
f IValue
IVFuture),
(RawIValue -> IO CBool
iValue_isDevice RawIValue
obj, IValue -> IO r
f IValue
IVDevice),
(RawIValue -> IO CBool
iValue_isObject RawIValue
obj, IValue -> IO r
f IValue
IVObject),
(RawIValue -> IO CBool
iValue_isCapsule RawIValue
obj, IValue -> IO r
f IValue
IVCapsule)
]
where
select :: [(m a, m a)] -> m a
select [] = IOError -> m a
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO (IOError -> m a) -> IOError -> m a
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
"Unsupported IValue"
select ((m a
cond, m a
body) : [(m a, m a)]
xs) =
m a
cond m a -> (a -> m a) -> m a
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
a
1 -> m a
body
a
_ -> [(m a, m a)] -> m a
select [(m a, m a)]
xs