{-# LANGUAGE CPP #-}
{-# LANGUAGE MultiParamTypeClasses #-}
-- |
-- Module      :  Data.IntSet.Translatable
-- Copyright   :  (c) Jannis Harder 2011
-- License     :  MIT
-- Maintainer  :  Jannis Harder <jannis@harderweb.de>
--
-- An implementation of integer sets with a constant time 'translate'
-- operation, where 'translate' is defined to be
-- @'translate' x s = 'map' (+x) s@.
--
-- Since many function names (but not the type name) clash with
-- "Prelude" names, this module is usually imported @qualified@, e.g.
--
-- >  import Data.IntSet.Translatable (IntSet)
-- >  import qualified Data.IntSet.Translatable as IntSet
--
-- This implementation is based on /Finger-Trees/ storing differences
-- of consecutive entries of the ordered sequence of set elements.
-- With this representation, a translation of all elements can be
-- realized by changing only the leftmost element of the Finger-Tree
-- which is a constant time operation. Together with caching of the
-- accumulated differences most set operations can be implemented
-- efficiently too.
module Data.IntSet.Translatable (
  -- * Set type
    IntSet

  -- * Operators
  , (\\)

  -- * Query
  , null
  , size
  , member
  , notMember

  -- * Construction
  , empty
  , singleton
  , insert
  , delete

  -- * Combine
  , union
  , unions
  , difference
  , intersection

  -- * Filter
  , filter
  , partition
  , split
  , splitMember

  -- * Min\/Max
  , findMin
  , findMax
  , deleteMin
  , deleteMax
  , deleteFindMin
  , deleteFindMax
  , maxView
  , minView

  -- * Map
  , map
  , translate

  -- * Fold
  , fold

  -- * Conversion
  -- ** List
  , elems
  , toList
  , fromList

  -- ** Ordered list
  , toAscList
  , fromAscList
  , fromDistinctAscList

  ) where

import Prelude hiding (null, filter, map)

#if __GLASGOW_HASKELL__
import Text.Read
#endif

import Data.Monoid (Monoid(..))
import qualified Data.List as List
import Data.List (group, sort, foldl')
import Data.Maybe (fromMaybe)

import Control.Arrow ((***))
import Control.Monad (join)

import qualified Data.FingerTree as FingerTree
import Data.FingerTree (FingerTree, Measured, measure, (<|), (|>), (><),
                        ViewL(..), ViewR(..), viewl, viewr)

newtype Diff = Diff { getDiff :: Int } deriving Eq

data DiffSum = DiffSum { getSum :: !Int
                       , getSize :: !Int
                       }

instance Monoid DiffSum where
  mempty = DiffSum 0 0
  mappend a b = DiffSum { getSum = getSum a + getSum b
                        , getSize = getSize a + getSize b }

instance Measured DiffSum Diff where
  measure a = DiffSum { getSum = getDiff a, getSize = 1}

newtype IntSet = IntSet (FingerTree DiffSum Diff) deriving Eq

instance Ord IntSet where
    compare s1 s2 = compare (toAscList s1) (toAscList s2)
    -- lazyness should make this quite efficient

instance Show IntSet where
  showsPrec p xs = showParen (p > 10) $
    showString "fromList " . shows (toList xs)

instance Read IntSet where
#ifdef __GLASGOW_HASKELL__
  readPrec = parens $ prec 10 $ do
    Ident "fromList" <- lexP
    xs <- readPrec
    return (fromList xs)

  readListPrec = readListPrecDefault
#else
  readsPrec p = readParen (p > 10) $ \ r -> do
    ("fromList",s) <- lex r
    (xs,t) <- reads s
    return (fromList xs,t)
#endif

instance Monoid IntSet where
    mempty  = empty
    mappend = union
    mconcat = unions

-- | /O(???)/. See 'difference'.
(\\) :: IntSet -> IntSet -> IntSet
m1 \\ m2 = difference m1 m2

-- | /O(1)/. Is the set empty?
null :: IntSet -> Bool
null (IntSet xs) = FingerTree.null xs

-- | /O(1)/. Cardinality of the set.
size :: IntSet -> Int
size (IntSet xs) = getSize $ measure xs

-- | /O(log(n))/. Is the value a member of the set?
member :: Int -> IntSet -> Bool
member k (IntSet s) = case FingerTree.split ((> k) . getSum) s of
  (ls, _) | FingerTree.null ls       -> False
          | getSum (measure ls) == k -> True
          | otherwise                -> False

-- | /O(log(n)/. Is the element not in the set?
notMember :: Int -> IntSet -> Bool
notMember k = not . member k

-- | /O(1)/. The empty set.
empty :: IntSet
empty = IntSet FingerTree.empty

-- | /O(1)/. A set of one element.
singleton :: Int -> IntSet
singleton = IntSet . FingerTree.singleton . Diff

-- | /O(log(n))/. Add a value to the set.
insert :: Int -> IntSet -> IntSet
insert k (IntSet s) = IntSet $ case FingerTree.split ((> k) . getSum) s of
  (ls, rs) | FingerTree.null ls       -> Diff k <| translate' (-k) rs
           | d == 0                   -> s
           | otherwise                -> ls >< Diff d <| translate' (-d) rs
    where d = k - getSum (measure ls)

-- | /O(log(n))/. Delete a value in the set. Returns the
-- original set when the value was not present.
delete :: Int -> IntSet -> IntSet
delete k (IntSet s) = IntSet $ case FingerTree.split ((> k) . getSum) s of
  (ls, rs) | getSum (measure ls) == k ->
             case viewr ls of
               EmptyR   -> s
               ls' :> _ -> ls' >< translate' (k - getSum (measure ls')) rs
           | otherwise                -> s

-- | /O(m log(n /\// m))/ where /m<=n/. The union  of two sets. /O(log m)/
-- if all elements of one set are larger than all elements of the
-- other set.
union :: IntSet -> IntSet -> IntSet
union (IntSet xs) (IntSet ys) = IntSet $ merge xs ys
  where merge as bs = case viewl bs of
          EmptyL -> as
          Diff b :< bs' -> ls >< d <|? merge bs' (translate' (-d) rs)
            where (ls, rs) = FingerTree.split (\v -> getSum v > b) as
                  d = b - getSum (measure ls)
                  0 <|? as | not $ FingerTree.null ls = as
                  a <|? as                            = Diff a <| as

-- | The union of a list of sets.
unions :: [IntSet] -> IntSet
unions xs = foldl' union empty xs


-- | /O(???)/. Difference between two sets.

-- This should be O(m log(n / m)) but it might be even better.
difference :: IntSet -> IntSet -> IntSet
difference (IntSet xs) (IntSet ys) = IntSet $ diffF xs ys
  where diffF as bs = case viewl bs of
          EmptyL -> as
          Diff b :< bs'
            | FingerTree.null ls -> diffR (translate' b bs') rs
            | d == 0             ->
              case viewr ls of
                ls' :> Diff m -> ls' >< translate' (d + m) (diffR bs' rs)
            | otherwise          -> ls >< diffR (translate' d bs') rs
            where (ls, rs) = FingerTree.split (\v -> getSum v > b) as
                  d = b - getSum (measure ls)
        diffR as bs = case viewl bs of
          EmptyL -> bs
          Diff b :< bs'
            | FingerTree.null ls -> Diff b <| diffF bs' (translate' (-b) rs)
            | d == 0             ->
              case viewr ls of
                ls' :> Diff m -> translate' b $ diffF bs' rs
            | otherwise          -> Diff b <| diffF bs' (translate' (-d) rs)
            where (ls, rs) = FingerTree.split (\v -> getSum v > b) as
                  d = b - getSum (measure ls)

-- | /O(???)/. The intersection of two sets.

-- This should be O(m log(n / m)) but is likely even better.
intersection :: IntSet -> IntSet -> IntSet
intersection (IntSet xs) (IntSet ys) = IntSet $ both xs ys
  where both as bs = case viewl bs of
          EmptyL -> bs
          Diff b :< bs'
            | FingerTree.null ls -> both (translate' b bs') rs
            | d == 0             -> Diff b <| both bs' rs
            | otherwise          -> both (translate' b bs') (translate' m rs)
            where (ls, rs) = FingerTree.split (\v -> getSum v > b) as
                  m = getSum (measure ls)
                  d = b - m

-- | /O(n)/. Filter all elements that satisfy some predicate.
filter :: (Int -> Bool) -> IntSet -> IntSet
filter p = fromDistinctAscList . List.filter p . toList

-- | /O(n)/. Partition the set into two sets, one with all elements that satisfy
-- the predicate and one with all elements that don't satisfy the predicate.
-- See also 'split'.
partition :: (Int -> Bool) -> IntSet -> (IntSet, IntSet)
partition p = join (***) fromDistinctAscList . List.partition p . toList

-- | /O(log(min(i,n-i)))/. The expression (@'split' x set@) is a pair @(set1,set2)@
-- where @set1@ comprises the elements of @set@ less than @x@ and @set2@
-- comprises the elements of @set@ greater than @x@.
--
-- > split 3 (fromList [1..5]) == (fromList [1,2], fromList [4,5])
split :: Int -> IntSet -> (IntSet, IntSet)
split k s = case splitMember k s of
  (a, _, b) -> (a, b)


-- | /O(log(min(i,n-i)))/. Performs a 'split' but also returns whether the pivot
-- element was found in the original set.
splitMember :: Int -> IntSet -> (IntSet, Bool, IntSet)
splitMember k (IntSet s) =
  case FingerTree.split ((> k) . getSum) s of
    (ls, rs) | FingerTree.null ls       -> (IntSet ls, False, IntSet rs)
             | getSum (measure ls) == k ->
                 case viewr ls of
                   ls' :> _ -> (IntSet ls', True, IntSet rs')
             | otherwise                -> (IntSet ls, False, IntSet rs')
      where d   = getSum (measure ls)
            rs' = translate' d rs

-- | /O(1)/. The minimal element of the set.
findMin :: IntSet -> Int
findMin =
  maybe (error "findMin: empty set has no minimal element") fst . minView

-- | /O(1)/. The maximal element of a set.
findMax :: IntSet -> Int
findMax =
  maybe (error "findMax: empty set has no maximal element") fst . maxView

-- | /O(1)/. Delete the minimal element.
deleteMin :: IntSet -> IntSet
deleteMin =
  maybe (error "deleteMin: empty set has no minimal element") snd . minView

-- | /O(1)/. Delete the maximal element.
deleteMax :: IntSet -> IntSet
deleteMax =
  maybe (error "deleteMax: empty set has no maximal element") snd . maxView

-- | /O(1)/. Delete and find the minimal element.
--
-- > deleteFindMin set = (findMin set, deleteMin set)
deleteFindMin :: IntSet -> (Int, IntSet)
deleteFindMin =
  fromMaybe (error "deleteFindMin: empty set has no minimal element") . minView

-- | /O(1)/. Delete and find the maximal element.
--
-- > deleteFindMax set = (findMax set, deleteMax set)
deleteFindMax :: IntSet -> (Int, IntSet)
deleteFindMax =
  fromMaybe (error "deleteFindMax: empty set has no maximal element") . maxView

-- | /O(1)/. Retrieves the maximal key of the set, and the set
-- stripped of that element, or 'Nothing' if passed an empty set.
maxView :: IntSet -> Maybe (Int, IntSet)
maxView (IntSet xs) = case viewr xs of
  EmptyR   -> Nothing
  xs' :> _ -> Just (getSum $ measure xs, IntSet xs')

-- | /O(1)/. Retrieves the minimal key of the set, and the set
-- stripped of that element, or 'Nothing' if passed an empty set.
minView :: IntSet -> Maybe (Int, IntSet)
minView (IntSet xs) = case viewl xs of
  EmptyL   -> Nothing
  Diff x :< xs' -> Just (x, IntSet $ translate' x xs')

-- | /O(n*log(n))/.
-- @'map' f s@ is the set obtained by applying @f@ to each element of @s@.
--
-- It's worth noting that the size of the result may be smaller if,
-- for some @(x,y)@, @x \/= y && f x == f y@
map :: (Int -> Int) -> IntSet -> IntSet
map f = fromList . List.map f . toList

-- | /O(1)/. Add a constant value to all elements of the set.
--
-- > translate x s == map (+x) s
translate :: Int -> IntSet -> IntSet
translate x (IntSet xs) = IntSet $ translate' x xs

-- | /O(n)/. Fold over the elements of a set in an unspecified order.
--
-- > sum set   == fold (+) 0 set
-- > elems set == fold (:) [] set
fold :: (Int -> b -> b) -> b -> IntSet -> b
fold f i = foldr f i . toList


-- | /O(n)/. The elements of a set. (For sets, this is equivalent to toList)
elems :: IntSet -> [Int]
elems = toList

-- | /O(n)/. Convert the set to a list of elements.
toList :: IntSet -> [Int]
toList = toAscList

-- | /O(n*log(n))/. Create a set from a list of integers.
fromList :: [Int] -> IntSet
fromList = fromAscList . sort

-- | /O(n)/. Convert the set to an ascending list of elements.
toAscList :: IntSet -> [Int]
toAscList (IntSet xs) = toList xs 0
  where toList xs d = case viewl xs of
          EmptyL  -> []
          Diff x :< xs -> x + d : toList xs (x + d)

-- | /O(n)/. Build a set from an ascending list of elements.
-- /The precondition (input list is ascending) is not checked./
fromAscList :: [Int] -> IntSet
fromAscList = fromDistinctAscList . List.map head . group

-- | /O(n)/. Build a set from an ascending list of distinct elements.
-- /The precondition (input list is strictly ascending) is not checked./
fromDistinctAscList :: [Int] -> IntSet
fromDistinctAscList xs = IntSet $ foldl step FingerTree.empty xs
  where step as x = as |> Diff (x - getSum (measure as))

-- Internal

translate' 0 xs = xs
translate' d xs = case viewl xs of
  EmptyL  -> FingerTree.empty
  Diff x :< xs -> Diff (x + d) <| xs