{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE CPP #-}
#if MIN_VERSION_base(4,12,0)
{-# LANGUAGE NoStarIsType #-}
#endif
{-# OPTIONS_GHC -fno-cse -Wno-deprecations #-}
module Torch.Indef.Static.Tensor where
import Control.Exception.Safe
import Control.Monad
import Control.Monad.Trans
import Control.Monad.Trans.Maybe
import Data.Coerce
import Data.Maybe
import Data.List
import Data.Singletons.Prelude.List hiding (All, type (++))
import Data.Proxy
import GHC.Natural
import System.IO.Unsafe
import GHC.TypeLits
import Numeric.Dimensions
import Control.Monad.Trans.Except
import Torch.Indef.Types
import Torch.Indef.Index
import Torch.Indef.Static.Tensor.Copy
import qualified Torch.Indef.Dynamic.Tensor as Dynamic
import qualified Torch.Types.TH as TH
import qualified Torch.FFI.TH.Long.Storage as TH
import qualified Torch.Sig.Types as Sig
instance Show (Tensor d) where
show t = show (asDynamic t)
scalar :: HsReal -> Tensor '[1]
scalar = unsafeDupablePerformIO . unsafeVector . (:[])
{-# NOINLINE scalar #-}
vector :: forall n . KnownDim n => KnownNat n => [HsReal] -> ExceptT String IO (Tensor '[n])
vector rs
| genericLength rs == dimVal (dim :: Dim n) = asStatic <$> Dynamic.vectorEIO rs
| otherwise = ExceptT . pure $ Left "Vector dimension does not match length of list"
unsafeVector :: (KnownDim n, KnownNat n) => [HsReal] -> IO (Tensor '[n])
unsafeVector = fmap (either error id) . runExceptT . vector
newExpand :: Tensor d -> TH.IndexStorage -> Tensor d'
newExpand t = asStatic . Dynamic.newExpand (asDynamic t)
_expand r t = Dynamic._expand (asDynamic r) (asDynamic t)
_expandNd rs os = Dynamic._expandNd (fmap asDynamic rs) (fmap asDynamic os)
_resize t a b = Dynamic._resize (asDynamic t) a b >> pure ((asStatic . asDynamic) t)
resize1d_ t a = Dynamic.resize1d_ (asDynamic t) a >> pure ((asStatic . asDynamic) t)
resize2d_ t a b = Dynamic.resize2d_ (asDynamic t) a b >> pure ((asStatic . asDynamic) t)
resize3d_ t a b c = Dynamic.resize3d_ (asDynamic t) a b c >> pure ((asStatic . asDynamic) t)
resize4d_ t a b c d = Dynamic.resize4d_ (asDynamic t) a b c d >> pure ((asStatic . asDynamic) t)
resize5d_ t a b c d e = Dynamic.resize5d_ (asDynamic t) a b c d e >> pure ((asStatic . asDynamic) t)
resizeAsT_ src tar = Dynamic.resizeAs_ (asDynamic src) (asDynamic tar) >> pure ((asStatic . asDynamic) src)
resizeNd_ src a b c = Dynamic.resizeNd_ (asDynamic src) a b c >> pure ((asStatic . asDynamic) src)
retain t = Dynamic.retain (asDynamic t)
_clearFlag t = Dynamic._clearFlag (asDynamic t)
#ifndef HASKTORCH_CORE_CUDA
tensordata t = Dynamic.tensordata (asDynamic t)
#endif
get1d t = Dynamic.get1d (asDynamic t)
get2d t = Dynamic.get2d (asDynamic t)
get3d t = Dynamic.get3d (asDynamic t)
get4d t = Dynamic.get4d (asDynamic t)
isContiguous t = Dynamic.isContiguous (asDynamic t)
isSetTo t0 t1 = Dynamic.isSetTo (asDynamic t0) (asDynamic t1)
isSize t = Dynamic.isSize (asDynamic t)
nDimension t = Dynamic.nDimension (asDynamic t)
nElement t = Dynamic.nElement (asDynamic t)
_narrow t0 t1 = Dynamic._narrow (asDynamic t0) (asDynamic t1)
empty = asStatic Dynamic.empty
newClone :: Tensor d -> Tensor d
newClone t = asStatic $ Dynamic.newClone (asDynamic t)
newContiguous t = asStatic $ Dynamic.newContiguous (asDynamic t)
newNarrow t a b c = asStatic <$> Dynamic.newNarrow (asDynamic t) a b c
newSelect
:: KnownDim i
=> '(ls, r:+rs) ~ SplitAt i d
=> Tensor d
-> (Dim i, Idx i)
-> IO (Tensor (ls ++ rs))
newSelect t (d, i) =
asStatic <$>
Dynamic.newSelect
(asDynamic t)
(fromIntegral (dimVal d))
(fromIntegral (fromEnum i))
newSizeOf t = Dynamic.newSizeOf (asDynamic t)
newStrideOf t = Dynamic.newStrideOf (asDynamic t)
newTranspose t a b = asStatic $ Dynamic.newTranspose (asDynamic t) a b
newUnfold t a b c = asStatic $ Dynamic.newUnfold (asDynamic t) a b c
view :: forall d d' . (Dimensions d, Dimensions d') => Tensor d -> IO (Tensor d')
view src = do
longs <- ixCPUStorage $ fromIntegral <$> listDims (dims :: Dims d)
asStatic <$> Dynamic.newView (asDynamic src) longs
newWithSize a0 a1 = asStatic $ Dynamic.newWithSize a0 a1
newWithSize1d a0 = asStatic $ Dynamic.newWithSize1d a0
newWithSize2d a0 a1 = asStatic $ Dynamic.newWithSize2d a0 a1
newWithSize3d a0 a1 a2 = asStatic $ Dynamic.newWithSize3d a0 a1 a2
newWithSize4d a0 a1 a2 a3 = asStatic $ Dynamic.newWithSize4d a0 a1 a2 a3
newWithStorage a0 a1 a2 a3 = asStatic $ Dynamic.newWithStorage a0 a1 a2 a3
newWithStorage1d a0 a1 a2 = asStatic $ Dynamic.newWithStorage1d a0 a1 a2
newWithStorage2d a0 a1 a2 a3 = asStatic $ Dynamic.newWithStorage2d a0 a1 a2 a3
newWithStorage3d a0 a1 a2 a3 a4 = asStatic $ Dynamic.newWithStorage3d a0 a1 a2 a3 a4
newWithStorage4d a0 a1 a2 a3 a4 a5 = asStatic $ Dynamic.newWithStorage4d a0 a1 a2 a3 a4 a5
newWithTensor t = asStatic <$> Dynamic.newWithTensor (asDynamic t)
_select t0 t1 = Dynamic._select (asDynamic t0) (asDynamic t1)
_set t0 t1 = Dynamic._set (asDynamic t0) (asDynamic t1)
set1d_ t = Dynamic.set1d_ (asDynamic t)
set2d_ t = Dynamic.set2d_ (asDynamic t)
set3d_ t = Dynamic.set3d_ (asDynamic t)
set4d_ t = Dynamic.set4d_ (asDynamic t)
setFlag_ t = Dynamic.setFlag_ (asDynamic t)
setStorage_ t = Dynamic.setStorage_ (asDynamic t)
setStorage1d_ t = Dynamic.setStorage1d_ (asDynamic t)
setStorage2d_ t = Dynamic.setStorage2d_ (asDynamic t)
setStorage3d_ t = Dynamic.setStorage3d_ (asDynamic t)
setStorage4d_ t = Dynamic.setStorage4d_ (asDynamic t)
setStorageNd_ t = Dynamic.setStorageNd_ (asDynamic t)
size t = Dynamic.size (asDynamic t)
sizeDesc t = Dynamic.sizeDesc (asDynamic t)
_squeeze t0 t1 = Dynamic._squeeze (asDynamic t0) (asDynamic t1)
squeeze1d
:: Dimensions d
=> '(rs, 1:+ls) ~ (SplitAt n d)
=> Dim n
-> Tensor d
-> Tensor (rs ++ ls)
squeeze1d n t = unsafeDupablePerformIO $ squeeze1d_ n (newClone t)
{-# NOINLINE squeeze1d #-}
squeeze1d_
:: Dimensions d
=> '(rs, 1:+ls) ~ (SplitAt n d)
=> Dim n
-> Tensor d
-> IO (Tensor (rs ++ ls))
squeeze1d_ n t = do
let t' = asDynamic t
Dynamic.squeeze1d_ t' (fromIntegral (dimVal n))
pure (asStatic t')
storage t = Dynamic.storage (asDynamic t)
storageOffset t = Dynamic.storageOffset (asDynamic t)
stride t = Dynamic.stride (asDynamic t)
_transpose t0 t1 = Dynamic._transpose (asDynamic t0) (asDynamic t1)
_unfold t0 t1 = Dynamic._unfold (asDynamic t0) (asDynamic t1)
unsqueeze1d
:: Dimensions d
=> '(rs, ls) ~ (SplitAt n d)
=> Dim n
-> Tensor d
-> Tensor (rs ++ '[1] ++ ls)
unsqueeze1d n t = unsafeDupablePerformIO $ unsqueeze1d_ n (newClone t)
{-# NOINLINE unsqueeze1d #-}
unsqueeze1d_
:: Dimensions d
=> '(rs, ls) ~ (SplitAt n d)
=> Dim n
-> Tensor d
-> IO (Tensor (rs ++ '[1] ++ ls))
unsqueeze1d_ n t = do
Dynamic.unsqueeze1d_ (asDynamic t) (fromIntegral (dimVal n))
pure (asStatic (asDynamic t))
shape :: Tensor d -> [Word]
shape t = Dynamic.shape (asDynamic t)
getSomeDims :: Tensor d -> SomeDims
getSomeDims = someDimsVal . shape
withInplace :: (Dimensions d) => Tensor d -> (Tensor d -> Tensor d -> IO ()) -> IO (Tensor d)
withInplace t op = op t t >> pure t
{-# DEPRECATED withInplace "this is a trivial function with a bad API" #-}
throwFIXME :: MonadThrow io => String -> String -> io x
throwFIXME fixme msg = throwString $ msg ++ " (FIXME: " ++ fixme ++ ")"
throwNE :: MonadThrow io => String -> io x
throwNE = throwFIXME "make this function only take a non-empty [Nat]"
throwGT4 :: MonadThrow io => String -> io x
throwGT4 fnname = throwFIXME
("review how TH supports `" ++ fnname ++ "` operations on > rank-4 tensors")
(fnname ++ " with >4 rank")
setStorageDim_ :: Tensor d -> Storage -> StorageOffset -> [(Size, Stride)] -> IO ()
setStorageDim_ t s o = Dynamic.setStorageDim_ (asDynamic t) s o
setDim_ :: Tensor d -> Dims (d'::[Nat]) -> HsReal -> IO ()
setDim_ t = Dynamic.setDim_ (asDynamic t)
setDim'_ :: Tensor d -> SomeDims -> HsReal -> IO ()
setDim'_ t (SomeDims d) = setDim_ t d
getDim
:: forall d i d'
. All Dimensions '[d, i:+d']
=> Tensor (d::[Nat])
-> Dims ((i:+d')::[Nat])
-> Maybe HsReal
getDim t d = Dynamic.getDim (asDynamic t) d
(!!)
:: forall d ls r rs i
. '(ls, r:+rs) ~ SplitAt i d
=> KnownDim i
=> Dimensions d
=> Tensor d
-> Dim i
-> Tensor (ls ++ rs)
t !! i = unsafePerformIO $
case nDimension t of
0 -> pure empty
1 -> fromMaybe empty <$> runMaybeT selectVal
_ -> newSelect t (i, Idx 1)
where
selectVal :: MaybeT IO (Tensor (ls ++ rs))
selectVal = do
guard (dimVal i < size t (dimVal i))
v <- MaybeT . pure $ get1d t (fromIntegral $ dimVal i)
lift $ do
let r = newWithSize1d 1
set1d_ r 0 v
pure r
{-# NOINLINE (!!) #-}
new :: forall d . Dimensions d => Tensor d
new = asStatic $ Dynamic.new (dims :: Dims d)
_resizeDim :: forall d d' . (Dimensions d') => Tensor d -> IO (Tensor d')
_resizeDim t = do
Dynamic.resizeDim_ (asDynamic t) (dims :: Dims d')
pure $ asStatic (asDynamic t)
resizeAs_ :: forall d d' . (All Dimensions '[d, d'], Product d ~ Product d') => Tensor d -> IO (Tensor d')
resizeAs_ src = resizeAsT_ src (new :: Tensor d')
resizeAs :: forall d d' . (All Dimensions [d,d'], Product d ~ Product d') => Tensor d -> Tensor d'
resizeAs src = unsafeDupablePerformIO $
resizeAsT_ (newClone src :: Tensor d) (new :: Tensor d')
{-# NOINLINE resizeAs #-}
flatten :: (Dimensions d, KnownDim (Product d)) => Tensor d -> Tensor '[Product d]
flatten = resizeAs
fromList
:: forall d . Dimensions d
=> KnownNat (Product d)
=> KnownDim (Product d)
=> [HsReal] -> IO (Maybe (Tensor d))
fromList l = runMaybeT $ do
evec <- lift $ runExceptT (vector l)
vec :: Tensor '[Product d] <-
case evec of
Left _ -> mzero
Right t -> pure t
guard (genericLength l == dimVal (dim :: Dim (Product d)))
lift $ _resizeDim vec
{-# NOINLINE fromList #-}
matrix
:: forall n m
. (All KnownDim '[n, m], All KnownNat '[n, m])
#if MIN_VERSION_singletons(2,4,0)
=> KnownDim (n*m) => KnownNat (n*m)
#else
=> KnownDim (n*:m) => KnownNat (n*:m)
#endif
=> [[HsReal]] -> ExceptT String IO (Tensor '[n, m])
matrix ls
| null ls = ExceptT . pure . Left $ "no support for empty lists"
| colLen /= mVal =
ExceptT . pure . Left $ "length of outer list "++show colLen++" must match type-level columns " ++ show mVal
| any (/= colLen) (fmap length ls) =
ExceptT . pure . Left $ "can't build a matrix from jagged lists: " ++ show (fmap length ls)
| rowLen /= nVal =
ExceptT . pure . Left $ "inner list length " ++ show rowLen ++ " must match type-level rows " ++ show nVal
| otherwise = asStatic <$> Dynamic.matrix ls
where
rowLen :: Integral l => l
rowLen = genericLength ls
colLen :: Integral l => l
colLen = genericLength (head ls)
nVal = dimVal (dim :: Dim n)
mVal = dimVal (dim :: Dim m)
unsafeMatrix
:: forall n m
. All KnownDim '[n, m, n*m]
=> All KnownNat '[n, m, n*m]
=> [[HsReal]] -> IO (Tensor '[n, m])
unsafeMatrix = fmap (either error id) . runExceptT . matrix
cuboid
:: forall c h w
. (All KnownDim '[c, h, w], All KnownNat '[c, h, w])
=> [[[HsReal]]] -> ExceptT String IO (Tensor '[c, h, w])
cuboid ls
| isEmpty ls = ExceptT . pure . Left $ "no support for empty lists"
| chan /= length ls = ExceptT . pure . Left $ "channels are not all of length " ++ show chan
| any (/= rows) (lens ls) = ExceptT . pure . Left $ "rows are not all of length " ++ show rows
| any (/= cols) (lens (concat ls)) = ExceptT . pure . Left $ "columns are not all of length " ++ show cols
| otherwise = asStatic <$> Dynamic.cuboid ls
where
isEmpty = \case
[] -> True
[[]] -> True
[[[]]] -> True
list -> null list || any null list || any (any null) list
chan = fromIntegral $ dimVal (dim :: Dim c)
rows = fromIntegral $ dimVal (dim :: Dim h)
cols = fromIntegral $ dimVal (dim :: Dim w)
lens = fmap length
innerDimCheck :: Int -> [Int] -> Bool
innerDimCheck d = any ((/= d))
unsafeCuboid
:: forall c h w
. All KnownDim '[c, h, w]
=> All KnownNat '[c, h, w]
=> [[[HsReal]]] -> IO (Tensor '[c, h, w])
unsafeCuboid = fmap (either error id) . runExceptT . cuboid
transpose2d :: (All KnownDim '[r,c]) => Tensor '[r, c] -> Tensor '[c, r]
transpose2d t = newTranspose t 1 0
expand2d
:: forall x y . (All KnownDim '[x, y])
=> Tensor '[x] -> Tensor '[y, x]
expand2d t = unsafeDupablePerformIO $ do
let res :: Tensor '[y, x] = new
s <- mkCPUIxStorage =<< TH.c_newWithSize2_ s2 s1
_expand res t s
pure res
where
s1 = fromIntegral $ dimVal (dim :: Dim x)
s2 = fromIntegral $ dimVal (dim :: Dim y)
{-# NOINLINE expand2d #-}
getElem2d
:: forall (n::Nat) (m::Nat) . (All KnownDim '[n, m])
=> Tensor '[n, m] -> Word -> Word -> Maybe (HsReal)
getElem2d t r c
| r > fromIntegral (dimVal (dim :: Dim n)) ||
c > fromIntegral (dimVal (dim :: Dim m))
= Nothing
| otherwise = get2d t (fromIntegral r) (fromIntegral c)
{-# DEPRECATED getElem2d "use getDim instead" #-}
setElem2d
:: forall (n::Nat) (m::Nat) ns . (All KnownDim '[n, m])
=> Tensor '[n, m] -> Word -> Word -> HsReal -> IO ()
setElem2d t r c v
| r > fromIntegral (dimVal (dim :: Dim n)) ||
c > fromIntegral (dimVal (dim :: Dim m))
= throwString "Indices out of bounds"
| otherwise = set2d_ t (fromIntegral r) (fromIntegral c) v
{-# DEPRECATED setElem2d "use setDim_ instead" #-}