{-# LANGUAGE ForeignFunctionInterface #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# OPTIONS_GHC -fno-cse #-} module Torch.Types.TH ( module Torch.Types.TH.Structs , C'THState, C'THNNState, CState, State(..), asState, torchstate , CAllocator, Allocator , CGenerator, Generator(..), generatorToRng, Seed(..) , CDescBuff, DescBuff, descBuff -- for nn-packages , CNNState , CDim , CNNGenerator , CInt' , CMaskTensor, CIndexTensor, CIndexStorage, C'THIndexTensor, C'THIntegerTensor , MaskDynamic, IndexDynamic, MaskTensor, IndexTensor, IndexStorage -- * Unsigned types , CByteTensor, ByteDynamic(..), byteDynamic, ByteTensor(..), byteAsStatic , CByteStorage, ByteStorage(..), byteStorage , CCharTensor, CharDynamic(..), charDynamic, CharTensor(..), charAsStatic , CCharStorage, CharStorage(..), charStorage -- * Signed types , CLongTensor, LongDynamic(..), longDynamic, LongTensor(..), longAsStatic , CLongStorage, LongStorage(..), longStorage , CShortTensor, ShortDynamic(..), shortDynamic, ShortTensor(..), shortAsStatic , CShortStorage, ShortStorage(..), shortStorage , CIntTensor, IntDynamic(..), intDynamic, IntTensor(..), intAsStatic , CIntStorage, IntStorage(..), intStorage -- * Floating types , CFloatTensor, FloatDynamic(..), floatDynamic, FloatTensor(..), floatAsStatic , CFloatStorage, FloatStorage(..), floatStorage , CDoubleTensor, DoubleDynamic(..), doubleDynamic, DoubleTensor(..), doubleAsStatic , CDoubleStorage, DoubleStorage(..), doubleStorage , C'THHalfTensor, C'THHalfStorage, C'THFile, C'THHalf ) where import Foreign import Foreign.C.Types import GHC.TypeLits import Data.Char (chr) import System.IO.Unsafe (unsafePerformIO) import Torch.Types.TH.Structs type CDescBuff = C'THDescBuff type DescBuff = String descBuff :: Ptr CDescBuff -> IO DescBuff descBuff p = (map (chr . fromIntegral) . c'THDescBuff'str) <$> peek p foreign import ccall "&free_CTHState" state_free :: FunPtr (Ptr C'THState -> IO ()) -- | 'torchstate' is just a foreign pointer wrapping around a null pointer with a noop -- finalizer. This is to keep the API unified with THC. torchstate :: ForeignPtr C'THState torchstate = unsafePerformIO $ newForeignPtr state_free nullPtr {-# NOINLINE torchstate #-} type C'THState = () type C'THNNState = C'THState type CState = C'THState newtype State = State { asForeign :: ForeignPtr C'THState } deriving (Eq, Show) asState = State type CAllocator = C'THAllocator newtype Allocator = Allocator { callocator :: ForeignPtr CAllocator } deriving (Eq, Show) type CGenerator = C'THGenerator -- ^ Backpack type alias for TH's CPU generator newtype Generator = Generator { rng :: ForeignPtr CGenerator } deriving (Eq, Show) -- ^ Representation of a CPU-bound random number generator generatorToRng :: ForeignPtr CGenerator -> Generator generatorToRng = Generator -- | Representation of a CPU-bound random seed newtype Seed = Seed { unSeed :: Word64 } deriving (Bounded, Enum, Eq, Integral, Num, Ord, Read, Real, Show) -- for nn-package type CNNState = CState type CDim = CLLong type CNNGenerator = CGenerator -- data CDoubleTensor type CInt' = CInt type Int' = Int -- Some type alias' type C'THIndexTensor = CLongTensor type C'THIntegerTensor = CLongTensor -- TH-specific for THNN. type CMaskTensor = CByteTensor type CIndexTensor = CLongTensor type CIndexStorage = CLongStorage type MaskDynamic = ByteDynamic type MaskTensor = ByteTensor type IndexDynamic = LongDynamic type IndexTensor = LongTensor type IndexStorage = LongStorage -- | A C-level representation of the Byte Tensor type. These need to be wrapped in 'Ptr' type CByteTensor = C'THByteTensor -- | A memory-managed representation of TH's Byte Tensor type. These carry a reference to the 'CState' newtype ByteDynamic = ByteDynamic { byteDynamicState :: (ForeignPtr CState, ForeignPtr CByteTensor) } deriving (Eq) -- | smart constructor for 'ByteDynamic'. byteDynamic = curry ByteDynamic -- | A newtype wrapper around 'ByteDynamic' which imbues a 'ByteDynamic' with static tensor dimensions. newtype ByteTensor (ds :: [Nat]) = ByteTensor { byteAsDynamic :: ByteDynamic } deriving (Eq) -- | smart constructor for 'ByteTensor'. byteAsStatic = ByteTensor type CByteStorage = C'THByteStorage newtype ByteStorage = ByteStorage { byteStorageState :: (ForeignPtr CState, ForeignPtr CByteStorage) } deriving (Eq) byteStorage = curry ByteStorage type CCharTensor = C'THCharTensor newtype CharDynamic = CharDynamic { charDynamicState :: (ForeignPtr CState, ForeignPtr CCharTensor) } deriving (Eq) charDynamic = curry CharDynamic newtype CharTensor (ds :: [Nat]) = CharTensor { charAsDynamic :: CharDynamic } deriving (Eq) charAsStatic = CharTensor type CCharStorage = C'THCharStorage newtype CharStorage = CharStorage { charStorageState :: (ForeignPtr CState, ForeignPtr CCharStorage) } deriving (Eq) charStorage = curry CharStorage -- Signed types type CLongTensor = C'THLongTensor newtype LongDynamic = LongDynamic { longDynamicState :: (ForeignPtr CState, ForeignPtr CLongTensor) } deriving (Eq) longDynamic = curry LongDynamic newtype LongTensor (ds :: [Nat]) = LongTensor { longAsDynamic :: LongDynamic } deriving (Eq) longAsStatic = LongTensor type CLongStorage = C'THLongStorage newtype LongStorage = LongStorage { longStorageState :: (ForeignPtr CState, ForeignPtr CLongStorage) } deriving (Eq) longStorage = curry LongStorage type CShortTensor = C'THShortTensor newtype ShortDynamic = ShortDynamic { shortDynamicState :: (ForeignPtr CState, ForeignPtr CShortTensor) } deriving (Eq) shortDynamic = curry ShortDynamic newtype ShortTensor (ds :: [Nat]) = ShortTensor { shortAsDynamic :: ShortDynamic } deriving (Eq) shortAsStatic = ShortTensor type CShortStorage = C'THShortStorage newtype ShortStorage = ShortStorage { shortStorageState :: (ForeignPtr CState, ForeignPtr CShortStorage) } deriving (Eq) shortStorage = curry ShortStorage type CIntTensor = C'THIntTensor newtype IntDynamic = IntDynamic { intDynamicState :: (ForeignPtr CState, ForeignPtr CIntTensor) } deriving (Eq) intDynamic = curry IntDynamic newtype IntTensor (ds :: [Nat]) = IntTensor { intAsDynamic :: IntDynamic } deriving (Eq) intAsStatic = IntTensor type CIntStorage = C'THIntStorage newtype IntStorage = IntStorage { intStorageState :: (ForeignPtr CState, ForeignPtr CIntStorage) } deriving (Eq) intStorage = curry IntStorage -- Floating types type CFloatTensor = C'THFloatTensor newtype FloatDynamic = FloatDynamic { floatDynamicState :: (ForeignPtr CState, ForeignPtr CFloatTensor) } deriving (Eq) floatDynamic = curry FloatDynamic newtype FloatTensor (ds :: [Nat]) = FloatTensor { floatAsDynamic :: FloatDynamic } deriving (Eq) floatAsStatic = FloatTensor type CFloatStorage = C'THFloatStorage newtype FloatStorage = FloatStorage { floatStorageState :: (ForeignPtr CState, ForeignPtr CFloatStorage) } deriving (Eq) floatStorage = curry FloatStorage type CDoubleTensor = C'THDoubleTensor newtype DoubleDynamic = DoubleDynamic { doubleDynamicState :: (ForeignPtr CState, ForeignPtr CDoubleTensor) } deriving (Eq) doubleDynamic = curry DoubleDynamic newtype DoubleTensor (ds :: [Nat]) = DoubleTensor { doubleAsDynamic :: DoubleDynamic } deriving (Eq) doubleAsStatic = DoubleTensor type CDoubleStorage = C'THDoubleStorage newtype DoubleStorage = DoubleStorage { doubleStorageState :: (ForeignPtr CState, ForeignPtr CDoubleStorage) } deriving (Eq) doubleStorage = curry DoubleStorage {- data CHalfTensor data HalfDynTensor halfCTensor :: HalfDynTensor -> ForeignPtr CHalfTensor halfDynTensor :: ForeignPtr CHalfTensor -> HalfDynTensor -} type C'THHalfTensor = () type C'THHalfStorage = () type C'THFile = () type C'THHalf = Ptr ()