{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Accelerate.Utility.Ord where

import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate
          (Exp, Acc, Array, Elt, Slice, Shape, IsScalar, Scalar,
           (?), (<=*), )


argmin ::
   (Elt a, Elt b, IsScalar a) =>
   Exp (a, b) -> Exp (a, b) -> Exp (a, b)
argmin x y  =  A.fst x <=* A.fst y ? (x,y)

argminimum ::
   (Slice sh, Shape sh, Elt a, Elt b, IsScalar a) =>
   Acc (Array sh (a, b)) -> Acc (Scalar (a, b))
argminimum = A.fold1All argmin


argmax ::
   (Elt a, Elt b, IsScalar a) =>
   Exp (a, b) -> Exp (a, b) -> Exp (a, b)
argmax x y  =  A.fst x <=* A.fst y ? (y,x)

argmaximum ::
   (Slice sh, Shape sh, Elt a, Elt b, IsScalar a) =>
   Acc (Array sh (a, b)) -> Acc (A.Scalar (a, b))
argmaximum = A.fold1All argmax