------------------------------------------------------------------------------- -- | -- Module : Torch.Indef.Dynamic.Types -- Copyright : (c) Sam Stites 2017 -- License : BSD3 -- Maintainer: sam@stites.io -- Stability : experimental -- Portability: non-portable ------------------------------------------------------------------------------- {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE Strict #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} module Torch.Indef.Types ( module Sig , DimVal(..) , Step(..), Stride(..), StorageOffset(..), Size(..), KeepDim(..), fromKeepDim, keep, ignore, SortOrder(..), TopKOrder(..) , AllocatorContext(..) -- , StorageSize(..), Index(..) , (.:) -- manage arguments to c functions , managedState , managedStorage , managedTensor , managedGen -- lift managed IO actions , withLift , withDynamic , withStorage -- monadic patterns to be replaced with applicative-style , with2DynamicState , with3DynamicState -- helper functions for monadic construction , mkDynamic , mkStorage ) where import Foreign hiding (with) import Foreign.C.Types import Foreign.Ptr import GHC.Int (Int64(..), Int32(..)) import Control.Monad.Managed import Numeric.Dimensions import qualified Foreign.Marshal.Array as FM import Control.Arrow import Control.Monad import Control.Monad.IO.Class import Control.Monad.Reader.Class import Torch.Types.TH (C'THState) import GHC.TypeLits import Torch.Sig.State as Sig import Torch.Sig.Types as Sig import Torch.Sig.Types.Global as Sig import qualified Numeric.Dimensions (KnownDim) import qualified Torch.Types.TH as TH import qualified Torch.FFI.TH.Long.Storage as TH import qualified Torch.Sig.Tensor.Memory as SigTen import qualified Torch.Sig.Storage.Memory as SigStore ------------------------------------------------------------------------------- -- | From -- https://github.com/pytorch/pytorch/blob/c61f0217a536d19c9ff3290067ddcbb9ce3a5c6a/aten/src/THNN/Reduction.h -- -- NB: Keep this in sync with Reduction class in torch/nn/modules/functional.py -- These constants control the reduction behavior of loss functions. -- Ideally, this would be a scoped enum, but jit doesn't support that data Reduction = NoReduce -- ^ Do not reduce | ElementwiseMean -- ^ Sum losses and take mean over each individually computed loss element | Sum -- ^ Sum losses ------------------------------------------------------------------------------- -- helpers for dimensions: -- | term-level representation of an index. newtype DimVal = DimVal Int32 deriving (Bounded, Enum, Eq, Integral, Num, Ord, Read, Real, Show) -- | term-level representation of an index. newtype Index = Index Int64 deriving (Bounded, Enum, Eq, Integral, Num, Ord, Read, Real, Show) {-# DEPRECATED DimVal, Index "Use dimensions package's Idx instead" #-} {- transferDims :: Proxy (ds::[Nat]) -> Dim ds transferDims p = undefined where go :: forall f m . Proxy (m::[Nat]) -> Dim (f :: [Nat]) go _ = if null (fromSing (sing :: Sing m)) then (D :: Dim f) else (Dn :: (x:xs) ~ m => Dim (x::Nat)) :* (go (Proxy :: (x:xs) ~ m => Proxy xs)) -- -} -- Helper function to debug dimensions package. We return @Integral i@ in case we need to cast directly to C-level types. ------------------------------------------------------------------------------- -- | newtype wrapper around the C-level representation of a tensor's internal -- 'Storage' stride for each dimension. newtype Stride = Stride CLLong deriving (Bounded, Enum, Eq, Integral, Num, Ord, Read, Real, Show) -- | newtype wrapper around the C-level representation of a dimension's size newtype Size = Size CLLong deriving (Bounded, Enum, Eq, Integral, Num, Ord, Read, Real, Show) -- | newtype wrapper around the C-level representation of a storage offset newtype StorageOffset = Offset CPtrdiff deriving (Bounded, Enum, Eq, Integral, Num, Ord, Read, Real, Show) -- | Represents the size of storage, should be CPtrdiff to match with the C internals newtype StorageSize = StorageSize CPtrdiff deriving (Bounded, Enum, Eq, Integral, Num, Ord, Read, Real, Show) -- | newtype wrapper around the C-level representation of a step size newtype Step = Step CLong deriving (Bounded, Enum, Eq, Integral, Num, Ord, Read, Real, Show) -- | haskell representation of a CInt which determines whether or not to return dimensions newtype KeepDim = KeepDim { keepIt :: Bool } deriving (Bounded, Enum, Eq, Ord, Read, Show) -- | cast a 'KeepDim' to a numerical representation. -- -- NOTE: don't bind the @i@ in case there are some differences between THC and TH fromKeepDim :: Integral i => Maybe KeepDim -> i fromKeepDim = maybe 0 (fromIntegral . fromEnum) -- | smart constructors for keepdim since we don't get inference for free like Num keep, ignore :: KeepDim (keep, ignore) = (KeepDim True, KeepDim False) -- | Simple datatype to represent sort order arguments which torch provides to us. data SortOrder = Ascending | Descending deriving (Eq, Show, Ord, Enum, Bounded) -- | Simple datatype to represent arguments for a topk function. -- -- See https://github.com/torch/torch7/blob/75a86469aa9e2f5f04e11895b269ec22eb0e4687/lib/TH/generic/THTensorMath.c#L2545 data TopKOrder = KAscending | KNone | KDescending deriving (Eq, Show, Ord, Enum, Bounded) -- | this is supposed to represent the AllocatorContext, but it should not be exposed to a user. newtype AllocatorContext = AllocatorContext (Ptr ()) {-# WARNING AllocatorContext "this should not be used or referenced -- we are still figuring out what to do with this." #-} ------------------------------------------------------------------------------- ptrArray2hs :: (Ptr a -> IO (Ptr CReal)) -> (Ptr a -> IO Int) -> ForeignPtr a -> IO [HsReal] ptrArray2hs updPtrArray toSize fp = do sz <- withForeignPtr fp toSize creals <- withForeignPtr fp updPtrArray (fmap.fmap) c2hsReal (FM.peekArray sz creals) -- | The blackbird combinator. -- -- (stites): This happens often enough that I'm pulling in the blackbird -- -- FIXME(stites): remove this (.:) :: (b -> c) -> (a0 -> a1 -> b) -> a0 -> a1 -> c (.:) = (.) . (.) infixl 5 .: -- | run a function against the internal reference of a torch generator. withGen :: Sig.Generator -> (Ptr CGenerator -> IO x) -> IO x withGen g fn = withForeignPtr (Sig.rng g) fn -- | run a function with a managed state's raw internal pointer. managedState :: Managed (Ptr Sig.CState) managedState = managed (withForeignPtr Sig.torchstate) managedTensor :: Sig.Dynamic -> Managed (Ptr Sig.CTensor) managedTensor t = managed (withForeignPtr (Sig.ctensor t)) managedStorage :: Sig.Storage -> Managed (Ptr Sig.CStorage) managedStorage t = managed (withForeignPtr (Sig.cstorage t)) managedGen :: Sig.Generator -> Managed (Ptr CGenerator) managedGen g = managed (withForeignPtr (Sig.rng g)) -- | run a function with two tensors with reference to the first tensor's underlying state. with2DynamicState :: Sig.Dynamic -> Sig.Dynamic -> (Ptr Sig.CState -> Ptr Sig.CTensor -> Ptr Sig.CTensor -> IO x) -> IO x with2DynamicState t0 t1 fn = withLift $ fn <$> managedState <*> managedTensor t0 <*> managedTensor t1 -- | run a function with three tensors with reference to the first tensor's underlying state. with3DynamicState :: Sig.Dynamic -> Sig.Dynamic -> Sig.Dynamic -> (Ptr Sig.CState -> Ptr Sig.CTensor -> Ptr Sig.CTensor -> Ptr Sig.CTensor -> IO x) -> IO x with3DynamicState t0 t1 t2 fn = withLift $ fn <$> managedState <*> managedTensor t0 <*> managedTensor t1 <*> managedTensor t2 withLift :: Managed (IO x) -> IO x withLift = flip with pure . (liftIO =<<) -- | smart constructor for a Managed 'Sig.Dynamic' tensor withDynamic :: Managed (IO (Ptr Sig.CTensor)) -> IO Sig.Dynamic withDynamic = flip with mkDynamic . (liftIO =<<) -- | smart constructor for a Managed 'Sig.Storage' tensor withStorage :: Managed (IO (Ptr Sig.CStorage)) -> IO Sig.Storage withStorage = flip with mkStorage . (liftIO =<<) -- | smart constructor for a 'Sig.Dynamic' tensor mkDynamic :: Ptr Sig.CTensor -> IO Sig.Dynamic mkDynamic t = with managedState $ \s -> Sig.dynamic Sig.torchstate <$> newForeignPtrEnv SigTen.p_free s t -- | run a function with access to a 'Sig.Storage's underlying state and C-reference. withStorageState :: Sig.Storage -> (Ptr Sig.CState -> Ptr Sig.CStorage -> IO x) -> IO x withStorageState t fn = flip with pure . (liftIO =<<) $ fn <$> managedState <*> managedStorage t -- | smart constructor for 'Sig.Storage' mkStorage :: Ptr Sig.CStorage -> IO Sig.Storage mkStorage t = with managedState $ \s -> Sig.storage Sig.torchstate <$> newForeignPtrEnv SigStore.p_free s t