{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE FlexibleContexts #-}
{-# OPTIONS_GHC -fno-cse #-}
module Torch.Indef.Storage.Copy
( copy
, copyByte
, copyChar
, copyShort
, copyInt
, copyLong
, copyFloat
, copyDouble
) where
import Foreign hiding (new, with)
import Foreign.Ptr
import Control.Monad.Managed
import System.IO.Unsafe
import qualified Torch.Types.TH as TH
import qualified Foreign.Marshal.Array as FM
import qualified Torch.Sig.Types as Sig
import qualified Torch.Sig.Types.Global as Sig
import qualified Torch.Sig.Storage as Sig
import qualified Torch.Sig.Storage.Memory as Sig
import qualified Torch.Sig.Storage.Copy as Sig
import qualified Torch.FFI.TH.Long.Storage as L
import qualified Torch.FFI.TH.Float.Storage as F
import qualified Torch.FFI.TH.Byte.Storage as B
import qualified Torch.FFI.TH.Char.Storage as C
import qualified Torch.FFI.TH.Short.Storage as S
import qualified Torch.FFI.TH.Int.Storage as I
import qualified Torch.FFI.TH.Double.Storage as D
import Torch.Indef.Types
copyType
:: IO (Ptr a)
-> FinalizerPtr a
-> (ForeignPtr TH.C'THState -> ForeignPtr a -> b)
-> (Ptr CState -> Ptr CStorage -> Ptr a -> IO ())
-> Storage -> b
copyType newPtr fin builder cfun t = unsafeDupablePerformIO . flip with (pure . builder TH.torchstate) $ do
s' <- managedState
t' <- managedStorage t
liftIO $ do
target <- newPtr
cfun s' t' target
newForeignPtr fin target
{-# NOINLINE copyType #-}
rawCopy :: Storage -> [HsReal]
rawCopy t = unsafeDupablePerformIO . flip with (pure . fmap c2hsReal) $ do
s' <- managedState
t' <- managedStorage t
liftIO $ do
sz <- fromIntegral <$> Sig.c_size s' t'
res <- FM.mallocArray (fromIntegral sz)
Sig.c_rawCopy s' t' res
FM.peekArray (fromIntegral sz) res
{-# NOINLINE rawCopy #-}
copy :: Storage -> Storage
copy t = unsafeDupablePerformIO . flip with mkStorage $ do
s' <- managedState
t' <- managedStorage t
liftIO $ do
store <- Sig.c_new s'
Sig.c_copy s' t' store
pure store
{-# NOINLINE copy #-}
copyLong :: Storage -> TH.LongStorage
copyLong = copyType L.c_new_ L.p_free TH.longStorage Sig.c_copyLong
copyFloat :: Storage -> TH.FloatStorage
copyFloat = copyType F.c_new_ F.p_free TH.floatStorage Sig.c_copyFloat
copyByte :: Storage -> TH.ByteStorage
copyByte = copyType B.c_new_ B.p_free TH.byteStorage Sig.c_copyByte
copyChar :: Storage -> TH.CharStorage
copyChar = copyType C.c_new_ C.p_free TH.charStorage Sig.c_copyChar
copyShort :: Storage -> TH.ShortStorage
copyShort = copyType S.c_new_ S.p_free TH.shortStorage Sig.c_copyShort
copyInt :: Storage -> TH.IntStorage
copyInt = copyType I.c_new_ I.p_free TH.intStorage Sig.c_copyInt
copyDouble :: Storage -> TH.DoubleStorage
copyDouble = copyType D.c_new_ D.p_free TH.doubleStorage Sig.c_copyDouble