{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Torch.Indef.Types
( module Sig
, DimVal(..)
, Step(..), Stride(..), StorageOffset(..), Size(..), KeepDim(..), fromKeepDim, keep, ignore, SortOrder(..), TopKOrder(..)
, AllocatorContext(..)
, (.:)
, managedState
, managedStorage
, managedTensor
, managedGen
, withLift
, withDynamic
, withStorage
, with2DynamicState
, with3DynamicState
, mkDynamic
, mkStorage
) where
import Foreign hiding (with)
import Foreign.C.Types
import Foreign.Ptr
import GHC.Int (Int64(..), Int32(..))
import Control.Monad.Managed
import Numeric.Dimensions
import qualified Foreign.Marshal.Array as FM
import Control.Arrow
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Reader.Class
import Torch.Types.TH (C'THState)
import GHC.TypeLits
import Torch.Sig.State as Sig
import Torch.Sig.Types as Sig
import Torch.Sig.Types.Global as Sig
import qualified Numeric.Dimensions (KnownDim)
import qualified Torch.Types.TH as TH
import qualified Torch.FFI.TH.Long.Storage as TH
import qualified Torch.Sig.Tensor.Memory as SigTen
import qualified Torch.Sig.Storage.Memory as SigStore
data Reduction
= NoReduce
| ElementwiseMean
| Sum
newtype DimVal = DimVal Int32
deriving (Bounded, Enum, Eq, Integral, Num, Ord, Read, Real, Show)
newtype Index = Index Int64
deriving (Bounded, Enum, Eq, Integral, Num, Ord, Read, Real, Show)
{-# DEPRECATED DimVal, Index "Use dimensions package's Idx instead" #-}
newtype Stride = Stride CLLong
deriving (Bounded, Enum, Eq, Integral, Num, Ord, Read, Real, Show)
newtype Size = Size CLLong
deriving (Bounded, Enum, Eq, Integral, Num, Ord, Read, Real, Show)
newtype StorageOffset = Offset CPtrdiff
deriving (Bounded, Enum, Eq, Integral, Num, Ord, Read, Real, Show)
newtype StorageSize = StorageSize CPtrdiff
deriving (Bounded, Enum, Eq, Integral, Num, Ord, Read, Real, Show)
newtype Step = Step CLong
deriving (Bounded, Enum, Eq, Integral, Num, Ord, Read, Real, Show)
newtype KeepDim = KeepDim { keepIt :: Bool }
deriving (Bounded, Enum, Eq, Ord, Read, Show)
fromKeepDim :: Integral i => Maybe KeepDim -> i
fromKeepDim = maybe 0 (fromIntegral . fromEnum)
keep, ignore :: KeepDim
(keep, ignore) = (KeepDim True, KeepDim False)
data SortOrder = Ascending | Descending
deriving (Eq, Show, Ord, Enum, Bounded)
data TopKOrder = KAscending | KNone | KDescending
deriving (Eq, Show, Ord, Enum, Bounded)
newtype AllocatorContext = AllocatorContext (Ptr ())
{-# WARNING AllocatorContext "this should not be used or referenced -- we are still figuring out what to do with this." #-}
ptrArray2hs :: (Ptr a -> IO (Ptr CReal)) -> (Ptr a -> IO Int) -> ForeignPtr a -> IO [HsReal]
ptrArray2hs updPtrArray toSize fp = do
sz <- withForeignPtr fp toSize
creals <- withForeignPtr fp updPtrArray
(fmap.fmap) c2hsReal (FM.peekArray sz creals)
(.:) :: (b -> c) -> (a0 -> a1 -> b) -> a0 -> a1 -> c
(.:) = (.) . (.)
infixl 5 .:
withGen :: Sig.Generator -> (Ptr CGenerator -> IO x) -> IO x
withGen g fn = withForeignPtr (Sig.rng g) fn
managedState :: Managed (Ptr Sig.CState)
managedState = managed (withForeignPtr Sig.torchstate)
managedTensor :: Sig.Dynamic -> Managed (Ptr Sig.CTensor)
managedTensor t = managed (withForeignPtr (Sig.ctensor t))
managedStorage :: Sig.Storage -> Managed (Ptr Sig.CStorage)
managedStorage t = managed (withForeignPtr (Sig.cstorage t))
managedGen :: Sig.Generator -> Managed (Ptr CGenerator)
managedGen g = managed (withForeignPtr (Sig.rng g))
with2DynamicState
:: Sig.Dynamic
-> Sig.Dynamic
-> (Ptr Sig.CState -> Ptr Sig.CTensor -> Ptr Sig.CTensor -> IO x)
-> IO x
with2DynamicState t0 t1 fn = withLift $ fn
<$> managedState
<*> managedTensor t0
<*> managedTensor t1
with3DynamicState
:: Sig.Dynamic
-> Sig.Dynamic
-> Sig.Dynamic
-> (Ptr Sig.CState -> Ptr Sig.CTensor -> Ptr Sig.CTensor -> Ptr Sig.CTensor -> IO x)
-> IO x
with3DynamicState t0 t1 t2 fn = withLift $ fn
<$> managedState
<*> managedTensor t0
<*> managedTensor t1
<*> managedTensor t2
withLift :: Managed (IO x) -> IO x
withLift = flip with pure . (liftIO =<<)
withDynamic :: Managed (IO (Ptr Sig.CTensor)) -> IO Sig.Dynamic
withDynamic = flip with mkDynamic . (liftIO =<<)
withStorage :: Managed (IO (Ptr Sig.CStorage)) -> IO Sig.Storage
withStorage = flip with mkStorage . (liftIO =<<)
mkDynamic :: Ptr Sig.CTensor -> IO Sig.Dynamic
mkDynamic t = with managedState $ \s ->
Sig.dynamic Sig.torchstate <$> newForeignPtrEnv SigTen.p_free s t
withStorageState :: Sig.Storage -> (Ptr Sig.CState -> Ptr Sig.CStorage -> IO x) -> IO x
withStorageState t fn = flip with pure . (liftIO =<<) $ fn
<$> managedState
<*> managedStorage t
mkStorage :: Ptr Sig.CStorage -> IO Sig.Storage
mkStorage t = with managedState $ \s ->
Sig.storage Sig.torchstate <$> newForeignPtrEnv SigStore.p_free s t