{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fno-cse #-}
module Torch.Indef.Dynamic.Tensor.Math.Reduce
( minall
, maxall
, medianall
, sumall
, prodall
, _max
, _min
, _median
, _sum
, _prod
, Torch.Indef.Dynamic.Tensor.Math.Reduce.max
, Torch.Indef.Dynamic.Tensor.Math.Reduce.min
, median
) where
import Control.Monad.Managed
import Foreign (withForeignPtr)
import System.IO.Unsafe
import Numeric.Dimensions
import Torch.Indef.Types
import Torch.Indef.Dynamic.Tensor
import qualified Torch.Indef.Index as Ix
import qualified Torch.Sig.Types.Global as Sig
import qualified Torch.Sig.Tensor.Math.Reduce as Sig
minall :: Dynamic -> HsReal
minall t = unsafeDupablePerformIO . flip with (pure . c2hsReal) . (liftIO =<<) $ Sig.c_minall
<$> managedState
<*> managedTensor t
{-# NOINLINE minall #-}
maxall :: Dynamic -> HsReal
maxall t = unsafeDupablePerformIO . flip with (pure . c2hsReal) . (liftIO =<<) $ Sig.c_maxall
<$> managedState
<*> managedTensor t
{-# NOINLINE maxall #-}
medianall :: Dynamic -> HsReal
medianall t = unsafeDupablePerformIO . flip with (pure . c2hsReal) . (liftIO =<<) $ Sig.c_medianall
<$> managedState
<*> managedTensor t
{-# NOINLINE medianall #-}
sumall :: Dynamic -> HsAccReal
sumall t = unsafeDupablePerformIO . flip with (pure . c2hsAccReal) . (liftIO =<<) $ Sig.c_sumall
<$> managedState
<*> managedTensor t
{-# NOINLINE sumall #-}
prodall :: Dynamic -> HsAccReal
prodall t = unsafeDupablePerformIO . flip with (pure . c2hsAccReal) . (liftIO =<<) $ Sig.c_prodall
<$> managedState
<*> managedTensor t
{-# NOINLINE prodall #-}
_max
:: (Dynamic, IndexDynamic) -> Dynamic
-> Word
-> Maybe KeepDim -> IO ()
_max (t0, ix) t1 i0 i1 = withLift $ Sig.c_max
<$> managedState
<*> managedTensor t0
<*> managed (withForeignPtr (snd $ Sig.longDynamicState ix))
<*> managedTensor t1
<*> pure (fromIntegral i0)
<*> pure (fromKeepDim i1)
{-# NOINLINE _max #-}
_min :: (Dynamic, IndexDynamic) -> Dynamic
-> Word
-> Maybe KeepDim -> IO ()
_min (t0, ix) t1 i0 i1 = withLift $ Sig.c_min
<$> managedState
<*> managedTensor t0
<*> managed (withForeignPtr (snd $ Sig.longDynamicState ix))
<*> managedTensor t1
<*> pure (fromIntegral i0)
<*> pure (fromKeepDim i1)
{-# NOINLINE _min #-}
_median :: (Dynamic, IndexDynamic) -> Dynamic
-> Word
-> Maybe KeepDim -> IO ()
_median (t0, ix) t1 i0 i1 = withLift $ Sig.c_median
<$> managedState
<*> managedTensor t0
<*> managed (withForeignPtr (snd $ Sig.longDynamicState ix))
<*> managedTensor t1
<*> pure (fromIntegral i0)
<*> pure (fromKeepDim i1)
{-# NOINLINE _median #-}
_sum :: Dynamic -> Dynamic
-> Word
-> Maybe KeepDim -> IO ()
_sum t0 t1 i0 i1 = withLift $ Sig.c_sum
<$> managedState
<*> managedTensor t0
<*> managedTensor t1
<*> pure (fromIntegral i0)
<*> pure (fromKeepDim i1)
{-# NOINLINE _sum #-}
_prod
:: Dynamic -> Dynamic
-> Word
-> Maybe KeepDim -> IO ()
_prod t0 t1 i0 i1 = withLift $ Sig.c_prod
<$> managedState
<*> managedTensor t0
<*> managedTensor t1
<*> pure (fromIntegral i0)
<*> pure (fromKeepDim i1)
{-# NOINLINE _prod #-}
withKeepDim
:: ((Dynamic, IndexDynamic) -> Dynamic -> Word -> Maybe KeepDim -> IO ())
-> Dynamic -> Word -> Maybe KeepDim -> (Dynamic, Maybe (IndexDynamic))
withKeepDim _fn t d k = unsafeDupablePerformIO $ do
_fn (ret, ix) t d k
pure (ret, maybe Nothing (\(KeepDim b) -> if b then Just ix else Nothing) k)
where
tdim = getSomeDims t
(i:_) = shape t
ret :: Dynamic = new' tdim
ix = Ix.newIxDyn [i]
{-# NOINLINE withKeepDim #-}
max
:: Dynamic
-> Word
-> Maybe KeepDim
-> (Dynamic, Maybe (IndexDynamic))
max = withKeepDim _max
min
:: Dynamic
-> Word
-> Maybe KeepDim
-> (Dynamic, Maybe (IndexDynamic))
min = withKeepDim _min
median
:: Dynamic
-> Word
-> Maybe KeepDim
-> (Dynamic, Maybe (IndexDynamic))
median = withKeepDim _median