{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# OPTIONS_GHC -fno-cse #-}
module Torch.Indef.Dynamic.Tensor.Math where
import Foreign hiding (new, with)
import Foreign.Ptr
import Control.Monad.Managed
import Data.Foldable (foldrM, foldlM)
import Numeric.Dimensions
import System.IO.Unsafe
import qualified Foreign.Marshal as FM
import Debug.Trace
import Data.List (intercalate)
import Data.List.NonEmpty (NonEmpty((:|)))
import qualified Data.List.NonEmpty as NE
import Data.Vector (Vector)
import qualified Data.Vector as V
import Torch.Indef.Dynamic.Tensor
import Torch.Indef.Types
import qualified Torch.Indef.Index as Ix
import qualified Torch.Sig.Tensor.Math as Sig
import qualified Torch.Sig.Types as Sig
import qualified Torch.Sig.State as Sig
import qualified Torch.Types.TH as TH (IndexStorage)
fill_ :: Dynamic -> HsReal -> IO ()
fill_ t v = runManaged $ do
s' <- managedState
t' <- managedTensor t
liftIO $ Sig.c_fill s' t' (hs2cReal v)
zero_ :: Dynamic -> IO ()
zero_ t = runManaged $ do
s' <- managedState
t' <- managedTensor t
liftIO $ Sig.c_zero s' t'
zeros_ :: Dynamic -> IndexStorage -> IO ()
zeros_ t ix = runManaged $ do
s' <- managedState
t' <- managedTensor t
liftIO $ Sig.c_zero s' t'
zerosLike_
:: Dynamic
-> Dynamic
-> IO ()
zerosLike_ t0 t1 = with2DynamicState t0 t1 Sig.c_zerosLike
ones_ :: Dynamic -> TH.IndexStorage -> IO ()
ones_ t ix = runManaged $ do
s' <- managedState
t' <- managedTensor t
ix' <- managed $ Ix.withCPUIxStorage ix
liftIO $ Sig.c_ones s' t' ix'
onesLike_
:: Dynamic
-> Dynamic
-> IO ()
onesLike_ t0 t1 = with2DynamicState t0 t1 Sig.c_onesLike
numel :: Dynamic -> Integer
numel t = fromIntegral . unsafePerformIO $ withLift $ Sig.c_numel
<$> managedState
<*> managedTensor t
{-# NOINLINE numel #-}
_reshape :: Dynamic -> Dynamic -> TH.IndexStorage -> IO ()
_reshape t0 t1 ix = with2DynamicState t0 t1 $ \s' t0' t1' -> Ix.withCPUIxStorage ix $ \ix' ->
Sig.c_reshape s' t0' t1' ix'
{-# NOINLINE catArray #-}
catArray :: NonEmpty Dynamic -> Word -> Either String Dynamic
catArray ts dv =
case catDims ts dv of
Left msg -> Left msg
Right ds -> unsafePerformIO $ do
let r = new' (someDimsVal ds)
_catArray r ts dv
pure $ Right r
_catArray
:: Dynamic
-> NonEmpty Dynamic
-> Word
-> IO ()
_catArray res ds d = runManaged $ do
s' <- managedState
r' <- managedTensor res
liftIO $ do
ds' <- FM.newArray =<< mapM (\d -> withForeignPtr (ctensor d) pure) (NE.toList ds)
Sig.c_catArray s' r' ds' (fromIntegral $ length ds) (fromIntegral d)
_tril :: Dynamic -> Dynamic -> Integer -> IO ()
_tril t0 t1 i0 = withLift $ Sig.c_tril
<$> managedState
<*> managedTensor t0
<*> managedTensor t1
<*> pure (fromInteger i0)
_triu :: Dynamic -> Dynamic -> Integer -> IO ()
_triu t0 t1 i0 = withLift $ Sig.c_triu
<$> managedState
<*> managedTensor t0
<*> managedTensor t1
<*> pure (fromInteger i0)
_cat :: Dynamic -> Dynamic -> Dynamic
-> Word
-> IO ()
_cat t0 t1 t2 i = runManaged $ do
s' <- managedState
t0' <- managedTensor t0
t1' <- managedTensor t1
t2' <- managedTensor t2
liftIO $ Sig.c_cat s' t0' t1' t2' (fromIntegral i)
{-# NOINLINE cat #-}
cat :: Dynamic -> Dynamic -> Word -> Either String Dynamic
cat t0 t1 dv =
case catDims (t0:|[t1]) dv of
Left msg -> Left msg
Right ds -> unsafePerformIO $ do
let r = new' (someDimsVal ds)
_cat r t0 t1 dv
pure $ Right r
where
iv = fromIntegral dv
s0 = shape t0
s1 = shape t1
catDims :: NonEmpty Dynamic -> Word -> Either String [Word]
catDims ts dv
| any ((length s /=) . length) ss = Left "Dimensions must all be same length."
| all ((ix >=) . length) shapes = Left "Cat dimension must exist on tensors."
| otherwise =
case foldlM go 0 ss of
Nothing -> Left $
"Dimensionality error: all dimensions must match except in the cat-dimensions. " ++
"Dimensions include: " ++ intercalate ", " (show <$> s:ss) ++ "."
Just cd -> Right (V.toList $ s V.// [(ix, cd)])
where
ix :: Int
ix = fromIntegral dv
shapes :: NonEmpty (Vector Word)
shapes@(s:|ss) = fmap (V.fromList . shape) ts
go :: Word -> Vector Word -> Maybe Word
go catdim nxt =
if length s == length (V.ifilter (\i' j -> s V.! i' == j || i' == ix) nxt)
then pure $ catdim + nxt V.! ix
else Nothing
_nonzero :: IndexDynamic -> Dynamic -> IO ()
_nonzero ix t = runManaged $ do
s' <- managedState
t' <- managedTensor t
liftIO $ Ix.withDynamicState ix $ \_ ix' -> Sig.c_nonzero s' ix' t'
ttrace :: Dynamic -> HsAccReal
ttrace t = unsafePerformIO . flip with (pure . c2hsAccReal) $ do
s' <- managedState
t' <- managedTensor t
liftIO $ Sig.c_trace s' t'
{-# NOINLINE ttrace #-}
eye_
:: Dynamic
-> Integer
-> Integer
-> IO ()
eye_ t0 l0 l1 = runManaged $ do
s' <- managedState
t0' <- managedTensor t0
liftIO $ Sig.c_eye s' t0' (fromIntegral l0) (fromIntegral l1)
_arange :: Dynamic -> HsAccReal -> HsAccReal -> HsAccReal -> IO ()
_arange t0 a0 a1 a2 = runManaged $ do
s' <- managedState
t0' <- managedTensor t0
liftIO $ Sig.c_arange s' t0' (hs2cAccReal a0) (hs2cAccReal a1) (hs2cAccReal a2)
arange :: HsAccReal -> HsAccReal -> HsAccReal -> Dynamic
arange a0 a1 a2 = unsafePerformIO $ do
let t = empty
_arange t a0 a1 a2
return t
{-# NOINLINE arange #-}
range_
:: Dynamic
-> HsAccReal
-> HsAccReal
-> HsAccReal
-> IO ()
range_ t0 a0 a1 a2 = runManaged $ do
s' <- managedState
t0' <- managedTensor t0
liftIO $ Sig.c_range s' t0' (hs2cAccReal a0) (hs2cAccReal a1) (hs2cAccReal a2)
range
:: Dims (d::[Nat])
-> HsAccReal
-> HsAccReal
-> HsAccReal
-> Dynamic
range d a b c = unsafePerformIO $ withInplace (\r -> range_ r a b c) d
{-# NOINLINE range #-}
constant :: Dims (d :: [Nat]) -> HsReal -> Dynamic
constant d v = unsafePerformIO $ let r = new d in fill_ r v >> pure r
{-# NOINLINE constant #-}
_diag :: Dynamic -> Dynamic -> Int -> IO ()
_diag t0 t1 i0 = with2DynamicState t0 t1 $ \s' t0' t1' -> Sig.c_diag s' t0' t1' (fromIntegral i0)
diag_ :: Dynamic -> Int -> IO ()
diag_ t d = _diag t t d
diag :: Dynamic -> Int -> Dynamic
diag t d = unsafePerformIO $ let r = new' (getSomeDims t) in _diag r t d >> pure r
{-# NOINLINE diag #-}
diag1d :: Dynamic -> Dynamic
diag1d t = diag t 1
_tenLike
:: (Dynamic -> Dynamic -> IO ())
-> Dims (d::[Nat]) -> IO Dynamic
_tenLike _fn d = do
let
src = new d
shape = new d
_fn src shape
pure src
{-# WARNING _tenLike "this should not be exported outside of hasktorch" #-}
onesLike :: Dims (d::[Nat]) -> Dynamic
onesLike = unsafePerformIO . _tenLike onesLike_
{-# NOINLINE onesLike #-}
zerosLike :: Dims (d::[Nat]) -> Dynamic
zerosLike = unsafePerformIO . _tenLike zerosLike_
{-# NOINLINE zerosLike #-}