{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
----------------------------------------------------------------
--                                                  ~ 2021.10.17
-- |
-- Module      :  Data.List.Extras.Argmax
-- Copyright   :  Copyright (c) 2007--2021 wren gayle romano
-- License     :  BSD3
-- Maintainer  :  wren@cpan.org
-- Stability   :  experimental
-- Portability :  Haskell98
--
-- This module provides variants of the 'maximum' and 'minimum'
-- functions which return the elements for which some function is
-- maximized or minimized.
----------------------------------------------------------------

module Data.List.Extras.Argmax
    (
    -- * Utility functions
      catchNull

    -- * Generic versions
    , argmaxBy, argmaxesBy, argmaxWithMaxBy, argmaxesWithMaxBy

    -- * Maximum variations
    , argmax,   argmaxes,   argmaxWithMax,   argmaxesWithMax

    -- * Minimum variations
    , argmin,   argmins,    argminWithMin,   argminsWithMin

    {- TODO: CPS and monadic variants; argmax2, argmax3,... -}
    {- TODO: make sure argmax et al are "good consumers" for fusion -}
    ) where
-- argmaxM       :: (Monad m, Ord b) => (a -> m b) -> [a] -> m (Maybe a)

import Data.List (foldl')

----------------------------------------------------------------
----------------------------------------------------------------

-- | Apply a list function safely, i.e. when the list is non-empty.
-- All other functions will throw errors on empty lists, so use
-- this to make your own safe variations.
catchNull :: ([a] -> b) -> ([a] -> Maybe b)
{-# INLINE catchNull #-}
-- We use the explicit lambda in order to improve inlining in ghc-7.
catchNull :: ([a] -> b) -> [a] -> Maybe b
catchNull [a] -> b
f = \[a]
xs ->
    case [a]
xs of
    []  -> Maybe b
forall a. Maybe a
Nothing
    a
_:[a]
_ -> b -> Maybe b
forall a. a -> Maybe a
Just ([a] -> b
f [a]
xs)


-- | Minimize the number of string literals
emptyListError :: String -> a
{-# NOINLINE emptyListError #-}
emptyListError :: String -> a
emptyListError String
fun =
    String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$ String
"Data.List.Extras.Argmax."String -> String -> String
forall a. [a] -> [a] -> [a]
++String
funString -> String -> String
forall a. [a] -> [a] -> [a]
++String
": empty list"


-- | Apply a list function unsafely. For internal use.
throwNull :: String -> (a -> [a] -> b) -> ([a] -> b)
{-# INLINE throwNull #-}
-- We use the explicit lambda in order to improve inlining in ghc-7.
throwNull :: String -> (a -> [a] -> b) -> [a] -> b
throwNull String
fun a -> [a] -> b
f = \[a]
xs ->
    case [a]
xs of
    []    -> String -> b
forall a. String -> a
emptyListError String
fun
    a
x:[a]
xs' -> a -> [a] -> b
f a
x [a]
xs'

----------------------------------------------------------------
----------------------------------------------------------------
-- | Tail-recursive driver
_argmaxWithMaxBy :: (b -> b -> Bool) -> (a -> b) -> a -> [a] -> (a,b)
{-# INLINE _argmaxWithMaxBy #-}
-- We use the explicit lambda in order to improve inlining in ghc-7.
_argmaxWithMaxBy :: (b -> b -> Bool) -> (a -> b) -> a -> [a] -> (a, b)
_argmaxWithMaxBy b -> b -> Bool
isBetterThan a -> b
f =
    \a
x [a]
xs -> ((a, b) -> a -> (a, b)) -> (a, b) -> [a] -> (a, b)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (a, b) -> a -> (a, b)
cmp (a
x, a -> b
f a
x) [a]
xs
    where
    cmp :: (a, b) -> a -> (a, b)
cmp bfb :: (a, b)
bfb@(a
_,b
fb) a
a =
        let  fa :: b
fa = a -> b
f a
a in
        if   b
fa b -> b -> Bool
`isBetterThan` b
fb
        then (a
a,b
fa)
        else (a, b)
bfb


-- | Tail-recursive driver
_argmaxesWithMaxBy :: (b -> b -> Ordering) -> (a -> b) -> a -> [a] -> ([a],b)
{-# INLINE _argmaxesWithMaxBy #-}
-- We use the explicit lambda in order to improve inlining in ghc-7.
_argmaxesWithMaxBy :: (b -> b -> Ordering) -> (a -> b) -> a -> [a] -> ([a], b)
_argmaxesWithMaxBy b -> b -> Ordering
isBetterEqualThan a -> b
f =
    \a
x [a]
xs -> (([a], b) -> a -> ([a], b)) -> ([a], b) -> [a] -> ([a], b)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ([a], b) -> a -> ([a], b)
cmp ([a
x], a -> b
f a
x) [a]
xs
    where
    cmp :: ([a], b) -> a -> ([a], b)
cmp bsfb :: ([a], b)
bsfb@([a]
bs,b
fb) a
a =
        let  fa :: b
fa = a -> b
f a
a in
        case b -> b -> Ordering
isBetterEqualThan b
fa b
fb of
             Ordering
GT -> ([a
a],  b
fa)
             Ordering
EQ -> (a
aa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
bs, b
fb)
             Ordering
_  -> ([a], b)
bsfb


_argmaxBy :: (b -> b -> Bool) -> (a -> b) -> a -> [a] -> a
{-# INLINE _argmaxBy #-}
-- We use the point-free style in order to improve inlining in ghc-7.
_argmaxBy :: (b -> b -> Bool) -> (a -> b) -> a -> [a] -> a
_argmaxBy b -> b -> Bool
k a -> b
f = ((a, b) -> a
forall a b. (a, b) -> a
fst ((a, b) -> a) -> ([a] -> (a, b)) -> [a] -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.) (([a] -> (a, b)) -> [a] -> a)
-> (a -> [a] -> (a, b)) -> a -> [a] -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> b -> Bool) -> (a -> b) -> a -> [a] -> (a, b)
forall b a. (b -> b -> Bool) -> (a -> b) -> a -> [a] -> (a, b)
_argmaxWithMaxBy b -> b -> Bool
k a -> b
f


_argmaxesBy :: (b -> b -> Ordering) -> (a -> b) -> a -> [a] -> [a]
{-# INLINE _argmaxesBy #-}
-- We use the point-free style in order to improve inlining in ghc-7.
_argmaxesBy :: (b -> b -> Ordering) -> (a -> b) -> a -> [a] -> [a]
_argmaxesBy b -> b -> Ordering
k a -> b
f = (([a], b) -> [a]
forall a b. (a, b) -> a
fst (([a], b) -> [a]) -> ([a] -> ([a], b)) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.) (([a] -> ([a], b)) -> [a] -> [a])
-> (a -> [a] -> ([a], b)) -> a -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> b -> Ordering) -> (a -> b) -> a -> [a] -> ([a], b)
forall b a.
(b -> b -> Ordering) -> (a -> b) -> a -> [a] -> ([a], b)
_argmaxesWithMaxBy b -> b -> Ordering
k a -> b
f

----------------------------------------------------------------
----------------------------------------------------------------

bool    :: (a -> a -> Ordering) -> (a -> a -> Bool)
bool :: (a -> a -> Ordering) -> a -> a -> Bool
bool a -> a -> Ordering
ord = \a
a a
b -> a -> a -> Ordering
ord a
a a
b Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
== Ordering
GT


-- | Return an element of the list which maximizes the function
-- according to a user-defined ordering.
argmaxBy        :: (b -> b -> Ordering) -> (a -> b) -> [a] -> a
argmaxBy :: (b -> b -> Ordering) -> (a -> b) -> [a] -> a
argmaxBy   b -> b -> Ordering
ord a -> b
f = String -> (a -> [a] -> a) -> [a] -> a
forall a b. String -> (a -> [a] -> b) -> [a] -> b
throwNull String
"argmaxBy"
                 ((a -> [a] -> a) -> [a] -> a) -> (a -> [a] -> a) -> [a] -> a
forall a b. (a -> b) -> a -> b
$ (b -> b -> Bool) -> (a -> b) -> a -> [a] -> a
forall b a. (b -> b -> Bool) -> (a -> b) -> a -> [a] -> a
_argmaxBy ((b -> b -> Ordering) -> b -> b -> Bool
forall a. (a -> a -> Ordering) -> a -> a -> Bool
bool b -> b -> Ordering
ord) a -> b
f


-- | Return all elements of the list which maximize the function
-- according to a user-defined ordering.
argmaxesBy      :: (b -> b -> Ordering) -> (a -> b) -> [a] -> [a]
argmaxesBy :: (b -> b -> Ordering) -> (a -> b) -> [a] -> [a]
argmaxesBy b -> b -> Ordering
ord a -> b
f = String -> (a -> [a] -> [a]) -> [a] -> [a]
forall a b. String -> (a -> [a] -> b) -> [a] -> b
throwNull String
"argmaxesBy"
                 ((a -> [a] -> [a]) -> [a] -> [a])
-> (a -> [a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$ (b -> b -> Ordering) -> (a -> b) -> a -> [a] -> [a]
forall b a. (b -> b -> Ordering) -> (a -> b) -> a -> [a] -> [a]
_argmaxesBy b -> b -> Ordering
ord a -> b
f


-- | Return an element of the list which maximizes the function
-- according to a user-defined ordering, and return the value of
-- the function at that element as well.
argmaxWithMaxBy        :: (b -> b -> Ordering) -> (a -> b) -> [a] -> (a, b)
argmaxWithMaxBy :: (b -> b -> Ordering) -> (a -> b) -> [a] -> (a, b)
argmaxWithMaxBy   b -> b -> Ordering
ord a -> b
f = String -> (a -> [a] -> (a, b)) -> [a] -> (a, b)
forall a b. String -> (a -> [a] -> b) -> [a] -> b
throwNull String
"argmaxWithMaxBy"
                        ((a -> [a] -> (a, b)) -> [a] -> (a, b))
-> (a -> [a] -> (a, b)) -> [a] -> (a, b)
forall a b. (a -> b) -> a -> b
$ (b -> b -> Bool) -> (a -> b) -> a -> [a] -> (a, b)
forall b a. (b -> b -> Bool) -> (a -> b) -> a -> [a] -> (a, b)
_argmaxWithMaxBy ((b -> b -> Ordering) -> b -> b -> Bool
forall a. (a -> a -> Ordering) -> a -> a -> Bool
bool b -> b -> Ordering
ord) a -> b
f


-- | Return all elements of the list which maximize the function
-- according to a user-defined ordering, and return the value of
-- the function at those elements as well.
argmaxesWithMaxBy      :: (b -> b -> Ordering) -> (a -> b) -> [a] -> ([a], b)
argmaxesWithMaxBy :: (b -> b -> Ordering) -> (a -> b) -> [a] -> ([a], b)
argmaxesWithMaxBy b -> b -> Ordering
ord a -> b
f = String -> (a -> [a] -> ([a], b)) -> [a] -> ([a], b)
forall a b. String -> (a -> [a] -> b) -> [a] -> b
throwNull String
"argmaxesWithMaxBy"
                        ((a -> [a] -> ([a], b)) -> [a] -> ([a], b))
-> (a -> [a] -> ([a], b)) -> [a] -> ([a], b)
forall a b. (a -> b) -> a -> b
$ (b -> b -> Ordering) -> (a -> b) -> a -> [a] -> ([a], b)
forall b a.
(b -> b -> Ordering) -> (a -> b) -> a -> [a] -> ([a], b)
_argmaxesWithMaxBy b -> b -> Ordering
ord a -> b
f

----------------------------------------------------------------
-- SPECIALIZE on b \in {Int,Integer,Float,Double} for the four
-- functions below nearly doubles the library size (about +21kB).
-- For a basic utility library that's a bit excessive, though if
-- we break the argmax stuff out from list-extras then we might go
-- through with it for performance sake.

-- | Return an element of the list which maximizes the function.
argmax    :: (Ord b) => (a -> b) -> [a] -> a
argmax :: (a -> b) -> [a] -> a
argmax   a -> b
f = String -> (a -> [a] -> a) -> [a] -> a
forall a b. String -> (a -> [a] -> b) -> [a] -> b
throwNull String
"argmax"
           ((a -> [a] -> a) -> [a] -> a) -> (a -> [a] -> a) -> [a] -> a
forall a b. (a -> b) -> a -> b
$ (b -> b -> Bool) -> (a -> b) -> a -> [a] -> a
forall b a. (b -> b -> Bool) -> (a -> b) -> a -> [a] -> a
_argmaxBy b -> b -> Bool
forall a. Ord a => a -> a -> Bool
(>) a -> b
f

-- | Return all elements of the list which maximize the function.
argmaxes  :: (Ord b) => (a -> b) -> [a] -> [a]
argmaxes :: (a -> b) -> [a] -> [a]
argmaxes a -> b
f = String -> (a -> [a] -> [a]) -> [a] -> [a]
forall a b. String -> (a -> [a] -> b) -> [a] -> b
throwNull String
"argmaxes"
           ((a -> [a] -> [a]) -> [a] -> [a])
-> (a -> [a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$ (b -> b -> Ordering) -> (a -> b) -> a -> [a] -> [a]
forall b a. (b -> b -> Ordering) -> (a -> b) -> a -> [a] -> [a]
_argmaxesBy b -> b -> Ordering
forall a. Ord a => a -> a -> Ordering
compare a -> b
f


-- | Return an element of the list which maximizes the function,
-- and return the value of the function at that element as well.
argmaxWithMax    :: (Ord b) => (a -> b) -> [a] -> (a, b)
argmaxWithMax :: (a -> b) -> [a] -> (a, b)
argmaxWithMax   a -> b
f = String -> (a -> [a] -> (a, b)) -> [a] -> (a, b)
forall a b. String -> (a -> [a] -> b) -> [a] -> b
throwNull String
"argmaxWithMax"
                  ((a -> [a] -> (a, b)) -> [a] -> (a, b))
-> (a -> [a] -> (a, b)) -> [a] -> (a, b)
forall a b. (a -> b) -> a -> b
$ (b -> b -> Bool) -> (a -> b) -> a -> [a] -> (a, b)
forall b a. (b -> b -> Bool) -> (a -> b) -> a -> [a] -> (a, b)
_argmaxWithMaxBy b -> b -> Bool
forall a. Ord a => a -> a -> Bool
(>) a -> b
f


-- | Return all elements of the list which maximize the function,
-- and return the value of the function at those elements as well.
argmaxesWithMax  :: (Ord b) => (a -> b) -> [a] -> ([a], b)
argmaxesWithMax :: (a -> b) -> [a] -> ([a], b)
argmaxesWithMax a -> b
f = String -> (a -> [a] -> ([a], b)) -> [a] -> ([a], b)
forall a b. String -> (a -> [a] -> b) -> [a] -> b
throwNull String
"argmaxesWithMax"
                  ((a -> [a] -> ([a], b)) -> [a] -> ([a], b))
-> (a -> [a] -> ([a], b)) -> [a] -> ([a], b)
forall a b. (a -> b) -> a -> b
$ (b -> b -> Ordering) -> (a -> b) -> a -> [a] -> ([a], b)
forall b a.
(b -> b -> Ordering) -> (a -> b) -> a -> [a] -> ([a], b)
_argmaxesWithMaxBy b -> b -> Ordering
forall a. Ord a => a -> a -> Ordering
compare a -> b
f

----------------------------------------------------------------

-- | Return an element of the list which minimizes the function.
argmin   :: (Ord b) => (a -> b) -> [a] -> a
argmin :: (a -> b) -> [a] -> a
argmin  a -> b
f = String -> (a -> [a] -> a) -> [a] -> a
forall a b. String -> (a -> [a] -> b) -> [a] -> b
throwNull String
"argmax"
          ((a -> [a] -> a) -> [a] -> a) -> (a -> [a] -> a) -> [a] -> a
forall a b. (a -> b) -> a -> b
$ (b -> b -> Bool) -> (a -> b) -> a -> [a] -> a
forall b a. (b -> b -> Bool) -> (a -> b) -> a -> [a] -> a
_argmaxBy b -> b -> Bool
forall a. Ord a => a -> a -> Bool
(<) a -> b
f

-- | Return all elements of the list which minimize the function.
argmins  :: (Ord b) => (a -> b) -> [a] -> [a]
argmins :: (a -> b) -> [a] -> [a]
argmins a -> b
f = String -> (a -> [a] -> [a]) -> [a] -> [a]
forall a b. String -> (a -> [a] -> b) -> [a] -> b
throwNull String
"argmins"
          ((a -> [a] -> [a]) -> [a] -> [a])
-> (a -> [a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$ (b -> b -> Ordering) -> (a -> b) -> a -> [a] -> [a]
forall b a. (b -> b -> Ordering) -> (a -> b) -> a -> [a] -> [a]
_argmaxesBy ((b -> b -> Ordering) -> b -> b -> Ordering
forall a b c. (a -> b -> c) -> b -> a -> c
flip b -> b -> Ordering
forall a. Ord a => a -> a -> Ordering
compare) a -> b
f


-- | Return an element of the list which minimizes the function,
-- and return the value of the function at that element as well.
argminWithMin   :: (Ord b) => (a -> b) -> [a] -> (a, b)
argminWithMin :: (a -> b) -> [a] -> (a, b)
argminWithMin  a -> b
f = String -> (a -> [a] -> (a, b)) -> [a] -> (a, b)
forall a b. String -> (a -> [a] -> b) -> [a] -> b
throwNull String
"argminWithMin"
                 ((a -> [a] -> (a, b)) -> [a] -> (a, b))
-> (a -> [a] -> (a, b)) -> [a] -> (a, b)
forall a b. (a -> b) -> a -> b
$ (b -> b -> Bool) -> (a -> b) -> a -> [a] -> (a, b)
forall b a. (b -> b -> Bool) -> (a -> b) -> a -> [a] -> (a, b)
_argmaxWithMaxBy b -> b -> Bool
forall a. Ord a => a -> a -> Bool
(<) a -> b
f

-- | Return all elements of the list which minimize the function,
-- and return the value of the function at those elements as well.
argminsWithMin  :: (Ord b) => (a -> b) -> [a] -> ([a], b)
argminsWithMin :: (a -> b) -> [a] -> ([a], b)
argminsWithMin a -> b
f = String -> (a -> [a] -> ([a], b)) -> [a] -> ([a], b)
forall a b. String -> (a -> [a] -> b) -> [a] -> b
throwNull String
"argminsWithMin"
                 ((a -> [a] -> ([a], b)) -> [a] -> ([a], b))
-> (a -> [a] -> ([a], b)) -> [a] -> ([a], b)
forall a b. (a -> b) -> a -> b
$ (b -> b -> Ordering) -> (a -> b) -> a -> [a] -> ([a], b)
forall b a.
(b -> b -> Ordering) -> (a -> b) -> a -> [a] -> ([a], b)
_argmaxesWithMaxBy ((b -> b -> Ordering) -> b -> b -> Ordering
forall a b c. (a -> b -> c) -> b -> a -> c
flip b -> b -> Ordering
forall a. Ord a => a -> a -> Ordering
compare) a -> b
f

----------------------------------------------------------------
----------------------------------------------------------- fin.