-------------------------------------------------------------------------------
-- |
-- Module    :  Torch.Indef.Mask
-- Copyright :  (c) Sam Stites 2017
-- License   :  BSD3
-- Maintainer:  sam@stites.io
-- Stability :  experimental
-- Portability: non-portable
--
-- Redundant version of @Torch.Indef.{Dynamic/Static}.Tensor@ for Byte tensors.
--
-- This comes with the same fixme as 'Torch.Indef.Index':
--
-- FIXME: in the future, there could be a smaller subset of Torch which could
-- be compiled to to keep the code dry. Alternatively, if backpack one day
-- supports recursive indefinites, we could use this feature to possibly remove
-- this package and 'Torch.Indef.Mask'.
-------------------------------------------------------------------------------
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
module Torch.Indef.Mask
  ( newMask
  , newMaskDyn
  , newMaskDyn'
  , withMask
  , allOf
  ) where

import Foreign
import Foreign.Ptr
import Data.Proxy
import Data.List
import Control.Monad
import System.IO.Unsafe

import Numeric.Dimensions
import Torch.Sig.Types.Global
import Torch.Indef.Internal
import Control.Monad.Managed as X

import Torch.Sig.State as Sig
import qualified Torch.Types.TH as TH
import qualified Torch.Sig.Mask.Tensor as MaskSig
import qualified Torch.Sig.Mask.MathReduce as MaskSig
import qualified Torch.Sig.Mask.TensorFree as MaskSig

-- | build a new mask tensor with any known Dimension list.
newMask :: forall d . Dimensions d => MaskTensor d
newMask = byteAsStatic $ newMaskDyn (dims :: Dims d)

-- | build a new dynamic mask tensor with any known Nat list.
newMaskDyn :: Dims (d::[Nat]) -> MaskDynamic
newMaskDyn d = unsafeDupablePerformIO $ withForeignPtr Sig.torchstate $ \s -> do
  bytePtr <- case fromIntegral <$> listDims d of
    []           -> MaskSig.c_newWithSize1d s 1
    [x]          -> MaskSig.c_newWithSize1d s x
    [x, y]       -> MaskSig.c_newWithSize2d s x y
    [x, y, z]    -> MaskSig.c_newWithSize3d s x y z
    [x, y, z, q] -> MaskSig.c_newWithSize4d s x y z q
    _ -> error "FIXME: can't build masks of this size yet"

  byteDynamic Sig.torchstate
    <$> newForeignPtrEnv MaskSig.p_free s bytePtr

newMaskDyn' :: SomeDims -> MaskDynamic
newMaskDyn' (SomeDims d) = newMaskDyn d

-- | run a function with access to a dynamic index tensor's raw c-pointer.
withMask :: MaskDynamic -> (Ptr CMaskTensor -> IO x) -> IO x
withMask ix fn = withForeignPtr (snd $ byteDynamicState ix) fn

class IsMask t where
  -- | assert that all of the values of the Byte tensor are true.
  allOf :: t -> Bool
  -- anyOf :: t -> Bool

instance IsMask MaskDynamic where
  allOf t = unsafePerformIO $ flip X.with pure $ do
    s' <- managed $ withForeignPtr s
    t' <- managed $ withForeignPtr fp

    liftIO $ do
      ds <- MaskSig.c_nDimension s' t'
      prod <- foldM (\acc d -> (acc *) <$> MaskSig.c_size s' t' (fromIntegral d)) 1 [0..ds-1]
      tot  <- MaskSig.c_sumall s' t'
      pure $ tot == fromIntegral prod

    where
      (s, fp) = byteDynamicState t

instance IsMask (MaskTensor d) where
  allOf = allOf . byteAsDynamic