------------------------------------------------------------------------------- -- | -- Module : Torch.Indef.Dynamic.Tensor.Math.Pairwise -- Copyright : (c) Sam Stites 2017 -- License : BSD3 -- Maintainer: sam@stites.io -- Stability : experimental -- Portability: non-portable ------------------------------------------------------------------------------- {-# OPTIONS_GHC -fno-cse #-} module Torch.Indef.Dynamic.Tensor.Math.Pairwise ( equal , add, add_, add_scaled_ , sub, sub_, sub_scaled_ , mul, mul_ , Torch.Indef.Dynamic.Tensor.Math.Pairwise.div, div_ , lshift_ , rshift_ , fmod_ , remainder_ , bitand_ , bitor_ , bitxor_ ) where import Torch.Indef.Dynamic.Tensor import Torch.Indef.Types import System.IO.Unsafe import Control.Monad.Managed (with) import Control.Monad.IO.Class (liftIO) import qualified Torch.Sig.Tensor.Math.Pairwise as Sig -- | Call Torch's C-level @equal@ function. equal :: Dynamic -> Dynamic -> Bool equal r t = unsafeDupablePerformIO . fmap (== 1) . withLift $ Sig.c_equal <$> managedState <*> managedTensor r <*> managedTensor t -- | add a scalar to a tensor, inplace. add_ :: Dynamic -> HsReal -> IO () add_ t v = _add t t v -- | add a scalar to a tensor. add :: Dynamic -> HsReal -> Dynamic add t v = unsafeDupablePerformIO $ do let r = new' (getSomeDims t) _add r t v pure r {-# NOINLINE add #-} -- | subtract a scalar from a tensor, inplace. sub_ :: Dynamic -> HsReal -> IO () sub_ t v = _sub t t v -- | subtract a scalar from a tensor. sub :: Dynamic -> HsReal -> Dynamic sub t v = unsafeDupablePerformIO $ do let r = new' (getSomeDims t) _sub r t v pure r {-# NOINLINE sub #-} -- | add a scalar, which has been scaled, to a tensor, inplace. add_scaled_ :: Dynamic -- ^ tensor to scale -> HsReal -- ^ value to add -> HsReal -- ^ amount to scale the value by -> IO () add_scaled_ t v0 v1 = _add_scaled t t v0 v1 -- | subtract a scalar, which has been scaled, from a tensor, inplace. sub_scaled_ :: Dynamic -- ^ tensor to scale -> HsReal -- ^ value to add -> HsReal -- ^ amount to scale the value by -> IO () sub_scaled_ t v0 v1 = _sub_scaled t t v0 v1 -- | multiply a tensor by a scalar value, inplace. mul_ :: Dynamic -> HsReal -> IO () mul_ t v = _mul t t v -- | multiply a tensor by a scalar value, pure. mul :: Dynamic -> HsReal -> Dynamic mul t v = unsafeDupablePerformIO $ do let r = new' (getSomeDims t) _mul r t v pure r {-# NOINLINE mul #-} -- | divide a tensor by a scalar value, inplace. div_ :: Dynamic -> HsReal -> IO () div_ t v = _div t t v -- | divide a tensor by a scalar value, pure. div :: Dynamic -> HsReal -> Dynamic div t v = unsafeDupablePerformIO $ do let r = new' (getSomeDims t) _div r t v pure r {-# NOINLINE div #-} -- | Left shift all elements in the tensor by the given value, inplace. lshift_ :: Dynamic -> HsReal -> IO () lshift_ t v = _lshift t t v -- | Right shift all elements in the tensor by the given value, inplace. rshift_ :: Dynamic -> HsReal -> IO () rshift_ t v = _rshift t t v -- | Compute the remainder of division ( rounded towards zero) of all elements in the tensor by a given value, inplace. fmod_ :: Dynamic -> HsReal -> IO () fmod_ t v = _fmod t t v -- | Computes remainder of division (rounded to nearest) of all elements in the tensor by value, inplace remainder_ :: Dynamic -> HsReal -> IO () remainder_ t v = _remainder t t v -- | Performs the bitwise operation inplace on all elements in the tensor. bitand_, bitor_, bitxor_ :: Dynamic -> HsReal -> IO () bitand_ t v = _bitand t t v bitor_ t v = _bitor t t v bitxor_ t v = _bitxor t t v -- The remainder of this module includes C-styled versions of the haskell API _add :: Dynamic -> Dynamic -> HsReal -> IO () _add r t v = withLift $ Sig.c_add <$> managedState <*> managedTensor r <*> managedTensor t <*> pure (hs2cReal v) _sub :: Dynamic -> Dynamic -> HsReal -> IO () _sub r t v = withLift $ Sig.c_sub <$> managedState <*> managedTensor r <*> managedTensor t <*> pure (hs2cReal v) _add_scaled :: Dynamic -> Dynamic -> HsReal -> HsReal -> IO () _add_scaled r t v0 v1 = withLift $ Sig.c_add_scaled <$> managedState <*> managedTensor r <*> managedTensor t <*> pure (hs2cReal v0) <*> pure (hs2cReal v1) _sub_scaled :: Dynamic -> Dynamic -> HsReal -> HsReal -> IO () _sub_scaled r t v0 v1 = withLift $ Sig.c_sub_scaled <$> managedState <*> managedTensor r <*> managedTensor t <*> pure (hs2cReal v0) <*> pure (hs2cReal v1) _mul :: Dynamic -> Dynamic -> HsReal -> IO () _mul r t v = withLift $ Sig.c_mul <$> managedState <*> managedTensor r <*> managedTensor t <*> pure (hs2cReal v) _div :: Dynamic -> Dynamic -> HsReal -> IO () _div r t v = withLift $ Sig.c_div <$> managedState <*> managedTensor r <*> managedTensor t <*> pure (hs2cReal v) _lshift :: Dynamic -> Dynamic -> HsReal -> IO () _lshift r t v = withLift $ Sig.c_lshift <$> managedState <*> managedTensor r <*> managedTensor t <*> pure (hs2cReal v) _rshift :: Dynamic -> Dynamic -> HsReal -> IO () _rshift r t v = withLift $ Sig.c_rshift <$> managedState <*> managedTensor r <*> managedTensor t <*> pure (hs2cReal v) _fmod :: Dynamic -> Dynamic -> HsReal -> IO () _fmod r t v = withLift $ Sig.c_fmod <$> managedState <*> managedTensor r <*> managedTensor t <*> pure (hs2cReal v) _remainder :: Dynamic -> Dynamic -> HsReal -> IO () _remainder r t v = withLift $ Sig.c_remainder <$> managedState <*> managedTensor r <*> managedTensor t <*> pure (hs2cReal v) _bitand :: Dynamic -> Dynamic -> HsReal -> IO () _bitand r t v = withLift $ Sig.c_bitand <$> managedState <*> managedTensor r <*> managedTensor t <*> pure (hs2cReal v) _bitor :: Dynamic -> Dynamic -> HsReal -> IO () _bitor r t v = withLift $ Sig.c_bitor <$> managedState <*> managedTensor r <*> managedTensor t <*> pure (hs2cReal v) _bitxor :: Dynamic -> Dynamic -> HsReal -> IO () _bitxor r t v = withLift $ Sig.c_bitxor <$> managedState <*> managedTensor r <*> managedTensor t <*> pure (hs2cReal v)