-------------------------------------------------------------------------------
-- |
-- Module    :  Torch.Indef.Dynamic.Tensor.Index
-- Copyright :  (c) Sam Stites 2017
-- License   :  BSD3
-- Maintainer:  sam@stites.io
-- Stability :  experimental
-- Portability: non-portable
--
-- Index operations for a dyanmic tensor.
-------------------------------------------------------------------------------
module Torch.Indef.Dynamic.Tensor.Index
  ( _indexCopy
  , _indexAdd
  , _indexFill
  , _indexSelect
  , _take
  , _put
  ) where

import Foreign
import Foreign.Ptr
import Torch.Sig.Types
import Control.Monad.Managed
import qualified Torch.Sig.Types          as Sig
import qualified Torch.Sig.Types.Global   as Sig
import qualified Torch.Sig.Tensor.Index   as Sig

import Torch.Indef.Types

-- | Copies the elements of tensor into the original tensor by selecting the indices in the order given in index. The shape of tensor must exactly match the elements indexed or an error will be thrown.
_indexCopy :: Dynamic -> Int -> IndexDynamic -> Dynamic -> IO ()
_indexCopy r i ix t = withLift $ Sig.c_indexCopy
  <$> managedState
  <*> managedTensor r
  <*> pure (fromIntegral i)
  <*> managed (withForeignPtr (snd $ Sig.longDynamicState ix))
  <*> managedTensor t

-- | Accumulate the elements of tensor into the original tensor by adding to the indices in the order given in index. The shape of tensor must exactly match the elements indexed or an error will be thrown.
_indexAdd :: Dynamic -> Int -> IndexDynamic -> Dynamic -> IO ()
_indexAdd r i ix t = withLift $ Sig.c_indexAdd
  <$> managedState
  <*> managedTensor r
  <*> pure (fromIntegral i)
  <*> managed (withForeignPtr (snd $ Sig.longDynamicState ix))
  <*> managedTensor t

-- | Fills the elements of the original Tensor with value val by selecting the indices in the order given in index.
_indexFill :: Dynamic -> Int -> IndexDynamic -> HsReal -> IO ()
_indexFill r i ix v = withLift $ Sig.c_indexFill
  <$> managedState
  <*> managedTensor r
  <*> pure (fromIntegral i)
  <*> managed (withForeignPtr (snd $ Sig.longDynamicState ix))
  <*> pure (Sig.hs2cReal v)

-- | Selects the elements of the original Tensor by the index.
_indexSelect :: Dynamic -> Dynamic -> Int -> IndexDynamic -> IO ()
_indexSelect r t i ix = withLift $ Sig.c_indexSelect
  <$> managedState
  <*> managedTensor r
  <*> managedTensor t
  <*> pure (fromIntegral i)
  <*> managed (withForeignPtr (snd $ Sig.longDynamicState ix))

-- | TODO
_take :: Dynamic -> Dynamic -> IndexDynamic -> IO ()
_take r t ix = withLift $ Sig.c_take
  <$> managedState
  <*> managedTensor r
  <*> managedTensor t
  <*> managed (withForeignPtr (snd $ Sig.longDynamicState ix))

-- | TODO
_put :: Dynamic -> IndexDynamic -> Dynamic -> Int -> IO ()
_put r ix t i = withLift $ Sig.c_put
  <$> managedState
  <*> managedTensor r
  <*> managed (withForeignPtr (snd $ Sig.longDynamicState ix))
  <*> managedTensor t
  <*> pure (fromIntegral i)

-- class GPUTensorIndex Dynamic where
--   _indexCopy_long :: t -> Int -> IndexDynamic t -> t -> IO ()
--   _indexAdd_long :: t -> Int -> IndexDynamic t -> t -> IO ()
--   _indexFill_long :: t -> Int -> IndexDynamic t -> Word -> IO ()
--   _indexSelect_long :: t -> t -> Int -> IndexDynamic t -> IO ()
--   _calculateAdvancedIndexingOffsets :: IndexDynamic t -> t -> Integer -> [IndexTensor t] -> IO ()