------------------------------------------------------------------------------- -- | -- Module : Torch.Indef.Dynamic.Tensor.Copy -- Copyright : (c) Sam Stites 2017 -- License : BSD3 -- Maintainer: sam@stites.io -- Stability : experimental -- Portability: non-portable -- -- Functions to copy (and cast) tensors into different types. -- This is a pure module. ------------------------------------------------------------------------------- {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE TypeSynonymInstances #-} {-# LANGUAGE ScopedTypeVariables #-} {-# OPTIONS_GHC -fno-cse #-} module Torch.Indef.Dynamic.Tensor.Copy ( copy , copyByte , copyChar , copyShort , copyInt , copyLong , copyFloat , copyDouble ) where import Foreign hiding (with) import Foreign.C.Types import Data.List (intercalate) import Control.Exception.Safe (throwString) import Control.Monad.Managed import System.IO.Unsafe import Torch.Types.TH (C'THState) import qualified Torch.Types.TH as TH import qualified Torch.Sig.Tensor as Sig import qualified Torch.Sig.Types.Global as Sig import qualified Torch.Sig.Tensor.Copy as Sig import qualified Torch.FFI.TH.Byte.Tensor as B import qualified Torch.FFI.TH.Short.Tensor as S import qualified Torch.FFI.TH.Int.Tensor as I import qualified Torch.FFI.TH.Long.Tensor as L import qualified Torch.FFI.TH.Char.Tensor as C import qualified Torch.FFI.TH.Float.Tensor as F import qualified Torch.FFI.TH.Double.Tensor as D import Torch.Indef.Types copyType :: (Ptr TH.C'THLongStorage -> Ptr TH.C'THLongStorage -> IO (Ptr a)) -> FinalizerPtr a -> (ForeignPtr C'THState -> ForeignPtr a -> b) -> (Ptr CState -> Ptr CTensor -> Ptr a -> IO ()) -> Dynamic -> b copyType newWithSize_ fin builder cfun t = unsafePerformIO . flip with (pure . builder TH.torchstate) $ do s' <- managedState t' <- managedTensor t liftIO $ do sizes <- Sig.c_newSizeOf s' t' strides <- Sig.c_newStrideOf s' t' target <- newWithSize_ sizes strides cfun s' t' target newForeignPtr fin target {-# NOINLINE copyType #-} -- | Copy a tensor. copy :: Dynamic -> Dynamic copy t = unsafePerformIO . flip with pure $ do s' <- managedState t' <- managedTensor t liftIO $ do target <- Sig.c_new s' Sig.c_resizeAs s' target t' Sig.c_copy s' target t' mkDynamic target {-# NOINLINE copy #-} -- | copy a tensor to a byte tensor. *Use at your own discresion* copyByte :: Dynamic -> TH.ByteDynamic copyByte = copyType B.c_newWithSize_ B.p_free TH.byteDynamic Sig.c_copyByte -- | copy a tensor to a char tensor. *Use at your own discresion* copyChar :: Dynamic -> TH.CharDynamic copyChar = copyType C.c_newWithSize_ C.p_free TH.charDynamic Sig.c_copyChar -- | copy a tensor to a short tensor. *Use at your own discresion* copyShort :: Dynamic -> TH.ShortDynamic copyShort = copyType S.c_newWithSize_ S.p_free TH.shortDynamic Sig.c_copyShort -- | copy a tensor to a int tensor. *Use at your own discresion* copyInt :: Dynamic -> TH.IntDynamic copyInt = copyType I.c_newWithSize_ I.p_free TH.intDynamic Sig.c_copyInt -- | copy a tensor to a long tensor. *Use at your own discresion* copyLong :: Dynamic -> TH.LongDynamic copyLong = copyType L.c_newWithSize_ L.p_free TH.longDynamic Sig.c_copyLong -- | copy a tensor to a float tensor. *Use at your own discresion* copyFloat :: Dynamic -> TH.FloatDynamic copyFloat = copyType F.c_newWithSize_ F.p_free TH.floatDynamic Sig.c_copyFloat -- | copy a tensor to a double tensor. *Use at your own discresion* copyDouble :: Dynamic -> TH.DoubleDynamic copyDouble = copyType D.c_newWithSize_ D.p_free TH.doubleDynamic Sig.c_copyDouble -- copyDouble :: Dynamic -> TH.DoubleDynamic -- copyDouble = copyType D.c_new_ D.p_free TH.doubleDynamic Sig.c_copyDouble D.c_resize -- copyDouble :: Dynamic -> TH.DoubleDynamic -- copyDouble t = unsafePerformIO . withDynamicState t $ \s' t' -> do -- withForeignPtr TH.torchstate $ \ths' -> do -- sizes <- Sig.c_newSizeOf s' t' -- strides <- Sig.c_newStrideOf s' t' -- target <- D.c_newWithSize ths' sizes strides -- -- -- mapM (size t . fromIntegral) [0.. nDimension t - 1] >>= -- Sig.c_copyDouble s' t' target -- -- out <- TH.doubleDynamic TH.torchstate <$> newForeignPtr D.p_free target -- pure out -- FIXME: reintroduce Half -- copyHalf :: t -> io H.Dynamic -- #if CUDA -- class GPUTensorCopy gpu cpu | gpu -> cpu where -- copyCuda :: gpu -> io gpu -- copyIgnoringOverlaps :: gpu -> io gpu -- -- copyCudaByte :: gpu -> IO Cuda.ByteDynamic -- copyCudaChar :: gpu -> IO Cuda.CharDynamic -- copyCudaShort :: gpu -> IO Cuda.ShortDynamic -- copyCudaInt :: gpu -> IO Cuda.IntDynamic -- copyCudaLong :: gpu -> IO Cuda.LongDynamic -- copyCudaDouble :: gpu -> IO Cuda.DoubleDynamic -- -- copyCPU :: gpu -> IO cpu -- copyAsyncCPU :: gpu -> IO cpu -- -- thCopyCuda :: cpu -> IO gpu -- thCopyAsyncCuda :: cpu -> IO gpu -- #endif