-------------------------------------------------------------------------------
-- |
-- Module    :  Torch.Indef.Static.Tensor.Math
-- Copyright :  (c) Sam Stites 2017
-- License   :  BSD3
-- Maintainer:  sam@stites.io
-- Stability :  experimental
-- Portability: non-portable
-------------------------------------------------------------------------------
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# OPTIONS_GHC -fno-cse #-}
module Torch.Indef.Static.Tensor.Math where

import Numeric.Dimensions -- hiding (Length)

import Torch.Indef.Types
import Torch.Indef.Static.Tensor
import System.IO.Unsafe
import Data.Singletons.Prelude (fromSing)
import Data.List.NonEmpty (NonEmpty)
import Data.Either (fromRight)
import qualified Data.Singletons.Prelude.List as Sing hiding (All, type (++))
import qualified Torch.Types.TH as TH
import qualified Torch.Indef.Dynamic.Tensor.Math as Dynamic

-- | Static call to 'Dynamic.fill_'
fill_ r = Dynamic.fill_ (asDynamic r)
-- | Static call to 'Dynamic.zero_'
zero_ r = Dynamic.zero_ (asDynamic r)

-- | mutate a tensor, inplace, resizing the tensor to the given IndexStorage
-- size and replacing its value with zeros.
zeros_ :: Tensor d -> IndexStorage -> IO ()
zeros_ t0 ix = Dynamic.zeros_ (asDynamic t0) ix

-- | mutate a tensor, inplace, resizing the tensor to the same shape as the second tensor argument
-- and replacing the first tensor's values with zeros.
zerosLike_
  :: Tensor d  -- ^ tensor to mutate inplace and replace contents with zeros
  -> Tensor d'  -- ^ tensor to extract shape information from.
  -> IO ()
zerosLike_ t0 t1 = Dynamic.zerosLike_ (asDynamic t0) (asDynamic t1)

-- | mutate a tensor, inplace, resizing the tensor to the given IndexStorage
-- size and replacing its value with ones.
ones_ :: Tensor d -> TH.IndexStorage -> IO ()
ones_ t0 ix = Dynamic.ones_ (asDynamic t0) ix

-- | mutate a tensor, inplace, resizing the tensor to the same shape as the second tensor argument
-- and replacing the first tensor's values with ones.
onesLike_
  :: Tensor d  -- ^ tensor to mutate inplace and replace contents with ones
  -> Tensor d'  -- ^ tensor to extract shape information from.
  -> IO ()
onesLike_ t0 t1 = Dynamic.onesLike_ (asDynamic t0) (asDynamic t1)


-- | Static call to 'Dynamic.numel'
numel t = Dynamic.numel (asDynamic t)
-- | Static call to 'Dynamic._reshape'
_reshape r t = Dynamic._reshape (asDynamic r) (asDynamic t)
-- | Static call to 'Dynamic._catArray'
_catArray r = Dynamic._catArray (asDynamic r)
-- | Static call to 'Dynamic._nonzero'
_nonzero r t = Dynamic._nonzero (longAsDynamic r) (asDynamic t)
-- | Static call to 'Dynamic._tril'
_tril r t = Dynamic._tril (asDynamic r) (asDynamic t)
-- | Static call to 'Dynamic._triu'
_triu r t = Dynamic._triu (asDynamic r) (asDynamic t)
-- | Static call to 'Dynamic.eye_'
eye_ r = Dynamic.eye_ (asDynamic r)

-- | Returns the trace (sum of the diagonal elements) of a matrix x. This is
-- equal to the sum of the eigenvalues of x.
--
-- Static call to 'Dynamic.ttrace'
ttrace r = Dynamic.ttrace (asDynamic r)
-- | Identical to a direct C call to the @arange@, or @range@ with special consideration for floating precision types. Static call to 'Dynamic._arange'
_arange r = Dynamic._arange (asDynamic r)

-- | Static call to 'Dynamic.range_'
range_ r = Dynamic.range_ (asDynamic r)

-- | Static call to 'Dynamic.constant'
constant :: forall d . Dimensions d => HsReal -> Tensor d
constant = asStatic . Dynamic.constant (dims :: Dims d)

-- | Static call to 'Dynamic.diag_'
diag_ :: All Dimensions '[d, d'] => Tensor d -> Int -> IO (Tensor d')
diag_ t d = do
  Dynamic.diag_ (asDynamic t) d
  pure $ (asStatic . asDynamic) t

-- | Static call to 'Dynamic.diag'
diag :: All Dimensions '[d, d'] => Tensor d -> Int -> Tensor d'
diag t d = asStatic $ Dynamic.diag (asDynamic t) d

-- | Create a diagonal matrix from a 1D vector
diag1d :: (KnownDim n) => Tensor '[n] -> Tensor '[n, n]
diag1d t = diag t 0

-- | Static call to 'Dynamic.cat_'. Unsafely returning the resulting tensor with new dimensions.
cat_
  :: All Dimensions '[d, d', d'']
  => Tensor d -> Tensor d' -> Word -> IO (Tensor d'')
cat_ a b d = do
  Dynamic._cat (asDynamic a) (asDynamic a) (asDynamic b) d
  pure (asStatic (asDynamic a))
{-# WARNING cat_ "this function is impure and the dimensions can fall out of sync with the type, if used incorrectly" #-}

-- | Static call to 'Dynamic.cat'
cat
  :: '(ls, r0:+rs) ~ Sing.SplitAt i d
  => '(ls, r1:+rs) ~ Sing.SplitAt i d'
  => Tensor d
  -> Tensor d'
  -> Dim (i::Nat)
  -> Tensor (ls ++ '[r0 + r1] ++ rs)
cat a b d = fromRight (error "impossible: cat type should not allow this branch") $
  asStatic <$> Dynamic.cat (asDynamic a) (asDynamic b) (fromIntegral $ dimVal d)

-- | convenience function, specifying a type-safe 'cat' operation.
cat1d
  :: (All KnownDim '[n1,n2,n], n ~ Sing.Sum [n1, n2])
  => Tensor '[n1] -> Tensor '[n2] -> Tensor '[n]
cat1d a b = cat a b (dim :: Dim 0)

-- | convenience function, specifying a type-safe 'cat' operation.
cat2d0 :: (All KnownDim '[n,m,n0,n1], n ~ Sing.Sum [n0, n1]) => Tensor '[n0, m] -> Tensor '[n1, m] -> Tensor '[n, m]
cat2d0 a b = cat a b (dim :: Dim 0)

-- | convenience function, stack two rank-1 tensors along the 0-dimension
stack1d0 :: KnownDim m => Tensor '[m] -> Tensor '[m] -> (Tensor '[2, m])
stack1d0 a b = cat2d0
  (unsqueeze1d (dim :: Dim 0) a)
  (unsqueeze1d (dim :: Dim 0) b)

-- | convenience function, specifying a type-safe 'cat' operation.
cat2d1 :: (All KnownDim '[n,m,m0,m1], m ~ Sing.Sum [m0, m1]) => Tensor '[n, m0] -> Tensor '[n, m1] -> (Tensor '[n, m])
cat2d1 a b = cat a b (dim :: Dim 1)

-- | convenience function, stack two rank-1 tensors along the 1-dimension
stack1d1 :: KnownDim n => Tensor '[n] -> Tensor '[n] -> (Tensor '[n, 2])
stack1d1 a b = cat2d1
  (unsqueeze1d (dim :: Dim 1) a)
  (unsqueeze1d (dim :: Dim 1) b)

-- | convenience function, specifying a type-safe 'cat' operation.
cat3d0
  :: (All KnownDim '[x,y,x0,x1,z], x ~ Sing.Sum [x0, x1])
  => Tensor '[x0, y, z]
  -> Tensor '[x1, y, z]
  -> (Tensor '[x, y, z])
cat3d0 a b = cat a b (dim :: Dim 0)

-- | convenience function, specifying a type-safe 'cat' operation.
cat3d1
  :: (All KnownDim '[x,y,y0,y1,z], y ~ Sing.Sum [y0, y1])
  => Tensor '[x, y0, z]
  -> Tensor '[x, y1, z]
  -> (Tensor '[x, y, z])
cat3d1 a b = cat a b (dim :: Dim 1)

-- | convenience function, specifying a type-safe 'cat' operation.
cat3d2
  :: (All KnownDim '[x,y,z0,z1,z], z ~ Sing.Sum [z0, z1])
  => Tensor '[x, y, z0]
  -> Tensor '[x, y, z1]
  -> (Tensor '[x, y, z])
cat3d2 a b = cat a b (dim :: Dim 2)

-- | Concatenate all tensors in a given list of dynamic tensors along the given dimension.
--
-- NOTE: In C, if the dimension is not specified or if it is -1, it is the maximum
-- last dimension over all input tensors, except if all tensors are empty, then it is 1.
catArray
  :: (Dimensions d)
  => NonEmpty Dynamic
  -> Word
  -> Either String (Tensor d)
catArray ts dv = asStatic <$> Dynamic.catArray ts dv

-- | Concatenate all tensors in a given list of dynamic tensors along the given dimension.
-- --
-- -- NOTE: In C, if the dimension is not specified or if it is -1, it is the maximum
-- -- last dimension over all input tensors, except if all tensors are empty, then it is 1.
-- catArray0
--   :: forall d ls rs r0 r1 i
--   .  Dimensions d
--   => '([], r0:+rs) ~ Sing.SplitAt i d
--   => (forall _i . [Tensor (_i+:rs)])
--   -> IO (Tensor (r0+:rs))
-- catArray0 ts dv = catArray (asDynamic <$> ts) (dimVal dv)


-- | Concatenate all tensors in a given list of dynamic tensors along the given dimension.
--
-- NOTE: In C, if the dimension is not specified or if it is -1, it is the maximum
-- last dimension over all input tensors, except if all tensors are empty, then it is 1.
catArray'
  :: forall d ls rs r0 r1 i
  .  Dimensions d
  => '(ls, r0:+rs) ~ Sing.SplitAt i d
  => d ~ (ls ++ '[r0] ++ rs)
  => (forall _i . NonEmpty (Tensor (ls ++ '[_i] ++ rs)))
  -> Dim i
  -> Either String (Tensor d)
catArray' ts dv = catArray (asDynamic <$> ts) (dimVal dv)

catArray0 :: (Dimensions d, Dimensions d2) => NonEmpty (Tensor d2) -> Either String (Tensor d)
catArray0 ts = catArray (asDynamic <$> ts) 0

{-
catArray_
  :: forall d ls rs out n
  .  All Dimensions '[out]
  => out ~ (rs ++ '[Length '[Tensor d]] ++ ls)
  => '(ls, rs) ~ Sing.SplitAt n d

  => Sing.SList '[Tensor d]
  -> Dim n
  -> IO (Tensor out)
catArray_ ts dv
  = -- fmap asStatic
    catArray
    (asDynamic <$> (fromSing ts :: [Tensor d]))
    (fromIntegral $ dimVal dv)

-- data Sing (z :: [a]) where
--     SNil :: Sing ([] :: [k])
--     SCons :: Sing (n ': n)

singToList :: forall k ks k2 x . Sing.SList '[x] -> [x]
singToList sl = go [] sl
 where
  go :: [x] -> Sing.SList '[x] -> [x]
  -- go acc Sing.SNil = acc
  -- go acc (Sing.SNil :: Sing.SList ('[] :: [x])) = acc
  -- go acc (Sing.SConst :: Sing.Sing '[]) = acc
  go acc (Sing.SCons k ks) = go acc ks

    -- | fromSing (Sing.sNull sl) = reverse acc
    -- | otherwise = go (fromSing (Sing.sHead sl):acc) (Sing.sTail acc)
  -- Sing.SNil = reverse acc
  -- go acc (Sing.SCons sval rest) = go (fromSing sval:acc) rest
-}

-- | Static call to 'Dynamic.onesLike'
onesLike :: forall d . Dimensions d => (Tensor d)
onesLike = asStatic $ Dynamic.onesLike (dims :: Dims d)

-- | Static call to 'Dynamic.zerosLike'
zerosLike :: forall d . Dimensions d => (Tensor d)
zerosLike = asStatic $ Dynamic.zerosLike (dims :: Dims d)