{-# OPTIONS_GHC -fno-cse #-}
module Torch.Indef.Dynamic.Tensor.Math.Compare
( ltValue, ltValueT, ltValueT_
, leValue, leValueT, leValueT_
, gtValue, gtValueT, gtValueT_
, geValue, geValueT, geValueT_
, neValue, neValueT, neValueT_
, eqValue, eqValueT, eqValueT_
) where
import Foreign hiding (with, new)
import Foreign.Ptr
import Numeric.Dimensions
import System.IO.Unsafe
import Control.Monad.Managed
import Torch.Indef.Types
import Torch.Indef.Mask
import Torch.Indef.Dynamic.Tensor
import qualified Torch.Sig.Tensor.Math.Compare as Sig
_ltValueT, _leValueT, _gtValueT, _geValueT, _neValueT, _eqValueT
:: Dynamic -> Dynamic -> HsReal -> IO ()
_ltValueT = compareValueTOp Sig.c_ltValueT
_leValueT = compareValueTOp Sig.c_leValueT
_gtValueT = compareValueTOp Sig.c_gtValueT
_geValueT = compareValueTOp Sig.c_geValueT
_neValueT = compareValueTOp Sig.c_neValueT
_eqValueT = compareValueTOp Sig.c_eqValueT
compareValueTOp
:: (Ptr CState -> Ptr CTensor -> Ptr CTensor -> CReal -> IO ())
-> Dynamic -> Dynamic -> HsReal -> IO ()
compareValueTOp fn a b v = withLift $ fn
<$> managedState
<*> managedTensor a
<*> managedTensor b
<*> pure (hs2cReal v)
compareTensorOp
:: (Ptr CState -> Ptr CByteTensor -> Ptr CTensor -> CReal -> IO ())
-> Dynamic -> HsReal -> MaskDynamic
compareTensorOp op t0 v = unsafeDupablePerformIO . flip with pure $ do
s' <- managedState
t' <- managedTensor t0
let bt = newMaskDyn' (getSomeDims t0)
bt' <- managed $ withMask bt
liftIO $ op s' bt' t' (hs2cReal v)
pure bt
{-# NOINLINE compareTensorOp #-}
ltValue, leValue, gtValue, geValue, neValue, eqValue
:: Dynamic -> HsReal -> MaskDynamic
ltValue = compareTensorOp Sig.c_ltValue
leValue = compareTensorOp Sig.c_leValue
gtValue = compareTensorOp Sig.c_gtValue
geValue = compareTensorOp Sig.c_geValue
neValue = compareTensorOp Sig.c_neValue
eqValue = compareTensorOp Sig.c_eqValue
ltValueT, leValueT, gtValueT, geValueT, neValueT, eqValueT
:: Dynamic -> HsReal -> Dynamic
ltValueT a b = unsafeDupablePerformIO $ let r = empty in _ltValueT r a b >> pure r
leValueT a b = unsafeDupablePerformIO $ let r = empty in _leValueT r a b >> pure r
gtValueT a b = unsafeDupablePerformIO $ let r = empty in _gtValueT r a b >> pure r
geValueT a b = unsafeDupablePerformIO $ let r = empty in _geValueT r a b >> pure r
neValueT a b = unsafeDupablePerformIO $ let r = empty in _neValueT r a b >> pure r
eqValueT a b = unsafeDupablePerformIO $ let r = empty in _eqValueT r a b >> pure r
{-# NOINLINE ltValueT #-}
{-# NOINLINE leValueT #-}
{-# NOINLINE gtValueT #-}
{-# NOINLINE geValueT #-}
{-# NOINLINE neValueT #-}
{-# NOINLINE eqValueT #-}
ltValueT_, leValueT_, gtValueT_, geValueT_, neValueT_, eqValueT_
:: Dynamic -> HsReal -> IO ()
ltValueT_ a b = _ltValueT a a b
leValueT_ a b = _leValueT a a b
gtValueT_ a b = _gtValueT a a b
geValueT_ a b = _geValueT a a b
neValueT_ a b = _neValueT a a b
eqValueT_ a b = _eqValueT a a b