-------------------------------------------------------------------------------
-- |
-- Module    :  Torch.Indef.Static.Tensor.Index
-- Copyright :  (c) Sam Stites 2017
-- License   :  BSD3
-- Maintainer:  sam@stites.io
-- Stability :  experimental
-- Portability: non-portable
-------------------------------------------------------------------------------
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fno-cse #-}
module Torch.Indef.Static.Tensor.Index where

import Numeric.Dimensions
import GHC.TypeLits
import Control.Exception.Safe
import System.IO.Unsafe

import Torch.Indef.Types
import Torch.Indef.Static.Tensor
import qualified Torch.Indef.Index as Ix
import qualified Torch.Indef.Dynamic.Tensor as Dynamic
import qualified Torch.Indef.Dynamic.Tensor.Index as Dynamic

-- | Static call to 'Dynamic._indexCopy'
_indexCopy :: Tensor d -> Int -> IndexTensor '[n] -> Tensor d' -> IO ()
_indexCopy r x ix t = Dynamic._indexCopy (asDynamic r) x (longAsDynamic ix) (asDynamic t)

-- | Static call to 'Dynamic._indexAdd'
_indexAdd :: Tensor d -> Int -> IndexTensor '[n] -> Tensor d' -> IO ()
_indexAdd r x ix t = Dynamic._indexAdd (asDynamic r) x (longAsDynamic ix) (asDynamic t)

-- | Static call to 'Dynamic._indexFill'
_indexFill :: Tensor d -> Int -> IndexTensor '[n] -> HsReal -> IO ()
_indexFill r x ix v = Dynamic._indexFill (asDynamic r) x (longAsDynamic ix) v

-- | Static call to 'Dynamic._indexSelect'
_indexSelect :: Tensor d -> Tensor d' -> Int -> IndexTensor '[n] -> IO ()
_indexSelect r t d ix = Dynamic._indexSelect (asDynamic r) (asDynamic t) d (longAsDynamic ix)

-- | Static call to 'Dynamic._take'
_take :: Tensor d -> Tensor d' -> IndexTensor '[n] -> IO ()
_take r t ix = Dynamic._take (asDynamic r) (asDynamic t) (longAsDynamic ix)

-- | Static call to 'Dynamic._put'
_put :: Tensor d -> IndexTensor '[n] -> Tensor d' -> Int -> IO ()
_put r ix t d = Dynamic._put (asDynamic r) (longAsDynamic ix) (asDynamic t) d


-- | Retrieve a single row from a matrix
--
-- FIXME: Use 'Idx' and remove the 'throwString' function
getRow
  :: forall t n m . (All KnownDim '[n, m], KnownNat m)
  => Tensor '[n, m] -> Word -> Maybe (Tensor '[1, m])
getRow t r
  | r > dimVal (dim :: Dim n) = Nothing
  | otherwise = unsafeDupablePerformIO $ do
      let res = Dynamic.new (dims :: Dims '[1, m])
      let ixs = Ix.indexDyn [ fromIntegral r ]
      Dynamic._indexSelect res (asDynamic t) 0 ixs
      pure . Just $ asStatic res
{-# NOINLINE getRow #-}

-- | Retrieve a single column from a matrix
--
-- FIXME: Use 'Idx' and remove the 'throwString' function
getColumn
  :: forall t n m . (All KnownDim '[n, m], KnownNat n)
  => Tensor '[n, m] -> Word -> Maybe (Tensor '[n, 1])
getColumn t r
  | r > dimVal (dim :: Dim m) = Nothing
  | otherwise = unsafeDupablePerformIO $ do
      let res = Dynamic.new (dims :: Dims '[n, 1])
      let ixs = Ix.indexDyn [ fromIntegral r ]
      Dynamic._indexSelect res (asDynamic t) 1 ixs
      pure . Just $ asStatic res
{-# NOINLINE getColumn #-}