{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE Strict #-}
{-# OPTIONS_GHC -fno-cse #-}
module Torch.Indef.Storage
( IsList(..)
, Storage(..)
, cstorage
, storage
, storageState
, storageStateRef
, storagedata
, size
, set
, get
, empty
, newWithSize
, newWithSize1
, newWithSize2
, newWithSize3
, newWithSize4
, newWithMapping
, newWithData
, setFlag
, clearFlag
, retain
, resize
, fill
) where
import Torch.Indef.Storage.Copy as X
import Foreign hiding (with, new)
import Foreign.C.Types
import GHC.ForeignPtr (ForeignPtr)
import GHC.Int
import GHC.Word
import Control.Monad
import Control.Monad.Managed
import Control.DeepSeq
import System.IO.Unsafe
import GHC.Exts (IsList(..))
import Control.Monad.ST
import Data.STRef
import Torch.Indef.Types
import Torch.Indef.Internal
import qualified Torch.Sig.Types as Sig
import qualified Torch.Sig.Types.Global as Sig
import qualified Torch.Sig.Storage as Sig
import qualified Torch.Sig.Storage.Memory as Sig
import qualified Foreign.Marshal.Array as FM
storagedata :: Storage -> [HsReal]
storagedata s = unsafeDupablePerformIO . flip with (pure . fmap c2hsReal) $ do
st <- managedState
s' <- managedStorage s
liftIO $ do
sz <- fromIntegral <$> Sig.c_size st s'
tmp <- FM.mallocArray sz
creals <- Sig.c_data st s'
FM.copyArray tmp creals sz
FM.peekArray sz tmp
where
arrayLen :: Ptr CState -> Ptr CStorage -> IO Int
arrayLen st p = fromIntegral <$> Sig.c_size st p
{-# NOINLINE storagedata #-}
size :: Storage -> Int
size s = unsafeDupablePerformIO . fmap fromIntegral . withLift $ Sig.c_size
<$> managedState
<*> managedStorage s
{-# NOINLINE size #-}
set :: Storage -> Word -> HsReal -> IO ()
set s pd v = withLift $ Sig.c_set
<$> managedState
<*> managedStorage s
<*> pure (fromIntegral pd)
<*> pure (hs2cReal v)
get :: Storage -> Word -> HsReal
get s pd = unsafeDupablePerformIO . fmap c2hsReal . withLift $ Sig.c_get
<$> managedState
<*> managedStorage s
<*> pure (fromIntegral pd)
{-# NOINLINE get #-}
empty :: Storage
empty = unsafeDupablePerformIO . withStorage $ Sig.c_new
<$> managedState
{-# NOINLINE empty #-}
newWithSize :: Word -> Storage
newWithSize pd = unsafeDupablePerformIO . withStorage $ Sig.c_newWithSize
<$> managedState
<*> pure (fromIntegral pd)
{-# NOINLINE newWithSize #-}
newWithSize1 :: HsReal -> Storage
newWithSize1 a0 = unsafeDupablePerformIO . withStorage $ Sig.c_newWithSize1
<$> managedState
<*> pure (hs2cReal a0)
{-# NOINLINE newWithSize1 #-}
newWithSize2 :: HsReal -> HsReal -> Storage
newWithSize2 a0 a1 = unsafeDupablePerformIO . withStorage $ Sig.c_newWithSize2
<$> managedState
<*> pure (hs2cReal a0)
<*> pure (hs2cReal a1)
{-# NOINLINE newWithSize2 #-}
newWithSize3 :: HsReal -> HsReal -> HsReal -> Storage
newWithSize3 a0 a1 a2 = unsafeDupablePerformIO . withStorage $ Sig.c_newWithSize3
<$> managedState
<*> pure (hs2cReal a0)
<*> pure (hs2cReal a1)
<*> pure (hs2cReal a2)
{-# NOINLINE newWithSize3 #-}
newWithSize4 :: HsReal -> HsReal -> HsReal -> HsReal -> Storage
newWithSize4 a0 a1 a2 a3 = unsafeDupablePerformIO . withStorage $ Sig.c_newWithSize4
<$> managedState
<*> pure (hs2cReal a0)
<*> pure (hs2cReal a1)
<*> pure (hs2cReal a2)
<*> pure (hs2cReal a3)
{-# NOINLINE newWithSize4 #-}
newWithMapping
:: [Int8]
-> Word64
-> Int32
-> IO Storage
newWithMapping pcc' pd ci = withStorage $ Sig.c_newWithMapping
<$> managedState
<*> liftIO (FM.newArray (map fromIntegral pcc'))
<*> pure (fromIntegral pd)
<*> pure (fromIntegral ci)
newWithData
:: [HsReal]
-> Word64
-> Storage
newWithData pr pd = unsafeDupablePerformIO . withStorage $ Sig.c_newWithData
<$> managedState
<*> liftIO (FM.newArray (hs2cReal <$> pr))
<*> pure (fromIntegral pd)
{-# NOINLINE newWithData #-}
setFlag :: Storage -> Int8 -> IO ()
setFlag s cc = withLift $ Sig.c_setFlag
<$> managedState
<*> managedStorage s
<*> pure (fromIntegral cc)
clearFlag :: Storage -> Int8 -> IO ()
clearFlag s cc = withLift $ Sig.c_clearFlag
<$> managedState
<*> managedStorage s
<*> pure (fromIntegral cc)
retain :: Storage -> IO ()
retain s = withLift $ Sig.c_retain
<$> managedState
<*> managedStorage s
resize :: Storage -> Word32 -> IO ()
resize s pd = withLift $ Sig.c_resize
<$> managedState
<*> managedStorage s
<*> pure (fromIntegral pd)
fill :: Storage -> HsReal -> IO ()
fill s v = withLift $ Sig.c_fill
<$> managedState
<*> managedStorage s
<*> pure (hs2cReal v)
instance IsList Storage where
type Item Storage = HsReal
toList = storagedata
fromList pr = newWithData pr (fromIntegral $ length pr)
instance Show Storage where
show = show . storagedata