-----------------------------------------------------------------------------
-- |
-- Module      :  Data.OrdList
-- Copyright   :  (c) Leon P Smith 2009
-- License     :  BSD3
--
-- Maintainer  :  leon at melding-monads dot com
-- Stability   :  experimental
-- Portability :  portable
--
-- This module implements bag and set operations on ordered lists.
-- Except for variations of the  'sort' and 'isSorted' functions,
-- every function assumes that any list arguments are sorted lists.
-- Assuming this precondition is met,  every resulting list is also
-- sorted.
--
-- Note that these functions handle multisets, and are left-biased.
-- Thus, even assuming the arguments are sorted,  'isect' may not be a
-- return the same results as Data.List.intersection,  due to multiplicity.
--
-----------------------------------------------------------------------------

module  Data.OrdList
     (  member, memberBy, has, hasBy
     ,  isSorted, isSortedBy
     ,  insertBag, insertBagBy
     ,  insertSet, insertSetBy
     ,  isect, isectBy
     ,  union, unionBy
     ,  minus, minusBy
     ,  xunion, xunionBy
     ,  merge, mergeBy
     ,  subset, subsetBy
     ,  sort, sortBy
     ,  sortOn, sortOn'
     ,  nubSort, nubSortBy
     ,  nubSortOn, nubSortOn'
     ,  nub, nubBy
     )  where

import Data.List(sort,sortBy)
-- |  Returns 'True' if the elements of a list occur in non-descending order,  equivalent to 'isSortedBy' '(<=)'
isSorted :: (Ord a) => [a] -> Bool
isSorted = isSortedBy (<=)

-- |  Returns 'True' if the predicate returns true for all adjacent pairs of elements in the list
isSortedBy :: (a -> a -> Bool) -> [a] -> Bool
isSortedBy lte = loop
  where
    loop []       = True
    loop [_]      = True
    loop (x:y:zs) = (x `lte` y) && loop (y:zs)

-- |  Returns 'True' if the element appears in the list
member :: (Ord a) => a -> [a] -> Bool
member = memberBy compare

memberBy :: (a -> a -> Ordering) -> a -> [a] -> Bool
memberBy cmp x = loop
  where
    loop []     = False
    loop (y:ys) = case cmp x y of
                    LT -> False
                    EQ -> True
                    GT -> loop ys

-- |  Returns 'True' if the element appears in the list
has :: (Ord a) => [a] -> a -> Bool
has xs y = memberBy compare y xs

hasBy :: (a -> a -> Ordering) -> [a] -> a -> Bool
hasBy cmp xs y = memberBy cmp y xs

-- |  Inserts an element into a list,  allowing for duplicate elements
insertBag :: (Ord a) => a -> [a] -> [a]
insertBag = insertBagBy compare

insertBagBy :: (a -> a -> Ordering) -> a -> [a] -> [a]
insertBagBy cmp = loop
  where
    loop x [] = [x]
    loop x (y:ys)
      = case cmp x y of
         LT -> y:loop x ys
         _  -> x:y:ys

-- |  Inserts an element into a list only if it is not already there.
insertSet :: (Ord a) => a -> [a] -> [a]
insertSet = insertSetBy compare

insertSetBy :: (a -> a -> Ordering) -> a -> [a] -> [a]
insertSetBy cmp = loop
  where
    loop x [] = [x]
    loop x (y:ys) = case cmp x y of
            LT -> y:loop x ys
            EQ -> y:ys
            GT -> x:y:ys

-- |  Intersection of two ordered lists.
--
-- > isect [1,3,5] [2,4,6] == []
-- > isect [2,4,6,8] [3,6,9] == [6]
-- > isect [1,2,2,2] [1,1,1,2,2] == [1,2,2]
isect :: (Ord a) => [a] -> [a] -> [a]
isect = isectBy compare

isectBy :: (a -> a -> Ordering) -> [a] -> [a] -> [a]
isectBy cmp = loop
  where
     loop [] _ys  = []
     loop _xs []  = []
     loop (x:xs) (y:ys)
       = case cmp x y of
          LT ->     loop xs (y:ys)
          EQ -> x : loop xs ys
          GT ->     loop (x:xs) ys

-- |  Union of two ordered lists.
--
-- > union [1,3,5] [2,4,6] == [1..6]
-- > union [2,4,6,8] [3,6,9] == [2,3,4,6,8,9]
-- > union [1,2,2,2] [1,1,1,2,2] == [1,1,1,2,2,2]
union :: (Ord a) => [a] -> [a] -> [a]
union = unionBy compare

unionBy :: (a -> a -> Ordering) -> [a] -> [a] -> [a]
unionBy cmp = loop
  where
     loop [] ys = ys
     loop xs [] = xs
     loop (x:xs) (y:ys)
       = case cmp x y of
          LT -> x : loop xs (y:ys)
          EQ -> x : loop xs ys
          GT -> y : loop (x:xs) ys



-- |  Difference
--
-- > minus [1,3,5] [2,4,6] == [1,3,5]
-- > minus [2,4,6,8] [3,6,9] == [2,4,8]
-- > minus [1,2,2,2] [1,1,1,2,2] == [2]
minus :: (Ord a) => [a] -> [a] -> [a]
minus = minusBy compare

minusBy :: (a -> a -> Ordering) -> [a] -> [a] -> [a]
minusBy cmp = loop
  where
     loop [] _ys = []
     loop xs [] = xs
     loop (x:xs) (y:ys)
       = case cmp x y of
          LT -> x : loop xs (y:ys)
          EQ ->     loop xs ys
          GT ->     loop (x:xs) ys

-- |  Exclusive union
--
-- > xunion [1,3,5] [2,4,6] == [1..6]
-- > xunion [2,4,6,8] [3,6,9] == [2,3,4,8]
-- > xunion [1,2,2,2] [1,1,1,2,2] == [1,1,2]
xunion :: (Ord a) => [a] -> [a] -> [a]
xunion = xunionBy compare

xunionBy :: (a -> a -> Ordering) -> [a] -> [a] -> [a]
xunionBy cmp = loop
  where
     loop [] ys = ys
     loop xs [] = xs
     loop (x:xs) (y:ys)
       = case cmp x y of
          LT -> x : loop xs (y:ys)
          EQ ->     loop xs ys
          GT -> y : loop (x:xs) ys

{-
genSectBy cmp p = loop
  where
     loop [] ys | p False True = ys
                | otherwise    = []
     loop xs [] | p True False = xs
                | otherwise    = []
     loop (x:xs) (y:ys)
       = case cmp x y of
          LT | p True False -> x : loop xs (y:ys)
             | otherwise    ->     loop xs (y:ys)
          EQ | p True True  -> x : loop xs ys
             | otherwise    ->     loop xs ys
          GT | p False True -> y : loop (x:xs) ys
             | otherwise    ->     loop (x:xs) ys
-}

-- |  Merge two ordered lists
--
-- > merge [1,3,5] [2,4,6] == [1,2,3,4,5,6]
-- > merge [2,4,6,8] [3,6,9] == [2,3,4,6,6,8,9]
-- > merge [1,2,2,2] [1,1,1,2,2] == [1,1,1,1,2,2,2,2,2]
merge :: (Ord a) => [a] -> [a] -> [a]
merge = mergeBy compare

mergeBy :: (a -> a -> Ordering) -> [a] -> [a] -> [a]
mergeBy cmp = loop
  where
    loop [] ys  = ys
    loop xs []  = xs
    loop (x:xs) (y:ys)
      = case cmp x y of
         GT -> y : loop (x:xs) ys
         _  -> x : loop xs (y:ys)

-- |  Returns true if the first list is a sub-list of the second
subset :: (Ord a) => [a] -> [a] -> Bool
subset = subsetBy compare

subsetBy :: (a -> a -> Ordering) -> [a] -> [a] -> Bool
subsetBy cmp = loop
  where
    loop [] _ys = True
    loop _xs [] = False
    loop (x:xs) (y:ys)
      = case cmp x y of
         LT -> False
         EQ -> loop xs ys
         GT -> loop (x:xs) ys

{-
sort :: Ord a => [a] -> [a]
sort = sortBy compare

sortBy :: (a -> a -> Ordering) -> [a] -> [a]
sortBy cmp = loop . map (\x -> [x])
  where
    loop []   = []
    loop [xs] = xs
    loop xss  = loop (merge_pairs xss)

    merge_pairs []          = []
    merge_pairs [xs]        = [xs]
    merge_pairs (xs:ys:xss) = mergeBy cmp xs ys : merge_pairs xss
-}

-- |  decorate-sort-undecorate, aka the \"Schwartzian transform\"
sortOn :: Ord b => (a -> b) -> [a] -> [a]
sortOn f  = map snd . sortOn' fst .  map (\x -> (f x, x))

-- |  Recomputes instead;  better for some things such as projections.
sortOn' :: Ord b => (a -> b) -> [a] -> [a]
sortOn' f = sortBy (\x y -> compare (f x) (f y))


-- |  Equivalent to nub . sort,  except somewhat more efficient
nubSort :: Ord a => [a] -> [a]
nubSort = nubSortBy compare

nubSortBy :: (a -> a -> Ordering) -> [a] -> [a]
nubSortBy cmp = loop . map (\x -> [x])
  where
    loop []   = []
    loop [xs] = xs
    loop xss  = loop (union_pairs xss)

    union_pairs []          = []
    union_pairs [xs]        = [xs]
    union_pairs (xs:ys:xss) = unionBy cmp xs ys : union_pairs xss

nubSortOn :: Ord b => (a -> b) -> [a] -> [a]
nubSortOn f = map snd . nubSortOn' fst . map (\x -> (f x, x))

nubSortOn' :: Ord b => (a -> b) -> [a] -> [a]
nubSortOn' f = nubSortBy (\x y -> compare (f x) (f y))


-- |  Equivalent to nub on ordered lists, except faster; on unordered
-- lists it also removes elements that are smaller than any preceding element.
--
-- > nub [2,0,1,3,3] == [2,3]
nub :: (Ord a) => [a] -> [a]
nub = nubBy (>)

nubBy :: (a -> a -> Bool) -> [a] -> [a]
nubBy p xs = case xs of
              [] -> []
              (x:xs') -> x : loop x xs'
  where
    loop _ [] = []
    loop x (y:ys)
       | p x y     = loop x ys
       | otherwise = y : loop y ys