------------------------------------------------------------------------------- -- | -- Module : Torch.Indef.Dynamic.Tensor.Math.CompareT -- Copyright : (c) Sam Stites 2017 -- License : BSD3 -- Maintainer: sam@stites.io -- Stability : experimental -- Portability: non-portable -- -- Compare two tensors ------------------------------------------------------------------------------- {-# OPTIONS_GHC -fno-cse #-} module Torch.Indef.Dynamic.Tensor.Math.CompareT ( ltTensor, ltTensorT, ltTensorT_ , leTensor, leTensorT, leTensorT_ , gtTensor, gtTensorT, gtTensorT_ , geTensor, geTensorT, geTensorT_ , neTensor, neTensorT, neTensorT_ , eqTensor, eqTensorT, eqTensorT_ ) where import Foreign hiding (with, new) import Foreign.Ptr import System.IO.Unsafe import Numeric.Dimensions import qualified Torch.Sig.Tensor.Math.CompareT as Sig import Torch.Indef.Mask import Torch.Indef.Types import Torch.Indef.Dynamic.Tensor _ltTensorT, _leTensorT, _gtTensorT, _geTensorT, _neTensorT, _eqTensorT :: Dynamic -> Dynamic -> Dynamic -> IO () _ltTensorT = compareTensorTOp Sig.c_ltTensorT _leTensorT = compareTensorTOp Sig.c_leTensorT _gtTensorT = compareTensorTOp Sig.c_gtTensorT _geTensorT = compareTensorTOp Sig.c_geTensorT _neTensorT = compareTensorTOp Sig.c_neTensorT _eqTensorT = compareTensorTOp Sig.c_eqTensorT compareTensorTOp :: (Ptr CState -> Ptr CTensor -> Ptr CTensor -> Ptr CTensor -> IO ()) -> Dynamic -> Dynamic -> Dynamic -> IO () compareTensorTOp fn a b c = withLift $ fn <$> managedState <*> managedTensor a <*> managedTensor b <*> managedTensor c compareTensorOp :: (Ptr CState -> Ptr CByteTensor -> Ptr CTensor -> Ptr CTensor -> IO ()) -> Dynamic -> Dynamic -> MaskDynamic compareTensorOp op t0 t1 = unsafeDupablePerformIO $ do let sd = getSomeDims t0 bt = newMaskDyn' sd with2DynamicState t0 t1 $ \s' t0' t1' -> withMask bt $ \bt' -> op s' bt' t0' t1' pure bt {-# NOINLINE compareTensorOp #-} -- | Return a byte tensor which contains boolean values indicating the relation between two tensors. ltTensor, leTensor, gtTensor, geTensor, neTensor, eqTensor :: Dynamic -> Dynamic -> MaskDynamic ltTensor = compareTensorOp Sig.c_ltTensor leTensor = compareTensorOp Sig.c_leTensor gtTensor = compareTensorOp Sig.c_gtTensor geTensor = compareTensorOp Sig.c_geTensor neTensor = compareTensorOp Sig.c_neTensor eqTensor = compareTensorOp Sig.c_eqTensor -- | return a tensor which contains numeric values indicating the relation between two tensors. -- 0 stands for false, 1 stands for true. ltTensorT, leTensorT, gtTensorT, geTensorT, neTensorT, eqTensorT :: Dynamic -- ^ source tensor. -> Dynamic -- ^ tensor to compare with. -> Dynamic -- ^ new return tensor. ltTensorT a b = unsafeDupablePerformIO $ let r = empty in _ltTensorT r a b >> pure r leTensorT a b = unsafeDupablePerformIO $ let r = empty in _leTensorT r a b >> pure r gtTensorT a b = unsafeDupablePerformIO $ let r = empty in _gtTensorT r a b >> pure r geTensorT a b = unsafeDupablePerformIO $ let r = empty in _geTensorT r a b >> pure r neTensorT a b = unsafeDupablePerformIO $ let r = empty in _neTensorT r a b >> pure r eqTensorT a b = unsafeDupablePerformIO $ let r = empty in _eqTensorT r a b >> pure r {-# NOINLINE ltTensorT #-} {-# NOINLINE leTensorT #-} {-# NOINLINE gtTensorT #-} {-# NOINLINE geTensorT #-} {-# NOINLINE neTensorT #-} {-# NOINLINE eqTensorT #-} -- | mutate a tensor in-place with its numeric relation to the second tensor of the same size, -- where 0 stands for false and 1 stands for true. ltTensorT_, leTensorT_, gtTensorT_, geTensorT_, neTensorT_, eqTensorT_ :: Dynamic -- ^ source tensor to mutate inplace. -> Dynamic -- ^ tensor to compare with. -> IO () ltTensorT_ a b = _ltTensorT a a b leTensorT_ a b = _leTensorT a a b gtTensorT_ a b = _gtTensorT a a b geTensorT_ a b = _geTensorT a a b neTensorT_ a b = _neTensorT a a b eqTensorT_ a b = _eqTensorT a a b