-------------------------------------------------------------------------------
-- |
-- Module    :  Torch.Indef.Dynamic.Tensor.Math.Reduce
-- Copyright :  (c) Sam Stites 2017
-- License   :  BSD3
-- Maintainer:  sam@stites.io
-- Stability :  experimental
-- Portability: non-portable
-------------------------------------------------------------------------------
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fno-cse #-}
module Torch.Indef.Dynamic.Tensor.Math.Reduce
  ( minall
  , maxall
  , medianall
  , sumall
  , prodall
  , _max
  , _min
  , _median
  , _sum
  , _prod
  , Torch.Indef.Dynamic.Tensor.Math.Reduce.max
  , Torch.Indef.Dynamic.Tensor.Math.Reduce.min
  , median
  ) where

import Control.Monad.Managed
import Foreign (withForeignPtr)
import System.IO.Unsafe
import Numeric.Dimensions

import Torch.Indef.Types

import Torch.Indef.Dynamic.Tensor
import qualified Torch.Indef.Index as Ix
import qualified Torch.Sig.Types.Global as Sig
import qualified Torch.Sig.Tensor.Math.Reduce as Sig

-- | get the minima of a tensor's elements
minall :: Dynamic -> HsReal
minall t = unsafeDupablePerformIO . flip with (pure . c2hsReal) . (liftIO =<<) $ Sig.c_minall
  <$> managedState
  <*> managedTensor t
{-# NOINLINE minall #-}

-- | get the maxima of a tensor's elements
maxall :: Dynamic -> HsReal
maxall t = unsafeDupablePerformIO . flip with (pure . c2hsReal) . (liftIO =<<) $ Sig.c_maxall
  <$> managedState
  <*> managedTensor t
{-# NOINLINE maxall #-}

-- | get the median value of a tensor's elements
medianall :: Dynamic -> HsReal
medianall t = unsafeDupablePerformIO . flip with (pure . c2hsReal) . (liftIO =<<) $ Sig.c_medianall
  <$> managedState
  <*> managedTensor t
{-# NOINLINE medianall #-}

-- | get the sum of a tensor's elements
sumall :: Dynamic -> HsAccReal
sumall t = unsafeDupablePerformIO . flip with (pure . c2hsAccReal) . (liftIO =<<) $ Sig.c_sumall
  <$> managedState
  <*> managedTensor t
{-# NOINLINE sumall #-}

-- | get the product of a tensor's elements
prodall :: Dynamic -> HsAccReal
prodall t = unsafeDupablePerformIO . flip with (pure . c2hsAccReal) . (liftIO =<<) $ Sig.c_prodall
  <$> managedState
  <*> managedTensor t
{-# NOINLINE prodall #-}

-- | get the maximal value in the specified dimension and a corresponding index tensor of the maximum value's index.
--
-- Inplace and C-Style mutation
_max
  :: (Dynamic, IndexDynamic) -> Dynamic
  -> Word -- ^ dimension to operate over
  -> Maybe KeepDim -> IO ()
_max (t0, ix) t1 i0 i1 = withLift $ Sig.c_max
  <$> managedState
  <*> managedTensor t0
  <*> managed (withForeignPtr (snd $ Sig.longDynamicState ix))
  <*> managedTensor t1
  <*> pure (fromIntegral i0)
  <*> pure (fromKeepDim i1)
{-# NOINLINE _max #-}

-- | get the minimal value in the specified dimension and a corresponding index tensor of the minimum value's index.
--
-- Inplace and C-Style mutation
_min :: (Dynamic, IndexDynamic) -> Dynamic
  -> Word -- ^ dimension to operate over
  -> Maybe KeepDim -> IO ()
_min (t0, ix) t1 i0 i1  = withLift $ Sig.c_min
  <$> managedState
  <*> managedTensor t0
  <*> managed (withForeignPtr (snd $ Sig.longDynamicState ix))
  <*> managedTensor t1
  <*> pure (fromIntegral i0)
  <*> pure (fromKeepDim i1)
{-# NOINLINE _min #-}

-- | get the median value in the specified dimension and a corresponding index tensor of the median value's index.
--
-- Inplace and C-Style mutation
_median :: (Dynamic, IndexDynamic) -> Dynamic
  -> Word -- ^ dimension to operate over
  -> Maybe KeepDim -> IO ()
_median (t0, ix) t1 i0 i1 = withLift $ Sig.c_median
  <$> managedState
  <*> managedTensor t0
  <*> managed (withForeignPtr (snd $ Sig.longDynamicState ix))
  <*> managedTensor t1
  <*> pure (fromIntegral i0)
  <*> pure (fromKeepDim i1)
{-# NOINLINE _median #-}

-- | sum the tensor in the specified dimension.
--
-- Inplace and C-Style mutation
_sum :: Dynamic -> Dynamic
  -> Word -- ^ dimension to operate over
  -> Maybe KeepDim -> IO ()
_sum t0 t1 i0 i1 = withLift $ Sig.c_sum
  <$> managedState
  <*> managedTensor t0
  <*> managedTensor t1
  <*> pure (fromIntegral i0)
  <*> pure (fromKeepDim i1)
{-# NOINLINE _sum #-}

-- | take the product of the tensor in the specified dimension.
--
-- Inplace and C-Style mutation
_prod
  :: Dynamic -> Dynamic
  -> Word -- ^ dimension to operate over
  -> Maybe KeepDim -> IO ()
_prod t0 t1 i0 i1 = withLift $ Sig.c_prod
  <$> managedState
  <*> managedTensor t0
  <*> managedTensor t1
  <*> pure (fromIntegral i0)
  <*> pure (fromKeepDim i1)
{-# NOINLINE _prod #-}


withKeepDim
  :: ((Dynamic, IndexDynamic) -> Dynamic -> Word -> Maybe KeepDim -> IO ())
  -> Dynamic -> Word -> Maybe KeepDim -> (Dynamic, Maybe (IndexDynamic))
withKeepDim _fn t d k = unsafeDupablePerformIO $ do
  _fn (ret, ix) t d k
  pure (ret, maybe Nothing (\(KeepDim b) -> if b then Just ix else Nothing) k)
  where
    tdim = getSomeDims t
    (i:_) = shape t
    ret :: Dynamic = new' tdim
    ix = Ix.newIxDyn [i]
{-# NOINLINE withKeepDim #-}

-- | get the maximum value in the specified dimension and return an optional corresponding index tensor of the maximum value's index.
max
  :: Dynamic
  -> Word -- ^ dimension to operate over
  -> Maybe KeepDim
  -> (Dynamic, Maybe (IndexDynamic))
max = withKeepDim _max

-- | get the minimum value in the specified dimension and return an optional corresponding index tensor of the minimum value's index.
min
  :: Dynamic
  -> Word -- ^ dimension to operate over
  -> Maybe KeepDim
  -> (Dynamic, Maybe (IndexDynamic))
min = withKeepDim _min

-- | get the median value in the specified dimension and return an optional corresponding index tensor of the median value's index.
median
  :: Dynamic
  -> Word -- ^ dimension to operate over
  -> Maybe KeepDim
  -> (Dynamic, Maybe (IndexDynamic))
median = withKeepDim _median


-- * not in THC.BYte
-- c_renorm :: Ptr CState -> t -> t -> HsReal t -> CInt -> HsReal t -> IO ()
-- c_std :: Ptr CState -> t -> Ptr CTensor -> CInt -> CInt -> CInt -> IO ()
-- c_stdall :: Ptr CState -> Ptr CTensor -> CInt -> HsReal t
-- c_var :: Ptr CState -> Ptr CTensor -> Ptr CTensor -> CInt -> CInt -> CInt -> IO ()
-- c_varall :: Ptr CState -> Ptr CTensor -> CInt -> HsReal t
-- c_dist :: Ptr CState -> Ptr CTensor -> Ptr CTensor -> HsReal t -> HsReal t

-- * not in TH.Byte
-- c_norm :: Ptr CState -> Ptr CTensor -> Ptr CTensor -> HsReal t -> CInt -> CInt -> IO ()
-- c_normall :: Ptr CState -> Ptr CTensor -> HsReal t -> HsReal t
-- c_mean :: Ptr CState -> Ptr CTensor -> Ptr CTensor -> CInt -> CInt -> IO ()
-- c_meanall :: Ptr CState -> Ptr CTensor -> HsReal t