-------------------------------------------------------------------------------
-- |
-- Module    :  Torch.Indef.Dynamic.Tensor.Math
-- Copyright :  (c) Sam Stites 2017
-- License   :  BSD3
-- Maintainer:  sam@stites.io
-- Stability :  experimental
-- Portability: non-portable
--
-- Torch provides MATLAB-like functions for manipulating Tensor objects.
-- Functions fall into several types of categories:
--
--   * Constructors like zeros, ones;
--   * Extractors like diag and triu;
--   * Element-wise mathematical operations like abs and pow;
--   * BLAS operations;
--   * Column or row-wise operations like sum and max;
--   * Matrix-wide operations like trace and norm;
--   * Convolution and cross-correlation operations like conv2;
--   * Basic linear algebra operations like eig;
--   * Logical operations on Tensors.
--
-- Unfortunately the above this comes from the Lua docs. Hasktorch doesn't
-- mimic this exactly and (FIXME) we will have to restructure this module
-- header to reflect these changes.
-------------------------------------------------------------------------------
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# OPTIONS_GHC -fno-cse #-}
module Torch.Indef.Dynamic.Tensor.Math where

import Foreign hiding (new, with)
import Foreign.Ptr
import Control.Monad.Managed
import Data.Foldable (foldrM, foldlM)
import Numeric.Dimensions
import System.IO.Unsafe
import qualified Foreign.Marshal as FM
import Debug.Trace
import Data.List (intercalate)
import Data.List.NonEmpty (NonEmpty((:|)))
import qualified Data.List.NonEmpty as NE
import Data.Vector (Vector)
import qualified Data.Vector as V

import Torch.Indef.Dynamic.Tensor
import Torch.Indef.Types
import qualified Torch.Indef.Index     as Ix
import qualified Torch.Sig.Tensor.Math as Sig
import qualified Torch.Sig.Types       as Sig
import qualified Torch.Sig.State       as Sig
import qualified Torch.Types.TH        as TH (IndexStorage)


-- | fill a dynamic tensor, inplace, with the given value.
fill_ :: Dynamic -> HsReal -> IO ()
fill_ t v = runManaged $ do
  s' <- managedState
  t' <- managedTensor t
  liftIO $ Sig.c_fill s' t' (hs2cReal v)

-- | mutate a tensor, inplace, filling it with zero values.
zero_ :: Dynamic -> IO ()
zero_ t = runManaged $ do
  s' <- managedState
  t' <- managedTensor t
  liftIO $ Sig.c_zero s' t'

-- | mutate a tensor, inplace, resizing the tensor to the given IndexStorage
-- size and replacing its value with zeros.
zeros_ :: Dynamic -> IndexStorage -> IO ()
zeros_ t ix = runManaged $ do
  s' <- managedState
  t' <- managedTensor t
  liftIO $ Sig.c_zero s' t'

-- | mutate a tensor, inplace, resizing the tensor to the same shape as the second tensor argument
-- and replacing the first tensor's values with zeros.
zerosLike_
  :: Dynamic  -- ^ tensor to mutate inplace and replace contents with zeros
  -> Dynamic  -- ^ tensor to extract shape information from.
  -> IO ()
zerosLike_ t0 t1 = with2DynamicState t0 t1 Sig.c_zerosLike

-- | mutate a tensor, inplace, resizing the tensor to the given IndexStorage
-- size and replacing its value with ones.
ones_ :: Dynamic -> TH.IndexStorage -> IO ()
ones_ t ix = runManaged $ do
  s' <- managedState
  t' <- managedTensor t
  ix' <- managed $ Ix.withCPUIxStorage ix
  liftIO $ Sig.c_ones s' t' ix'

-- | mutate a tensor, inplace, resizing the tensor to the same shape as the second tensor argument
-- and replacing the first tensor's values with ones.
onesLike_
  :: Dynamic  -- ^ tensor to mutate inplace and replace contents with ones
  -> Dynamic  -- ^ tensor to extract shape information from.
  -> IO ()
onesLike_ t0 t1 = with2DynamicState t0 t1 Sig.c_onesLike

-- | returns the count of the number of elements in the matrix.
numel :: Dynamic -> Integer
numel t = fromIntegral . unsafePerformIO $ withLift $ Sig.c_numel
  <$> managedState
  <*> managedTensor t
{-# NOINLINE numel #-}

-- |
-- @
--   _reshape y x (Ix.newStorage [m, n, k, l, o])
-- @
--
-- Mutates the @y@ dynamic tensor to be reshaped as a @m × n × k × l × o@ tensor whose elements are
-- taken rowwise from @x@, which must have @m * n * k * l * o@ elements. The elements are copied into
-- the new Tensor.
_reshape :: Dynamic -> Dynamic -> TH.IndexStorage -> IO ()
_reshape t0 t1 ix = with2DynamicState t0 t1 $ \s' t0' t1' -> Ix.withCPUIxStorage ix $ \ix' ->
  Sig.c_reshape s' t0' t1' ix'

-- | pure version of '_catArray'
{-# NOINLINE catArray #-}
catArray :: NonEmpty Dynamic -> Word -> Either String Dynamic
catArray ts dv =
  case catDims ts dv of
    Left msg -> Left msg
    Right ds -> unsafePerformIO $ do
      let r = new' (someDimsVal ds)
      _catArray r ts dv
      pure $ Right r

-- | Concatenate all tensors in a given list of dynamic tensors along the given dimension.
--
-- NOTE: In C, if the dimension is not specified or if it is -1, it is the maximum
-- last dimension over all input tensors, except if all tensors are empty, then it is 1.
--
-- C-Style: In the classic Torch C-style, the first argument is treated as the return type and is mutated in-place.
_catArray
  :: Dynamic   -- ^ result to mutate
  -> NonEmpty Dynamic -- ^ tensors to concatenate
  -> Word      -- ^ dimension to concatenate along.
  -> IO ()
_catArray res ds d = runManaged $ do
  s' <- managedState
  r' <- managedTensor res
  liftIO $ do
    ds' <- FM.newArray =<< mapM (\d -> withForeignPtr (ctensor d) pure) (NE.toList ds)
    Sig.c_catArray s' r' ds' (fromIntegral $ length ds) (fromIntegral d)

-- | "Get the lower triangle of a tensor."
--
-- Mutates the first tensor to have the triangular part of the second tensor under the Kth diagonal.
-- where k=0 is the main diagonal, k>0 is above the main diagonal, and k<0 is below the main diagonal.
-- All other elements are set to 0.
--
-- C-Style: In the classic Torch C-style, the first argument is treated as the return type and is mutated in-place.
_tril :: Dynamic -> Dynamic -> Integer -> IO ()
_tril t0 t1 i0 = withLift $ Sig.c_tril
  <$> managedState
  <*> managedTensor t0
  <*> managedTensor t1
  <*> pure (fromInteger i0)

-- | "Get the upper triangle of a tensor."
--
-- Mutates the first tensor to have the triangular part of the second tensor above the Kth diagonal.
-- where k=0 is the main diagonal, k>0 is above the main diagonal, and k<0 is below the main diagonal.
-- All other elements are set to 0.
--
-- C-Style: In the classic Torch C-style, the first argument is treated as the return type and is mutated in-place.
_triu :: Dynamic -> Dynamic -> Integer -> IO ()
_triu t0 t1 i0 = withLift $ Sig.c_triu
  <$> managedState
  <*> managedTensor t0
  <*> managedTensor t1
  <*> pure (fromInteger i0)

-- | Concatinate two dynamic tensors along the specified dimension, treating the
-- first argument as the return tensor, to be mutated in-place.
--
-- NOTE: In C, if the dimension is not specified or if it is -1, it is the maximum
-- last dimension over all input tensors, except if all tensors are empty, then it is 1.
--
-- C-Style: In the classic Torch C-style, the first argument is treated as the return type and is mutated in-place.
_cat :: Dynamic -> Dynamic -> Dynamic
  -> Word  -- ^ dimension to concatenate along
  -> IO ()
_cat t0 t1 t2 i = runManaged $ do
  s'  <- managedState
  t0' <- managedTensor t0
  t1' <- managedTensor t1
  t2' <- managedTensor t2
  liftIO $ Sig.c_cat s' t0' t1' t2' (fromIntegral i)

-- | pure version of '_cat'
{-# NOINLINE cat #-}
cat :: Dynamic -> Dynamic -> Word -> Either String Dynamic
cat t0 t1 dv =
  case catDims (t0:|[t1]) dv of
    Left msg -> Left msg
    Right ds -> unsafePerformIO $ do
      let r = new' (someDimsVal ds)
      _cat r t0 t1 dv
      pure $ Right r
 where
  iv = fromIntegral dv
  s0 = shape t0
  s1 = shape t1

catDims :: NonEmpty Dynamic -> Word -> Either String [Word]
catDims ts dv
  | any ((length s /=) . length) ss = Left "Dimensions must all be same length."
  | all ((ix >=) . length) shapes = Left "Cat dimension must exist on tensors."
  | otherwise =
  case foldlM go 0 ss of
    Nothing -> Left $
      "Dimensionality error: all dimensions must match except in the cat-dimensions. " ++
      "Dimensions include: " ++ intercalate ", " (show <$> s:ss) ++ "."
    Just cd -> Right (V.toList $ s V.// [(ix, cd)])
 where
  ix :: Int
  ix = fromIntegral dv

  shapes :: NonEmpty (Vector Word)
  shapes@(s:|ss) = fmap (V.fromList . shape) ts

  go :: Word -> Vector Word -> Maybe Word
  go catdim nxt =
    if length s == length (V.ifilter (\i' j -> s V.! i' == j || i' == ix) nxt)
    then pure $ catdim + nxt V.! ix
    else Nothing

-- | Finds and returns a LongTensor corresponding to the subscript indices of all non-zero elements in tensor.
--
-- C-Style: In the classic Torch C-style, the first argument is treated as the return type and is mutated in-place.
_nonzero :: IndexDynamic -> Dynamic -> IO ()
_nonzero ix t = runManaged $ do
  s' <- managedState
  t' <- managedTensor t
  liftIO $ Ix.withDynamicState ix $ \_ ix' -> Sig.c_nonzero s' ix' t'

-- | Returns the trace (sum of the diagonal elements) of a matrix x. This is equal to the sum of the
-- eigenvalues of x.
ttrace :: Dynamic -> HsAccReal
ttrace t = unsafePerformIO . flip with (pure . c2hsAccReal) $ do
  s' <- managedState
  t' <- managedTensor t
  liftIO $ Sig.c_trace s' t'
{-# NOINLINE ttrace #-}

-- | mutates a tensor to be an @n × m@ identity matrix with ones on the diagonal and zeros elsewhere.
eye_
  :: Dynamic  -- ^ tensor to mutate inplace
  -> Integer  -- ^ @n@ dimension in an @n × m@ matrix
  -> Integer  -- ^ @m@ dimension in an @n × m@ matrix
  -> IO ()
eye_ t0 l0 l1 = runManaged $ do
  s'  <- managedState
  t0' <- managedTensor t0
  liftIO $ Sig.c_eye s' t0' (fromIntegral l0) (fromIntegral l1)

-- | identical to a direct C call to the @arange@, or @range@ with special consideration for floating precision types.
_arange :: Dynamic -> HsAccReal -> HsAccReal -> HsAccReal -> IO ()
_arange t0 a0 a1 a2 = runManaged $ do
  s'  <- managedState
  t0' <- managedTensor t0
  liftIO $ Sig.c_arange s' t0' (hs2cAccReal a0) (hs2cAccReal a1) (hs2cAccReal a2)

-- | identical to a direct C call to the @arange@, or @range@ with special consideration for floating precision types.
arange :: HsAccReal -> HsAccReal -> HsAccReal -> Dynamic
arange a0 a1 a2 = unsafePerformIO $ do
  let t = empty
  _arange t a0 a1 a2
  return t
{-# NOINLINE arange #-}

-- | mutate a Tensor inplace, filling it with values from @min@ to @max@ with @step@. Will make the tensor take a
-- shape of size @floor((y - x) / step) + 1@.
range_
  :: Dynamic    -- ^ tensor to mutate
  -> HsAccReal  -- ^ @min@ value
  -> HsAccReal  -- ^ @max@ value
  -> HsAccReal  -- ^ @step@ size
  -> IO ()
range_ t0 a0 a1 a2 = runManaged $ do
  s'  <- managedState
  t0' <- managedTensor t0
  liftIO $ Sig.c_range s' t0' (hs2cAccReal a0) (hs2cAccReal a1) (hs2cAccReal a2)

-- | pure version of 'range_'
range
  :: Dims (d::[Nat])
  -> HsAccReal
  -> HsAccReal
  -> HsAccReal
  -> Dynamic
range d a b c = unsafePerformIO $ withInplace (\r -> range_ r a b c) d
{-# NOINLINE range #-}

-- | create a 'Dynamic' tensor with a given dimension and value
--
-- We can get away 'unsafePerformIO' this as constant is pure and thread-safe
constant :: Dims (d :: [Nat]) -> HsReal -> Dynamic
constant d v = unsafePerformIO $ let r = new d in fill_ r v >> pure r
{-# NOINLINE constant #-}

-- | direct call to the C-FFI of @diag@, mutating the first tensor argument with
-- the data from the remaining aruments.
_diag :: Dynamic -> Dynamic -> Int -> IO ()
_diag t0 t1 i0 = with2DynamicState t0 t1 $ \s' t0' t1' -> Sig.c_diag s' t0' t1' (fromIntegral i0)

-- | mutates the tensor inplace and replaces it with the given k-th diagonal,
-- where k=0 is the main diagonal, k>0 is above the main diagonal, and k<0 is
-- below the main diagonal.
diag_ :: Dynamic -> Int -> IO ()
diag_ t d = _diag t t d

-- | returns the k-th diagonal of the input tensor, where k=0 is the main diagonal,
-- k>0 is above the main diagonal, and k<0 is below the main diagonal.
diag :: Dynamic -> Int -> Dynamic
diag t d = unsafePerformIO $ let r = new' (getSomeDims t) in _diag r t d >> pure r
{-# NOINLINE diag #-}

-- | returns a diagonal matrix with diagonal elements constructed from the input tensor
diag1d :: Dynamic -> Dynamic
diag1d t = diag t 1

-- | helper function for 'onesLike' and 'zerosLike'
_tenLike
  :: (Dynamic -> Dynamic -> IO ())
  -> Dims (d::[Nat]) -> IO Dynamic
_tenLike _fn d = do
  let
    src = new d
    shape = new d
  _fn src shape
  pure src
{-# WARNING _tenLike "this should not be exported outside of hasktorch" #-}

-- | pure version of 'onesLike_'
onesLike :: Dims (d::[Nat]) -> Dynamic
onesLike = unsafePerformIO . _tenLike onesLike_
{-# NOINLINE onesLike #-}

-- | pure version of 'zerosLike_'
zerosLike :: Dims (d::[Nat]) -> Dynamic
zerosLike = unsafePerformIO . _tenLike zerosLike_
{-# NOINLINE zerosLike #-}


-- class CPUTensorMath t where
--   match    :: t -> t -> t -> IO (HsReal t)
--   kthvalue :: t -> IndexDynamic t -> t -> Integer -> Int -> IO Int
--   randperm :: t -> Generator t -> Integer -> IO ()