{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# OPTIONS_GHC -fno-cse #-}
module Torch.Indef.Static.Tensor.Math where
import Numeric.Dimensions
import Torch.Indef.Types
import Torch.Indef.Static.Tensor
import System.IO.Unsafe
import Data.Singletons.Prelude (fromSing)
import Data.List.NonEmpty (NonEmpty)
import Data.Either (fromRight)
import qualified Data.Singletons.Prelude.List as Sing hiding (All, type (++))
import qualified Torch.Types.TH as TH
import qualified Torch.Indef.Dynamic.Tensor.Math as Dynamic
fill_ r = Dynamic.fill_ (asDynamic r)
zero_ r = Dynamic.zero_ (asDynamic r)
zeros_ :: Tensor d -> IndexStorage -> IO ()
zeros_ t0 ix = Dynamic.zeros_ (asDynamic t0) ix
zerosLike_
:: Tensor d
-> Tensor d'
-> IO ()
zerosLike_ t0 t1 = Dynamic.zerosLike_ (asDynamic t0) (asDynamic t1)
ones_ :: Tensor d -> TH.IndexStorage -> IO ()
ones_ t0 ix = Dynamic.ones_ (asDynamic t0) ix
onesLike_
:: Tensor d
-> Tensor d'
-> IO ()
onesLike_ t0 t1 = Dynamic.onesLike_ (asDynamic t0) (asDynamic t1)
numel t = Dynamic.numel (asDynamic t)
_reshape r t = Dynamic._reshape (asDynamic r) (asDynamic t)
_catArray r = Dynamic._catArray (asDynamic r)
_nonzero r t = Dynamic._nonzero (longAsDynamic r) (asDynamic t)
_tril r t = Dynamic._tril (asDynamic r) (asDynamic t)
_triu r t = Dynamic._triu (asDynamic r) (asDynamic t)
eye_ r = Dynamic.eye_ (asDynamic r)
ttrace r = Dynamic.ttrace (asDynamic r)
_arange r = Dynamic._arange (asDynamic r)
range_ r = Dynamic.range_ (asDynamic r)
constant :: forall d . Dimensions d => HsReal -> Tensor d
constant = asStatic . Dynamic.constant (dims :: Dims d)
diag_ :: All Dimensions '[d, d'] => Tensor d -> Int -> IO (Tensor d')
diag_ t d = do
Dynamic.diag_ (asDynamic t) d
pure $ (asStatic . asDynamic) t
diag :: All Dimensions '[d, d'] => Tensor d -> Int -> Tensor d'
diag t d = asStatic $ Dynamic.diag (asDynamic t) d
diag1d :: (KnownDim n) => Tensor '[n] -> Tensor '[n, n]
diag1d t = diag t 0
cat_
:: All Dimensions '[d, d', d'']
=> Tensor d -> Tensor d' -> Word -> IO (Tensor d'')
cat_ a b d = do
Dynamic._cat (asDynamic a) (asDynamic a) (asDynamic b) d
pure (asStatic (asDynamic a))
{-# WARNING cat_ "this function is impure and the dimensions can fall out of sync with the type, if used incorrectly" #-}
cat
:: '(ls, r0:+rs) ~ Sing.SplitAt i d
=> '(ls, r1:+rs) ~ Sing.SplitAt i d'
=> Tensor d
-> Tensor d'
-> Dim (i::Nat)
-> Tensor (ls ++ '[r0 + r1] ++ rs)
cat a b d = fromRight (error "impossible: cat type should not allow this branch") $
asStatic <$> Dynamic.cat (asDynamic a) (asDynamic b) (fromIntegral $ dimVal d)
cat1d
:: (All KnownDim '[n1,n2,n], n ~ Sing.Sum [n1, n2])
=> Tensor '[n1] -> Tensor '[n2] -> Tensor '[n]
cat1d a b = cat a b (dim :: Dim 0)
cat2d0 :: (All KnownDim '[n,m,n0,n1], n ~ Sing.Sum [n0, n1]) => Tensor '[n0, m] -> Tensor '[n1, m] -> Tensor '[n, m]
cat2d0 a b = cat a b (dim :: Dim 0)
stack1d0 :: KnownDim m => Tensor '[m] -> Tensor '[m] -> (Tensor '[2, m])
stack1d0 a b = cat2d0
(unsqueeze1d (dim :: Dim 0) a)
(unsqueeze1d (dim :: Dim 0) b)
cat2d1 :: (All KnownDim '[n,m,m0,m1], m ~ Sing.Sum [m0, m1]) => Tensor '[n, m0] -> Tensor '[n, m1] -> (Tensor '[n, m])
cat2d1 a b = cat a b (dim :: Dim 1)
stack1d1 :: KnownDim n => Tensor '[n] -> Tensor '[n] -> (Tensor '[n, 2])
stack1d1 a b = cat2d1
(unsqueeze1d (dim :: Dim 1) a)
(unsqueeze1d (dim :: Dim 1) b)
cat3d0
:: (All KnownDim '[x,y,x0,x1,z], x ~ Sing.Sum [x0, x1])
=> Tensor '[x0, y, z]
-> Tensor '[x1, y, z]
-> (Tensor '[x, y, z])
cat3d0 a b = cat a b (dim :: Dim 0)
cat3d1
:: (All KnownDim '[x,y,y0,y1,z], y ~ Sing.Sum [y0, y1])
=> Tensor '[x, y0, z]
-> Tensor '[x, y1, z]
-> (Tensor '[x, y, z])
cat3d1 a b = cat a b (dim :: Dim 1)
cat3d2
:: (All KnownDim '[x,y,z0,z1,z], z ~ Sing.Sum [z0, z1])
=> Tensor '[x, y, z0]
-> Tensor '[x, y, z1]
-> (Tensor '[x, y, z])
cat3d2 a b = cat a b (dim :: Dim 2)
catArray
:: (Dimensions d)
=> NonEmpty Dynamic
-> Word
-> Either String (Tensor d)
catArray ts dv = asStatic <$> Dynamic.catArray ts dv
catArray'
:: forall d ls rs r0 r1 i
. Dimensions d
=> '(ls, r0:+rs) ~ Sing.SplitAt i d
=> d ~ (ls ++ '[r0] ++ rs)
=> (forall _i . NonEmpty (Tensor (ls ++ '[_i] ++ rs)))
-> Dim i
-> Either String (Tensor d)
catArray' ts dv = catArray (asDynamic <$> ts) (dimVal dv)
catArray0 :: (Dimensions d, Dimensions d2) => NonEmpty (Tensor d2) -> Either String (Tensor d)
catArray0 ts = catArray (asDynamic <$> ts) 0
onesLike :: forall d . Dimensions d => (Tensor d)
onesLike = asStatic $ Dynamic.onesLike (dims :: Dims d)
zerosLike :: forall d . Dimensions d => (Tensor d)
zerosLike = asStatic $ Dynamic.zerosLike (dims :: Dims d)