-------------------------------------------------------------------------------
-- |
-- Module    :  Torch.Indef.Dynamic.Tensor.Math.Compare
-- Copyright :  (c) Sam Stites 2017
-- License   :  BSD3
-- Maintainer:  sam@stites.io
-- Stability :  experimental
-- Portability: non-portable
--
-- Compare a tensor with a scala value
-------------------------------------------------------------------------------
{-# OPTIONS_GHC -fno-cse #-}
module Torch.Indef.Dynamic.Tensor.Math.Compare
  ( ltValue, ltValueT, ltValueT_
  , leValue, leValueT, leValueT_
  , gtValue, gtValueT, gtValueT_
  , geValue, geValueT, geValueT_
  , neValue, neValueT, neValueT_
  , eqValue, eqValueT, eqValueT_
  ) where

import Foreign hiding (with, new)
import Foreign.Ptr
import Numeric.Dimensions
import System.IO.Unsafe
import Control.Monad.Managed

import Torch.Indef.Types
import Torch.Indef.Mask
import Torch.Indef.Dynamic.Tensor

import qualified Torch.Sig.Tensor.Math.Compare as Sig

_ltValueT, _leValueT, _gtValueT, _geValueT, _neValueT, _eqValueT
  :: Dynamic -> Dynamic -> HsReal -> IO ()
_ltValueT = compareValueTOp Sig.c_ltValueT
_leValueT = compareValueTOp Sig.c_leValueT
_gtValueT = compareValueTOp Sig.c_gtValueT
_geValueT = compareValueTOp Sig.c_geValueT
_neValueT = compareValueTOp Sig.c_neValueT
_eqValueT = compareValueTOp Sig.c_eqValueT

compareValueTOp
  :: (Ptr CState -> Ptr CTensor -> Ptr CTensor -> CReal -> IO ())
  -> Dynamic -> Dynamic -> HsReal -> IO ()
compareValueTOp fn a b v = withLift $ fn
  <$> managedState
  <*> managedTensor a
  <*> managedTensor b
  <*> pure (hs2cReal v)

compareTensorOp
  :: (Ptr CState -> Ptr CByteTensor -> Ptr CTensor -> CReal -> IO ())
  -> Dynamic -> HsReal -> MaskDynamic
compareTensorOp op t0 v = unsafeDupablePerformIO . flip with pure $ do
  s' <- managedState
  t' <- managedTensor t0
  let bt = newMaskDyn' (getSomeDims t0)
  bt' <- managed $ withMask bt
  liftIO $ op s' bt' t' (hs2cReal v)
  pure bt
{-# NOINLINE compareTensorOp #-}

-- | return a byte tensor which contains boolean values indicating the relation between a tensor and a given scalar.
ltValue, leValue, gtValue, geValue, neValue, eqValue
  :: Dynamic -> HsReal -> MaskDynamic
ltValue = compareTensorOp Sig.c_ltValue
leValue = compareTensorOp Sig.c_leValue
gtValue = compareTensorOp Sig.c_gtValue
geValue = compareTensorOp Sig.c_geValue
neValue = compareTensorOp Sig.c_neValue
eqValue = compareTensorOp Sig.c_eqValue

-- | return a tensor which contains numeric values indicating the relation between a tensor and a given scalar.
-- 0 stands for false, 1 stands for true.
ltValueT, leValueT, gtValueT, geValueT, neValueT, eqValueT
  :: Dynamic -> HsReal -> Dynamic
ltValueT  a b = unsafeDupablePerformIO $ let r = empty in _ltValueT r a b >> pure r
leValueT  a b = unsafeDupablePerformIO $ let r = empty in _leValueT r a b >> pure r
gtValueT  a b = unsafeDupablePerformIO $ let r = empty in _gtValueT r a b >> pure r
geValueT  a b = unsafeDupablePerformIO $ let r = empty in _geValueT r a b >> pure r
neValueT  a b = unsafeDupablePerformIO $ let r = empty in _neValueT r a b >> pure r
eqValueT  a b = unsafeDupablePerformIO $ let r = empty in _eqValueT r a b >> pure r
{-# NOINLINE ltValueT #-}
{-# NOINLINE leValueT #-}
{-# NOINLINE gtValueT #-}
{-# NOINLINE geValueT #-}
{-# NOINLINE neValueT #-}
{-# NOINLINE eqValueT #-}

-- | mutate a tensor in-place with its numeric relation to a given scalar, where 0 stands for false and
-- 1 stands for true.
ltValueT_, leValueT_, gtValueT_, geValueT_, neValueT_, eqValueT_
  :: Dynamic -> HsReal -> IO ()
ltValueT_ a b = _ltValueT a a b
leValueT_ a b = _leValueT a a b
gtValueT_ a b = _gtValueT a a b
geValueT_ a b = _geValueT a a b
neValueT_ a b = _neValueT a a b
eqValueT_ a b = _eqValueT a a b