{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -fno-cse -Wno-deprecations #-}
module Torch.Indef.Dynamic.Tensor where
import Foreign hiding (with, new)
import Foreign.Ptr
import Control.Applicative ((<|>))
import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Managed
import Control.Exception.Safe
import Control.DeepSeq
import Data.Coerce (coerce)
import Data.Typeable
import Data.Maybe (fromMaybe, fromJust)
import Data.List (intercalate, genericLength)
import Data.List.NonEmpty (NonEmpty(..))
import Foreign.C.Types
import GHC.ForeignPtr (ForeignPtr)
import GHC.Int
import GHC.Exts (IsList(..))
import Numeric.Dimensions
import System.IO.Unsafe
import Control.Concurrent
import Control.Monad.Trans.Except
import Text.Printf
import qualified Data.List as List ((!!))
import qualified Data.List.NonEmpty as NE
import qualified Torch.Types.TH as TH
import qualified Foreign.Marshal.Array as FM
import qualified Torch.Sig.State as Sig
import qualified Torch.Sig.Types as Sig
import qualified Torch.Sig.Types.Global as Sig
import qualified Torch.Sig.Tensor as Sig
import qualified Torch.Sig.Tensor.Memory as Sig
import qualified Torch.Sig.Storage as StorageSig (c_size)
import Torch.Indef.Dynamic.Print (showTensor, describeTensor)
import Torch.Indef.Types
import Torch.Indef.Internal
import Torch.Indef.Index hiding (withDynamicState)
import qualified Torch.Indef.Storage as Storage
_clearFlag :: Dynamic -> Int8 -> IO ()
_clearFlag t cc = runManaged $ do
s' <- managedState
t' <- managedTensor t
liftIO $ Sig.c_clearFlag s' t' (CChar cc)
tensordata :: Dynamic -> [HsReal]
tensordata t =
case shape t of
[] -> []
ds ->
unsafeDupablePerformIO . flip with (pure . fmap c2hsReal) $ do
st <- managedState
t' <- managedTensor t
liftIO $ do
let sz = fromIntegral (product ds)
tmp <- FM.mallocArray sz
creals <- Sig.c_data st t'
FM.copyArray tmp creals sz
FM.peekArray sz tmp
{-# NOINLINE tensordata #-}
get1d :: Dynamic -> Word -> Maybe HsReal
get1d t d1
| nDimension t /= 1 || size t 0 < d1 = Nothing
| otherwise = unsafeDupablePerformIO . flip with (pure . Just . c2hsReal) . (liftIO =<<) $ Sig.c_get1d
<$> managedState
<*> managedTensor t
<*> pure (fromIntegral d1)
{-# NOINLINE get1d #-}
unsafeGet1d :: Dynamic -> Word -> HsReal
unsafeGet1d t d1 = fromJust $ get1d t d1
get2d :: Dynamic -> Word -> Word -> Maybe HsReal
get2d t d1 d2
| nDimension t /= 2 = Nothing
| otherwise = unsafeDupablePerformIO . flip with (pure . Just . c2hsReal) . (liftIO =<<) $ Sig.c_get2d
<$> managedState
<*> managedTensor t
<*> pure (fromIntegral d1)
<*> pure (fromIntegral d2)
{-# NOINLINE get2d #-}
unsafeGet2d :: Dynamic -> Word -> Word -> HsReal
unsafeGet2d t d1 d2 = fromJust $ get2d t d1 d2
get3d :: Dynamic -> Word -> Word -> Word -> Maybe HsReal
get3d t d1 d2 d3
| nDimension t /= 3 = Nothing
| otherwise = unsafeDupablePerformIO . flip with (pure . Just . c2hsReal) . (liftIO =<<) $ Sig.c_get3d
<$> managedState
<*> managedTensor t
<*> pure (fromIntegral d1)
<*> pure (fromIntegral d2)
<*> pure (fromIntegral d3)
{-# NOINLINE get3d #-}
unsafeGet3d :: Dynamic -> Word -> Word -> Word -> HsReal
unsafeGet3d t d1 d2 d3 = fromJust $ get3d t d1 d2 d3
get4d :: Dynamic -> Word -> Word -> Word -> Word -> Maybe HsReal
get4d t d1 d2 d3 d4
| nDimension t /= 4 = Nothing
| otherwise = unsafeDupablePerformIO . flip with (pure . Just . c2hsReal) . (liftIO =<<) $ Sig.c_get4d
<$> managedState
<*> managedTensor t
<*> pure (fromIntegral d1)
<*> pure (fromIntegral d2)
<*> pure (fromIntegral d3)
<*> pure (fromIntegral d4)
{-# NOINLINE get4d #-}
unsafeGet4d :: Dynamic -> Word -> Word -> Word -> Word -> HsReal
unsafeGet4d t d1 d2 d3 d4 = fromJust $ get4d t d1 d2 d3 d4
getDim :: Dynamic -> Dims ((i:+ds)::[Nat]) -> Maybe HsReal
getDim t d = case fromIntegral <$> listDims d of
[] -> error "[impossible] pattern match fail, `Dims ((i:+ds)::[Nat])` prevents this"
[x] -> get1d t x
[x, y] -> get2d t x y
[x, y, z] -> get3d t x y z
[x, y, z, q] -> get4d t x y z q
_ -> error "[incomplete] getDim doen't have support for dimensions > 4"
isContiguous :: Dynamic -> Bool
isContiguous t = unsafeDupablePerformIO . flip with (pure . (1 ==)) . (liftIO =<<) $ Sig.c_isContiguous
<$> managedState
<*> managedTensor t
{-# NOINLINE isContiguous #-}
isSameSizeAs :: Dynamic -> Dynamic -> Bool
isSameSizeAs t0 t1 = unsafeDupablePerformIO . flip with (pure . (1 ==)) . (liftIO =<<) $ Sig.c_isSameSizeAs
<$> managedState
<*> managedTensor t0
<*> managedTensor t1
{-# NOINLINE isSameSizeAs #-}
isSetTo :: Dynamic -> Dynamic -> Bool
isSetTo t0 t1 = unsafeDupablePerformIO . flip with (pure . (1 ==)) . (liftIO =<<) $ Sig.c_isSetTo
<$> managedState
<*> managedTensor t0
<*> managedTensor t1
{-# NOINLINE isSetTo #-}
isSize :: Dynamic -> TH.LongStorage -> Bool
isSize t ls = unsafeDupablePerformIO . flip with (pure . (1 ==)) . (liftIO =<<) $ Sig.c_isSize
<$> managedState
<*> managedTensor t
<*> managed (withForeignPtr (snd $ TH.longStorageState ls))
{-# NOINLINE isSize #-}
nDimension :: Dynamic -> Word
nDimension t = unsafeDupablePerformIO . flip with (pure . fromIntegral) . (liftIO =<<) $ Sig.c_nDimension
<$> managedState
<*> managedTensor t
{-# NOINLINE nDimension #-}
nElement :: Dynamic -> Word64
nElement t = unsafeDupablePerformIO . flip with (pure . fromIntegral) . (liftIO =<<) $ Sig.c_nElement
<$> managedState
<*> managedTensor t
{-# NOINLINE nElement #-}
_narrow
:: Dynamic
-> Dynamic
-> Word
-> Int64
-> Size
-> IO ()
_narrow t0 t1 a b c = withLift $ Sig.c_narrow
<$> managedState
<*> managedTensor t0
<*> managedTensor t1
<*> pure (fromIntegral a)
<*> pure (fromIntegral b)
<*> pure (fromIntegral c)
{-# WARNING _narrow "hasktorch devs have not yet made this safe. You are warned." #-}
empty :: Dynamic
empty = unsafeDupablePerformIO . withDynamic $ Sig.c_new <$> managedState
{-# NOINLINE empty #-}
newExpand :: Dynamic -> TH.IndexStorage -> Dynamic
newExpand r ix = unsafeDupablePerformIO . withDynamic $ Sig.c_newExpand
<$> managedState
<*> managedTensor r
<*> managed (withForeignPtr . snd $ TH.longStorageState ix)
{-# NOINLINE newExpand #-}
_expand
:: Dynamic
-> Dynamic
-> TH.IndexStorage
-> IO ()
_expand r t ix = withLift $ Sig.c_expand
<$> managedState
<*> managedTensor r
<*> managedTensor t
<*> managed (withForeignPtr . snd $ TH.longStorageState ix)
_expandNd :: NonEmpty Dynamic -> NonEmpty Dynamic -> Int -> IO ()
_expandNd (rets@(s:|_)) ops i = runManaged $ do
st <- managedState
rets' <- mngNonEmpty rets
ops' <- mngNonEmpty ops
liftIO $ Sig.c_expandNd st rets' ops' (fromIntegral i)
where
mngNonEmpty :: NonEmpty Dynamic -> Managed (Ptr (Ptr CTensor))
mngNonEmpty = mapM toMPtr . NE.toList >=> mWithArray
mWithArray :: [Ptr a] -> Managed (Ptr (Ptr a))
mWithArray as = managed (FM.withArray as)
toMPtr :: Dynamic -> Managed (Ptr CTensor)
toMPtr d = managed (withForeignPtr (Sig.ctensor d))
newClone :: Dynamic -> Dynamic
newClone t = unsafeDupablePerformIO . withDynamic $ Sig.c_newClone
<$> managedState
<*> managedTensor t
{-# NOINLINE newClone #-}
newContiguous :: Dynamic -> Dynamic
newContiguous t = unsafeDupablePerformIO . withDynamic $ Sig.c_newContiguous
<$> managedState
<*> managedTensor t
{-# NOINLINE newContiguous #-}
newNarrow
:: Dynamic
-> Word
-> Int64
-> Size
-> IO Dynamic
newNarrow t a b c = withDynamic $ Sig.c_newNarrow
<$> managedState
<*> managedTensor t
<*> pure (fromIntegral a)
<*> pure (fromIntegral b)
<*> pure (fromIntegral c)
{-# WARNING newNarrow "hasktorch devs have not yet made this safe. You are warned." #-}
newSelect
:: Dynamic
-> Word
-> Int64
-> IO Dynamic
newSelect t a b = withDynamic $ Sig.c_newSelect
<$> managedState
<*> managedTensor t
<*> pure (fromIntegral a)
<*> pure (fromIntegral b)
{-# WARNING newSelect "hasktorch devs have not yet made this safe. You are warned." #-}
newSizeOf :: Dynamic -> TH.IndexStorage
newSizeOf t = unsafeDupablePerformIO . flip with mkCPUIxStorage $ do
s' <- managedState
t' <- managedTensor t
liftIO $ Sig.c_newSizeOf s' t'
{-# NOINLINE newSizeOf #-}
newStrideOf :: Dynamic -> TH.IndexStorage
newStrideOf t = unsafeDupablePerformIO . flip with mkCPUIxStorage $ do
s' <- managedState
t' <- managedTensor t
liftIO $ Sig.c_newStrideOf s' t'
{-# NOINLINE newStrideOf #-}
newTranspose :: Dynamic -> Word -> Word -> Dynamic
newTranspose t a b = unsafeDupablePerformIO . withDynamic $ Sig.c_newTranspose
<$> managedState
<*> managedTensor t
<*> pure (fromIntegral a)
<*> pure (fromIntegral b)
{-# NOINLINE newTranspose #-}
newUnfold
:: Dynamic
-> Word
-> Int64
-> Int64
-> Dynamic
newUnfold t a b c = unsafeDupablePerformIO . withDynamic $ Sig.c_newUnfold
<$> managedState
<*> managedTensor t
<*> pure (fromIntegral a)
<*> pure (fromIntegral b)
<*> pure (fromIntegral c)
{-# NOINLINE newUnfold #-}
newView :: Dynamic -> TH.IndexStorage -> IO Dynamic
newView t ix = withDynamic $ Sig.c_newView
<$> managedState
<*> managedTensor t
<*> managed (withCPUIxStorage ix)
{-# WARNING newView "hasktorch devs have not yet made this safe. You are warned." #-}
newWithSize :: TH.IndexStorage -> TH.IndexStorage -> Dynamic
newWithSize l0 l1 = unsafeDupablePerformIO . withDynamic $ Sig.c_newWithSize
<$> managedState
<*> managed (withCPUIxStorage l0)
<*> managed (withCPUIxStorage l1)
{-# NOINLINE newWithSize #-}
newWithSize1d :: Word -> Dynamic
newWithSize1d a0 = unsafeDupablePerformIO . withDynamic $ Sig.c_newWithSize1d
<$> managedState
<*> pure (fromIntegral a0)
{-# NOINLINE newWithSize1d #-}
newWithSize2d :: Word -> Word -> Dynamic
newWithSize2d a0 a1 = unsafeDupablePerformIO . withDynamic $ Sig.c_newWithSize2d
<$> managedState
<*> pure (fromIntegral a0)
<*> pure (fromIntegral a1)
{-# NOINLINE newWithSize2d #-}
newWithSize3d :: Word -> Word -> Word -> Dynamic
newWithSize3d a0 a1 a2 = unsafeDupablePerformIO . withDynamic $ Sig.c_newWithSize3d
<$> managedState
<*> pure (fromIntegral a0)
<*> pure (fromIntegral a1)
<*> pure (fromIntegral a2)
{-# NOINLINE newWithSize3d #-}
newWithSize4d :: Word -> Word -> Word -> Word -> Dynamic
newWithSize4d a0 a1 a2 a3 = unsafeDupablePerformIO . withDynamic $ Sig.c_newWithSize4d
<$> managedState
<*> pure (fromIntegral a0)
<*> pure (fromIntegral a1)
<*> pure (fromIntegral a2)
<*> pure (fromIntegral a3)
{-# NOINLINE newWithSize4d #-}
newWithStorage :: Storage -> StorageOffset -> TH.IndexStorage -> TH.IndexStorage -> Dynamic
newWithStorage s pd l0 l1 = unsafeDupablePerformIO . withDynamic $ Sig.c_newWithStorage
<$> managedState
<*> managedStorage s
<*> pure (fromIntegral pd)
<*> managed (withForeignPtr (snd $ TH.longStorageState l0))
<*> managed (withForeignPtr (snd $ TH.longStorageState l1))
{-# NOINLINE newWithStorage #-}
newWithStorage1d
:: Storage
-> StorageOffset
-> (Size, Stride)
-> Dynamic
newWithStorage1d s pd (d00,d01) = unsafeDupablePerformIO . withDynamic $ Sig.c_newWithStorage1d
<$> managedState
<*> managedStorage s
<*> pure (fromIntegral pd)
<*> pure (fromIntegral d00) <*> pure (fromIntegral d01)
{-# NOINLINE newWithStorage1d #-}
newWithStorage2d
:: Storage
-> StorageOffset
-> (Size, Stride)
-> (Size, Stride)
-> Dynamic
newWithStorage2d s pd (d00,d01) (d10,d11) = unsafeDupablePerformIO . withDynamic $ Sig.c_newWithStorage2d
<$> managedState
<*> managedStorage s
<*> pure (fromIntegral pd)
<*> pure (fromIntegral d00) <*> pure (fromIntegral d01)
<*> pure (fromIntegral d10) <*> pure (fromIntegral d11)
{-# NOINLINE newWithStorage2d #-}
newWithStorage3d
:: Storage
-> StorageOffset
-> (Size, Stride)
-> (Size, Stride)
-> (Size, Stride)
-> Dynamic
newWithStorage3d s pd (d00,d01) (d10,d11) (d20,d21) = unsafeDupablePerformIO . withDynamic $ Sig.c_newWithStorage3d
<$> managedState
<*> managedStorage s
<*> pure (fromIntegral pd)
<*> pure (fromIntegral d00) <*> pure (fromIntegral d01)
<*> pure (fromIntegral d10) <*> pure (fromIntegral d11)
<*> pure (fromIntegral d20) <*> pure (fromIntegral d21)
{-# NOINLINE newWithStorage3d #-}
newWithStorage4d
:: Storage
-> StorageOffset
-> (Size, Stride)
-> (Size, Stride)
-> (Size, Stride)
-> (Size, Stride)
-> Dynamic
newWithStorage4d s pd (d00,d01) (d10,d11) (d20,d21) (d30,d31) = unsafeDupablePerformIO . withDynamic $ Sig.c_newWithStorage4d
<$> managedState
<*> managedStorage s
<*> pure (fromIntegral pd)
<*> pure (fromIntegral d00) <*> pure (fromIntegral d01)
<*> pure (fromIntegral d10) <*> pure (fromIntegral d11)
<*> pure (fromIntegral d20) <*> pure (fromIntegral d21)
<*> pure (fromIntegral d30) <*> pure (fromIntegral d31)
{-# NOINLINE newWithStorage4d #-}
newWithTensor :: Dynamic -> IO Dynamic
newWithTensor t = withDynamic $ Sig.c_newWithTensor
<$> managedState
<*> managedTensor t
{-# NOINLINE newWithTensor #-}
{-# WARNING newWithTensor "this function causes the input tensor to be impure" #-}
_resize
:: Dynamic -> TH.IndexStorage -> TH.IndexStorage -> IO ()
_resize t l0 l1 = withLift $ Sig.c_resize
<$> managedState
<*> managedTensor t
<*> managed (withCPUIxStorage l0)
<*> managed (withCPUIxStorage l1)
resize1d_ :: Dynamic -> Word -> IO ()
resize1d_ t l0 = withLift $ Sig.c_resize1d
<$> managedState
<*> managedTensor t
<*> pure (fromIntegral l0)
resize2d_ :: Dynamic -> Word -> Word -> IO ()
resize2d_ t l0 l1 = withLift $ Sig.c_resize2d
<$> managedState
<*> managedTensor t
<*> pure (fromIntegral l0)
<*> pure (fromIntegral l1)
resize3d_ :: Dynamic -> Word -> Word -> Word -> IO ()
resize3d_ t l0 l1 l2 = withLift $ Sig.c_resize3d
<$> managedState
<*> managedTensor t
<*> pure (fromIntegral l0)
<*> pure (fromIntegral l1)
<*> pure (fromIntegral l2)
resize4d_ :: Dynamic -> Word -> Word -> Word -> Word -> IO ()
resize4d_ t l0 l1 l2 l3 = withLift $ Sig.c_resize4d
<$> managedState
<*> managedTensor t
<*> pure (fromIntegral l0)
<*> pure (fromIntegral l1)
<*> pure (fromIntegral l2)
<*> pure (fromIntegral l3)
resize5d_ :: Dynamic -> Word -> Word -> Word -> Word -> Word -> IO ()
resize5d_ t l0 l1 l2 l3 l4 = withLift $ Sig.c_resize5d
<$> managedState
<*> managedTensor t
<*> pure (fromIntegral l0)
<*> pure (fromIntegral l1)
<*> pure (fromIntegral l2)
<*> pure (fromIntegral l3)
<*> pure (fromIntegral l4)
resizeAs_
:: Dynamic
-> Dynamic
-> IO ()
resizeAs_ t0 t1 = with2DynamicState t0 t1 Sig.c_resizeAs
resizeNd_
:: Dynamic
-> Int32
-> [Size]
-> [Stride]
-> IO ()
resizeNd_ t i l0' l1' = withLift $ Sig.c_resizeNd
<$> managedState
<*> managedTensor t
<*> pure (fromIntegral i)
<*> liftIO (FM.newArray (coerce l0' :: [CLLong]))
<*> liftIO (FM.newArray (coerce l1' :: [CLLong]))
retain :: Dynamic -> IO ()
retain t = withLift $ Sig.c_retain
<$> managedState
<*> managedTensor t
_select
:: Dynamic
-> Dynamic
-> Word
-> Word
-> IO ()
_select t0 t1 a b = with2DynamicState t0 t1 $ \s' t0' t1' ->
Sig.c_select s' t0' t1' (fromIntegral a) (fromIntegral b)
{-# WARNING _select "hasktorch devs have not yet made this safe. You are warned." #-}
_set
:: Dynamic
-> Dynamic
-> IO ()
_set t0 t1 = with2DynamicState t0 t1 Sig.c_set
{-# WARNING _set "hasktorch devs have not yet made this safe. You are warned." #-}
set1d_
:: Dynamic
-> Word
-> HsReal
-> IO ()
set1d_ t l0 v = withLift $ Sig.c_set1d
<$> managedState
<*> managedTensor t
<*> pure (fromIntegral l0)
<*> pure (hs2cReal v)
set2d_
:: Dynamic
-> Word
-> Word
-> HsReal
-> IO ()
set2d_ t l0 l1 v = withLift $ Sig.c_set2d
<$> managedState
<*> managedTensor t
<*> pure (fromIntegral l0)
<*> pure (fromIntegral l1)
<*> pure (hs2cReal v)
set3d_
:: Dynamic
-> Word
-> Word
-> Word
-> HsReal
-> IO ()
set3d_ t l0 l1 l2 v = withLift $ Sig.c_set3d
<$> managedState
<*> managedTensor t
<*> pure (fromIntegral l0)
<*> pure (fromIntegral l1)
<*> pure (fromIntegral l2)
<*> pure (hs2cReal v)
set4d_
:: Dynamic
-> Word
-> Word
-> Word
-> Word
-> HsReal
-> IO ()
set4d_ t l0 l1 l2 l3 v = withLift $ Sig.c_set4d
<$> managedState
<*> managedTensor t
<*> pure (fromIntegral l0)
<*> pure (fromIntegral l1)
<*> pure (fromIntegral l2)
<*> pure (fromIntegral l3)
<*> pure (hs2cReal v)
setFlag_ :: Dynamic -> Int8 -> IO ()
setFlag_ t l0 = withLift $ Sig.c_setFlag
<$> managedState
<*> managedTensor t
<*> pure (CChar l0)
setStorage_ :: Dynamic -> Storage -> StorageOffset -> TH.IndexStorage -> TH.IndexStorage -> IO ()
setStorage_ t s a b c = withLift $ Sig.c_setStorage
<$> managedState
<*> managedTensor t
<*> managed (withForeignPtr (Sig.cstorage s))
<*> pure (fromIntegral a)
<*> managed (withCPUIxStorage b)
<*> managed (withCPUIxStorage c)
{-# WARNING setStorage_ "mutating a tensor's storage can make your program unsafe. You are warned." #-}
setStorage1d_ :: Dynamic -> Storage -> StorageOffset -> (Size, Stride) -> IO ()
setStorage1d_ t s pd (d00,d01) = withLift $ Sig.c_setStorage1d
<$> managedState
<*> managedTensor t
<*> managed (withForeignPtr (Sig.cstorage s))
<*> pure (fromIntegral pd)
<*> pure (fromIntegral d00) <*> pure (fromIntegral d01)
{-# WARNING setStorage1d_ "mutating a tensor's storage can make your program unsafe. You are warned." #-}
setStorage2d_ :: Dynamic -> Storage -> StorageOffset -> (Size, Stride) -> (Size, Stride) -> IO ()
setStorage2d_ t s pd (d00,d01) (d10,d11) = withLift $ Sig.c_setStorage2d
<$> managedState
<*> managedTensor t
<*> managed (withForeignPtr (Sig.cstorage s))
<*> pure (fromIntegral pd)
<*> pure (fromIntegral d00) <*> pure (fromIntegral d01)
<*> pure (fromIntegral d10) <*> pure (fromIntegral d11)
{-# WARNING setStorage2d_ "mutating a tensor's storage can make your program unsafe. You are warned." #-}
setStorage3d_ :: Dynamic -> Storage -> StorageOffset -> (Size, Stride) -> (Size, Stride) -> (Size, Stride) -> IO ()
setStorage3d_ t s pd (d00,d01) (d10,d11) (d20,d21) = withLift $ Sig.c_setStorage3d
<$> managedState
<*> managedTensor t
<*> managed (withForeignPtr (Sig.cstorage s))
<*> pure (fromIntegral pd)
<*> pure (fromIntegral d00) <*> pure (fromIntegral d01)
<*> pure (fromIntegral d10) <*> pure (fromIntegral d11)
<*> pure (fromIntegral d20) <*> pure (fromIntegral d21)
{-# WARNING setStorage3d_ "mutating a tensor's storage can make your program unsafe. You are warned." #-}
setStorage4d_ :: Dynamic -> Storage -> StorageOffset -> (Size, Stride) -> (Size, Stride) -> (Size, Stride) -> (Size, Stride) -> IO ()
setStorage4d_ t s pd (d00,d01) (d10,d11) (d20,d21) (d30,d31) = withLift $ Sig.c_setStorage4d
<$> managedState
<*> managedTensor t
<*> managed (withForeignPtr (Sig.cstorage s))
<*> pure (fromIntegral pd)
<*> pure (fromIntegral d00) <*> pure (fromIntegral d01)
<*> pure (fromIntegral d10) <*> pure (fromIntegral d11)
<*> pure (fromIntegral d20) <*> pure (fromIntegral d21)
<*> pure (fromIntegral d30) <*> pure (fromIntegral d31)
{-# WARNING setStorage4d_ "mutating a tensor's storage can make your program unsafe. You are warned." #-}
setStorageNd_
:: Dynamic
-> Storage
-> StorageOffset
-> Word
-> [Size]
-> [Stride]
-> IO ()
setStorageNd_ t s a b hsc hsd = withLift $ Sig.c_setStorageNd
<$> managedState
<*> managedTensor t
<*> managed (withForeignPtr (Sig.cstorage s))
<*> pure (fromIntegral a)
<*> pure (fromIntegral b)
<*> liftIO (FM.newArray (coerce hsc :: [CLLong]))
<*> liftIO (FM.newArray (coerce hsd :: [CLLong]))
{-# WARNING setStorageNd_ "mutating a tensor's storage can make your program unsafe. You are warned." #-}
size
:: Dynamic
-> Word
-> Word
size t d = unsafeDupablePerformIO . flip with (pure . fromIntegral) . (liftIO =<<) $ Sig.c_size
<$> managedState
<*> managedTensor t
<*> pure (fromIntegral d)
sizeDesc :: Dynamic -> IO DescBuff
sizeDesc t = flip with (Sig.descBuff) $ do
s' <- managedState
t' <- managedTensor t
liftIO $ Sig.c_sizeDesc s' t'
_squeeze :: Dynamic -> Dynamic -> IO ()
_squeeze t0 t1 = withLift $ Sig.c_squeeze
<$> managedState
<*> managedTensor t1
<*> managedTensor t0
squeeze1d_
:: Dynamic
-> Word
-> IO ()
squeeze1d_ t d = _squeeze1d t t d
_squeeze1d
:: Dynamic
-> Dynamic
-> Word
-> IO ()
_squeeze1d t0 t1 d = withLift $ Sig.c_squeeze1d
<$> managedState
<*> managedTensor t1
<*> managedTensor t0
<*> pure (fromIntegral d)
storage :: Dynamic -> Storage
storage t = unsafeDupablePerformIO . withStorage $ Sig.c_storage
<$> managedState
<*> managedTensor t
{-# NOINLINE storage #-}
{-# WARNING storage "extracting and using a tensor's storage can make your program unsafe. You are warned." #-}
storageOffset :: Dynamic -> StorageOffset
storageOffset t = fromIntegral . unsafeDupablePerformIO . withLift $ Sig.c_storageOffset
<$> managedState
<*> managedTensor t
{-# NOINLINE storageOffset #-}
stride
:: Dynamic
-> Word
-> IO Stride
stride t a = flip with (pure . fromIntegral) . (liftIO =<<) $ Sig.c_stride
<$> managedState
<*> managedTensor t
<*> pure (fromIntegral a)
_transpose
:: Dynamic
-> Dynamic
-> Word
-> Word
-> IO ()
_transpose t0 t1 a b = withLift $ Sig.c_transpose
<$> managedState
<*> managedTensor t0
<*> managedTensor t1
<*> pure (fromIntegral a)
<*> pure (fromIntegral b)
_unfold
:: Dynamic
-> Dynamic
-> Word
-> Size
-> Step
-> IO ()
_unfold t0 t1 a b c = with2DynamicState t0 t1 $ \s' t0' t1' ->
Sig.c_unfold s' t0' t1' (fromIntegral a) (fromIntegral b) (fromIntegral c)
unsqueeze1d_
:: Dynamic
-> Word
-> IO ()
unsqueeze1d_ t = _unsqueeze1d t t
_unsqueeze1d
:: Dynamic
-> Dynamic
-> Word
-> IO ()
_unsqueeze1d t0 t1 d = withLift $ Sig.c_unsqueeze1d
<$> managedState
<*> managedTensor t0
<*> managedTensor t1
<*> pure (fromIntegral d)
shape :: Dynamic -> [Word]
shape t = case nDimension t of
0 -> []
d -> (size t . fromIntegral) <$> [0.. d - 1]
setStorageDim_ :: Dynamic -> Storage -> StorageOffset -> [(Size, Stride)] -> IO ()
setStorageDim_ t s o = \case
[] -> throwNE "can't setStorage on an empty dimension."
[x] -> setStorage1d_ t s o x
[x, y] -> setStorage2d_ t s o x y
[x, y, z] -> setStorage3d_ t s o x y z
[x, y, z, q] -> setStorage4d_ t s o x y z q
_ -> throwGT4 "setStorage"
{-# WARNING setStorageDim_ "mutating a tensor's storage can make your program unsafe. You are warned." #-}
setDim_ :: Dynamic -> Dims (d::[Nat]) -> HsReal -> IO ()
setDim_ t d !v = do
threadDelay 1000
case fromIntegral <$> listDims d of
[] -> throwNE "can't set on an empty dimension."
[x] -> set1d_ t x v
[x, y] -> set2d_ t x y v
[x, y, z] -> set3d_ t x y z v
[x, y, z, q] -> set4d_ t x y z q v
_ -> throwGT4 "set"
resizeDim_ :: Dynamic -> Dims (d::[Nat]) -> IO ()
resizeDim_ t d = case fromIntegral <$> listDims d of
[] -> throwNE "can't resize to an empty dimension."
[x] -> resize1d_ t x
[x, y] -> resize2d_ t x y
[x, y, z] -> resize3d_ t x y z
[x, y, z, q] -> resize4d_ t x y z q
[x, y, z, q, w] -> resize5d_ t x y z q w
_ -> throwFIXME "this should be doable with resizeNd" "resizeDim"
vectorEIO :: [HsReal] -> ExceptT String IO Dynamic
vectorEIO l = lift $ do
pure $ newWithStorage1d (fromList l) 0 (genericLength l, 1)
vectorE :: [HsReal] -> Either String Dynamic
vectorE = unsafePerformIO . runExceptT . vectorEIO
{-# NOINLINE vectorE #-}
vector :: [HsReal] -> Maybe Dynamic
vector = either (const Nothing) Just . vectorE
matrix :: [[HsReal]] -> ExceptT String IO Dynamic
matrix ls
| null ls = lift (pure empty)
| any ((ncols /=) . length) ls = ExceptT . pure $ Left "rows are not all the same length"
| otherwise = do
lift $ do
pure $ newWithStorage2d (fromList l) 0 (nrows, ncols) (ncols, 1)
where
l = concat ls
go vec (SomeDims ds) = resizeDim_ vec ds
ncols :: Integral i => i
ncols = genericLength (head ls)
nrows :: Integral i => i
nrows = genericLength ls
{-# NOINLINE cuboid #-}
cuboid :: [[[HsReal]]] -> ExceptT String IO Dynamic
cuboid ls
| isEmpty ls = lift (pure empty)
| null ls || any null ls || any (any null) ls
= ExceptT . pure . Left $ "can't accept empty lists"
| innerDimCheck ncols ls = ExceptT . pure . Left $ "rows are not all the same length"
| innerDimCheck ndepth (head ls) = ExceptT . pure . Left $ "columns are not all the same length"
| otherwise = lift $ do
pure $ newWithStorage3d (fromList l) 0 (nrows, ncols * ndepth) (ncols, ndepth) (ndepth, 1)
where
l = concat (concat ls)
go vec (SomeDims ds) = resizeDim_ vec ds >> pure vec
isEmpty = \case
[] -> True
[[]] -> True
[[[]]] -> True
_ -> False
innerDimCheck :: Int -> [[x]] -> Bool
innerDimCheck d = any ((/= d) . length)
ndepth :: Integral i => i
ndepth = genericLength (head (head ls))
ncols :: Integral i => i
ncols = genericLength (head ls)
nrows :: Integral i => i
nrows = genericLength ls
{-# NOINLINE hyper #-}
hyper :: [[[[HsReal]]]] -> ExceptT String IO Dynamic
hyper ls
| isEmpty ls = lift (pure empty)
| null ls
|| any null ls
|| any (any null) ls
|| any (any (any null)) ls = ExceptT . pure . Left $ "can't accept empty lists"
| innerDimCheck ntime (head (head ls)) = ExceptT . pure . Left $ "rows are not all the same length"
| innerDimCheck ndepth (head ls) = ExceptT . pure . Left $ "cols are not all the same length"
| innerDimCheck ncols ls = ExceptT . pure . Left $ "depths are not all the same length"
| otherwise = lift $ do
pure $ newWithStorage4d (fromList l) 0
(nrows, ncols * ndepth * ntime)
(ncols, ndepth * ntime)
(ndepth, ntime)
(ntime, 1)
where
l = concat (concat (concat ls))
go vec (SomeDims ds) = resizeDim_ vec ds >> pure vec
isEmpty = \case
[] -> True
[[]] -> True
[[[]]] -> True
[[[[]]]] -> True
_ -> False
innerDimCheck :: Int -> [[x]] -> Bool
innerDimCheck d = any ((/= d) . length)
ntime :: Integral i => i
ntime = genericLength (head (head (head ls)))
ndepth :: Integral i => i
ndepth = genericLength (head (head ls))
ncols :: Integral i => i
ncols = genericLength (head ls)
nrows :: Integral i => i
nrows = genericLength ls
getDimsList :: Integral i => Dynamic -> [i]
getDimsList t = map (fromIntegral . size t) [0 .. nDimension t - 1]
getSomeDims :: Dynamic -> SomeDims
getSomeDims = someDimsVal . getDimsList
new :: Dims (d::[Nat]) -> Dynamic
new d = case fromIntegral <$> listDims d of
[] -> empty
[x] -> newWithSize1d x
[x, y] -> newWithSize2d x y
[x, y, z] -> newWithSize3d x y z
[x, y, z, q] -> newWithSize4d x y z q
_ -> unsafeDupablePerformIO $ do
let t = empty
resizeDim_ t d
pure t
{-# NOINLINE new #-}
setDim'_ :: Dynamic -> SomeDims -> HsReal -> IO ()
setDim'_ t (SomeDims d) v = setDim_ t d v
resizeDim'_ :: Dynamic -> SomeDims -> IO ()
resizeDim'_ t (SomeDims d) = resizeDim_ t d
new' :: SomeDims -> Dynamic
new' (SomeDims d) = new d
resizeAs
:: Dynamic
-> Dynamic
-> IO Dynamic
resizeAs src shape = do
let res = newClone src
resizeAs_ res shape
pure res
withInplace :: (Dynamic -> IO ()) -> Dims (d::[Nat]) -> IO Dynamic
withInplace op d =
let
r = new d
in op r >> pure r
withInplace' :: (Dynamic -> IO ()) -> SomeDims -> IO Dynamic
withInplace' op (SomeDims d) = withInplace op d
twice :: Dynamic -> (Dynamic -> Dynamic -> IO ()) -> IO Dynamic
twice t op = op t t >> pure t
withEmpty' :: (Dynamic -> IO ()) -> IO Dynamic
withEmpty' op = let r = empty in op r >> pure r
instance IsList Dynamic where
type Item Dynamic = HsReal
toList = tensordata
fromList l = newWithStorage1d (fromList l) 0 (genericLength l, 1)
instance Show Dynamic where
show t = vs ++ "\n" ++ desc
where
dims = getDimsList t
desc = describeTensor dims (Proxy @HsReal)
vs = showTensor
(unsafeGet1d t)
(unsafeGet2d t)
(unsafeGet3d t)
(unsafeGet4d t)
dims