-------------------------------------------------------------------------------
-- |
-- Module    :  Torch.Indef.Dynamic.Tensor.ScatterGather
-- Copyright :  (c) Sam Stites 2017
-- License   :  BSD3
-- Maintainer:  sam@stites.io
-- Stability :  experimental
-- Portability: non-portable
-------------------------------------------------------------------------------
module Torch.Indef.Dynamic.Tensor.ScatterGather where

import Control.Monad.Managed
import Torch.Indef.Types
import qualified Torch.Indef.Index as Ix
import qualified Torch.Sig.Tensor.ScatterGather as Sig

-- | From the Lua docs:
--
-- Creates a new Tensor from the original tensor by gathering a number of values from each "row", where the rows are along the dimension dim. The values in a LongTensor, passed as index, specify which values to take from each row. Specifically, the resulting Tensor, which will have the same size as the index tensor, is given by
--
-- @
--   -- dim = 1
--   result[i][j][k]... = src[index[i][j][k]...][j][k]...
--
--   -- dim = 2
--   result[i][j][k]... = src[i][index[i][j][k]...][k]...
--
--   -- etc.
-- @
--
-- where src is the original Tensor.
--
-- The same number of values are selected from each row, and the same value cannot be selected from a row more than once. The values in the index tensor must not be larger than the length of the row, that is they must be between 1 and src:size(dim) inclusive. It can be somewhat confusing to ensure that the index tensor has the correct shape. Viewed pictorially:
_gather
  :: Dynamic
  -> Dynamic
  -> Word               -- ^ dimension to operate on
  -> IndexDynamic
  -> IO ()
_gather r src d ix = with2DynamicState r src $ \s' r' src' ->
  Ix.withDynamicState ix $ \_ ix' ->
    Sig.c_gather s' r' src' (fromIntegral d) ix'

-- | From the Lua docs:
--
-- Writes all values from tensor src or the scalar val into self at the
-- specified indices. The indices are specified with respect to the given
-- dimension, dim, in the manner described in gather. Note that, as for gather,
-- the values of index must be between 1 and self:size(dim) inclusive and all
-- values in a row along the specified dimension must be unique.
_scatter
  :: Dynamic
  -> Word               -- ^ dimension to operate on
  -> IndexDynamic
  -> Dynamic
  -> IO ()
_scatter r d ix src = with2DynamicState r src $ \s' r' src' ->
  Ix.withDynamicState ix $ \_ ix' ->
    Sig.c_scatter s' r' (fromIntegral d) ix' src'

-- | TODO
_scatterAdd
  :: Dynamic
  -> Word               -- ^ dimension to operate on
  -> IndexDynamic
  -> Dynamic
  -> IO ()
_scatterAdd r d ix src = with2DynamicState r src $ \s' r' src' ->
  Ix.withDynamicState ix $ \_ ix' ->
    Sig.c_scatterAdd s' r' (fromIntegral d) ix' src'

-- | TODO
_scatterFill
  :: Dynamic
  -> Word               -- ^ dimension to operate on
  -> IndexDynamic
  -> HsReal
  -> IO ()
_scatterFill r d ix v = runManaged $ do
  s' <- managedState
  r' <- managedTensor r
  liftIO $ Ix.withDynamicState ix $ \_ ix' -> Sig.c_scatterFill s' r' (fromIntegral d) ix' (hs2cReal v)