-------------------------------------------------------------------------------
-- |
-- Module    :  Torch.Indef.Static.Tensor.TopK
-- Copyright :  (c) Sam Stites 2017
-- License   :  BSD3
-- Maintainer:  sam@stites.io
-- Stability :  experimental
-- Portability: non-portable
-------------------------------------------------------------------------------
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fno-cse #-}
module Torch.Indef.Static.Tensor.TopK where

import Numeric.Dimensions
import System.IO.Unsafe

import Torch.Indef.Types
import Torch.Indef.Static.Tensor
import qualified Torch.Indef.Dynamic.Tensor.TopK as Dynamic
import Torch.Indef.Index

-- | returns all @k@ smallest elements in a tensor over a given dimension, including their indices, in unsorted order.
topk
  :: forall d' d n
  .  (All Dimensions '[d, d'], KnownDim n)
  => Tensor d -> Integer -> Word -> TopKOrder -> Maybe KeepDim -> (Tensor d', IndexTensor '[n])
topk t k d o sorted = unsafeDupablePerformIO $ do
  let ix :: IndexTensor '[n] = newIx
  let r  :: Tensor d' = new
  Dynamic._topk (asDynamic r, longAsDynamic ix) (asDynamic t) k d o sorted
  pure (r, ix)
{-# NOINLINE topk #-}