-------------------------------------------------------------------------------
-- |
-- Module    :  Torch.Indef.Dynamic.Tensor.Masked
-- Copyright :  (c) Sam Stites 2017
-- License   :  BSD3
-- Maintainer:  sam@stites.io
-- Stability :  experimental
-- Portability: non-portable
--
-- Operations using a mask tensor to filter which elements will be used.
-------------------------------------------------------------------------------
module Torch.Indef.Dynamic.Tensor.Masked where

import Foreign hiding (with)
import Control.Monad.Managed
import qualified Torch.Sig.Types as Sig
import qualified Torch.Sig.Types.Global  as Sig
import qualified Torch.Sig.Tensor.Masked as Sig

import Torch.Indef.Types

-- | fill a dynamic tensor with a value, filtered by a boolean mask tensor
maskedFill_
  :: Dynamic     -- ^ source tensor to mutate, inplace
  -> MaskDynamic -- ^ mask to fill
  -> HsReal      -- ^ value to fill
  -> IO ()
maskedFill_ d m v = withLift $ Sig.c_maskedFill
  <$> managedState
  <*> managedTensor d
  <*> managed (withForeignPtr (snd $ Sig.byteDynamicState m))
  <*> pure (hs2cReal v)

-- | copy a dynamic tensor with a value, filtered by a boolean mask tensor
--
-- C-Style: In the classic Torch C-style, the first argument is treated as the return type and is mutated in-place.
_maskedCopy
  :: Dynamic     -- ^ return tensor to mutate, inplace
  -> MaskDynamic -- ^ mask to copy with
  -> Dynamic     -- ^ source tensor to copy from
  -> IO ()
_maskedCopy t m f = withLift $ Sig.c_maskedCopy
  <$> managedState
  <*> managedTensor t
  <*> managed (withForeignPtr (snd $ Sig.byteDynamicState m))
  <*> managedTensor f

-- | select a dynamic tensor with a value, filtered by a boolean mask tensor
--
-- C-Style: In the classic Torch C-style, the first argument is treated as the return type and is mutated in-place.
_maskedSelect
  :: Dynamic     -- ^ return tensor to mutate, inplace
  -> Dynamic     -- ^ source tensor to select from
  -> MaskDynamic -- ^ mask to select with
  -> IO ()
_maskedSelect t sel m = withLift $ Sig.c_maskedSelect
  <$> managedState
  <*> managedTensor t
  <*> managedTensor sel
  <*> managed (withForeignPtr (snd $ Sig.byteDynamicState m))

-- class GPUTensorMasked t where
--   maskedFillByte   :: t -> MaskDynamic t -> HsReal t -> io ()
--   maskedCopyByte   :: t -> MaskDynamic t -> t -> io ()
--   maskedSelectByte :: t -> t -> MaskDynamic t -> io ()