-------------------------------------------------------------------------------
-- |
-- Module    :  Torch.Indef.Static.Tensor.Math.Reduce
-- Copyright :  (c) Sam Stites 2017
-- License   :  BSD3
-- Maintainer:  sam@stites.io
-- Stability :  experimental
-- Portability: non-portable
-------------------------------------------------------------------------------
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-cse #-}
module Torch.Indef.Static.Tensor.Math.Reduce
  ( minall
  , maxall
  , medianall
  , sumall
  , prodall

  , Torch.Indef.Static.Tensor.Math.Reduce.min
  , Torch.Indef.Static.Tensor.Math.Reduce.max
  , median

  , minIndex1d    , min1d
  , maxIndex1d    , max1d, max2d0, max2d1
  , medianIndex1d , median1d

  , Torch.Indef.Static.Tensor.Math.Reduce.sum, rowsum, colsum
  , _prod
  ) where

import Numeric.Dimensions
import Data.Coerce
import System.IO.Unsafe
import Data.Singletons.Prelude.List hiding (All, type (++), Length)
import Data.Singletons.Prelude.Ord
import GHC.TypeLits


import Data.Maybe (fromJust)
import Torch.Indef.Index
import Torch.Indef.Static.Tensor
import Torch.Indef.Types
import qualified Torch.Indef.Dynamic.Tensor.Math.Reduce as Dynamic


-- | Static call to 'Dynamic.minall'
minall :: Tensor d -> HsReal
minall t = Dynamic.minall (asDynamic t)

-- | Static call to 'Dynamic.maxall'
maxall :: Tensor d -> HsReal
maxall t = Dynamic.maxall (asDynamic t)

-- | Static call to 'Dynamic.medianall'
medianall :: Tensor d -> HsReal
medianall t = Dynamic.medianall (asDynamic t)

-- | Static call to 'Dynamic.sumall'
sumall :: Tensor d -> HsAccReal
sumall t = Dynamic.sumall (asDynamic t)

-- | Static call to 'Dynamic.prodall'
prodall :: Tensor d -> HsAccReal
prodall t = Dynamic.prodall (asDynamic t)

-- | Static call to 'Dynamic.max'
max
  :: forall d n ix rs ls
  .  All Dimensions '[d, rs ++ '[1] ++ ls]
  => All KnownNat '[n, ix]
  => All KnownDim '[n, ix]
  => (Length d > ix) ~ True
  => '(rs, n:+ls) ~ (SplitAt ix d)
  => Tensor d
  -> Dim ix
  -> KeepDim
  -> (Tensor (rs ++ '[1] ++ ls), Maybe (IndexTensor (rs ++ '[1] ++ ls)))
max = withKeepDim Dynamic._max

-- | Static call to 'Dynamic.min'
min
  :: forall d n ix rs ls
  .  All Dimensions '[d, rs ++ '[1] ++ ls]
  => All KnownNat '[n, ix]
  => All KnownDim '[n, ix]
  => (Length d > ix) ~ True
  => '(rs, n:+ls) ~ (SplitAt ix d)
  => Tensor d
  -> Dim ix
  -> KeepDim
  -> (Tensor (rs ++ '[1] ++ ls), Maybe (IndexTensor (rs ++ '[1] ++ ls)))
min = withKeepDim Dynamic._min

-- | Static call to 'Dynamic.median'
median
  :: forall d n ix rs ls
  .  All Dimensions '[d, rs ++ '[1] ++ ls]
  => All KnownNat '[n, ix]
  => All KnownDim '[n, ix]
  => (Length d > ix) ~ True
  => '(rs, n:+ls) ~ (SplitAt ix d)
  => Tensor d
  -> Dim ix
  -> KeepDim
  -> (Tensor (rs ++ '[1] ++ ls), Maybe (IndexTensor (rs ++ '[1] ++ ls)))
median = withKeepDim Dynamic._median

-- | Convenience method for 'max' over vectors
max1d :: (KnownNat n, KnownDim n) => Tensor '[n] -> KeepDim -> (Tensor '[1], Maybe (IndexTensor '[1]))
max1d t = Torch.Indef.Static.Tensor.Math.Reduce.max t (dim :: Dim 0)

-- | Convenience method for 'max' over matricies
max2d0 :: (KnownDim m, KnownNat n, KnownDim n) => Tensor '[n, m] -> KeepDim -> (Tensor '[1, m], Maybe (IndexTensor '[1, m]))
max2d0 t = Torch.Indef.Static.Tensor.Math.Reduce.max t (dim :: Dim 0)

-- | Convenience method for 'max' over matricies
max2d1 :: (KnownDim n, KnownDim m, KnownNat m) => Tensor '[n, m] -> KeepDim -> (Tensor '[n, 1], Maybe (IndexTensor '[n, 1]))
max2d1 t = Torch.Indef.Static.Tensor.Math.Reduce.max t (dim :: Dim 1)

-- | Convenience method for 'min'
min1d :: (KnownNat n, KnownDim n) => Tensor '[n] -> KeepDim -> (Tensor '[1], Maybe (IndexTensor '[1]))
min1d t = Torch.Indef.Static.Tensor.Math.Reduce.min t (dim :: Dim 0)

-- | Convenience method for 'median'
median1d :: (KnownNat n, KnownDim n) => Tensor '[n] -> KeepDim -> (Tensor '[1], Maybe (IndexTensor '[1]))
median1d t = median t (dim :: Dim 0)

-- | Convenience method for 'max'
maxIndex1d t = fromJust . snd $ max1d t keep

-- | Convenience method for 'min'
minIndex1d t = fromJust . snd $ min1d t keep

-- | Convenience method for 'median'
medianIndex1d t = fromJust . snd $ median1d t keep

-- | Static call to 'Dynamic._prod'
_prod :: Tensor d -> Tensor d -> Word -> Maybe KeepDim -> IO ()
_prod r t = Dynamic._prod (asDynamic r) (asDynamic t)

-------------------------------------------------------------------------------

withKeepDim
  :: forall d n ix rs ls
  .  All Dimensions '[d, rs ++ '[1] ++ ls]
  => All KnownNat '[n, ix]
  => All KnownDim '[n, ix]
  => (Length d > ix) ~ True
  => '(rs, n:+ls) ~ (SplitAt ix d)
  => ((Dynamic, IndexDynamic) -> Dynamic -> Word -> Maybe KeepDim -> IO ())
  -> Tensor d
  -> Dim ix
  -> KeepDim
  -> (Tensor (rs ++ '[1] ++ ls), Maybe (IndexTensor (rs ++ '[1] ++ ls)))
withKeepDim _fn t d k = unsafePerformIO $ do
  let ret = new
  let ix :: IndexTensor (rs ++ '[1] ++ ls) = newIx
  _fn (asDynamic ret, longAsDynamic ix) (asDynamic t) (dimVal d) (Just k)
  pure (ret, if coerce k then Just ix else Nothing)
{-# NOINLINE withKeepDim #-}


-- | Static call to 'Dynamic.sum'
sum :: Dimensions d' => Tensor d -> Word -> KeepDim -> Tensor d'
sum t d k = unsafePerformIO $ do
  let r = new
  Dynamic._sum (asDynamic r) (asDynamic t) d (Just k)
  pure r
{-# NOINLINE sum #-}

-- | convenience function for 'sum'
rowsum :: (All KnownDim '[r,c]) => Tensor '[r, c] -> (Tensor '[1, c])
rowsum t = Torch.Indef.Static.Tensor.Math.Reduce.sum t 0 keep

-- | convenience function for 'sum'
colsum :: (All KnownDim '[r,c]) => Tensor '[r, c] -> (Tensor '[r, 1])
colsum t = Torch.Indef.Static.Tensor.Math.Reduce.sum t 0 keep