{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fno-cse #-}
module Torch.Indef.Static.Tensor.TopK where
import Numeric.Dimensions
import System.IO.Unsafe
import Torch.Indef.Types
import Torch.Indef.Static.Tensor
import qualified Torch.Indef.Dynamic.Tensor.TopK as Dynamic
import Torch.Indef.Index
topk
:: forall d' d n
. (All Dimensions '[d, d'], KnownDim n)
=> Tensor d -> Integer -> Word -> TopKOrder -> Maybe KeepDim -> (Tensor d', IndexTensor '[n])
topk t k d o sorted = unsafeDupablePerformIO $ do
let ix :: IndexTensor '[n] = newIx
let r :: Tensor d' = new
Dynamic._topk (asDynamic r, longAsDynamic ix) (asDynamic t) k d o sorted
pure (r, ix)
{-# NOINLINE topk #-}