-- | An 'OSet' behaves much like a 'Set', with all the same asymptotics, but
-- also remembers the order that values were inserted.
module Data.Set.Ordered
	( OSet
	-- * Trivial sets
	, empty, singleton
	-- * Insertion
	-- | Conventionts:
	--
	-- * The open side of an angle bracket points to an 'OSet'
	--
	-- * The pipe appears on the side whose indices take precedence for keys that appear on both sides
	--
	-- * The left argument's indices are lower than the right argument's indices
	, (<|), (|<), (>|), (|>)
	, (<>|), (|<>)
	-- * Query
	, null, size, member, notMember
	-- * Deletion
	, delete, filter, (\\)
	-- * Indexing
	, Index, findIndex, elemAt
	-- * List conversions
	, fromList, toAscList
	) where

import Control.Monad (guard)
import Data.Foldable (Foldable(foldl', foldMap, toList))
import Data.Function (on)
import Data.Map (Map)
import Data.Map.Util (Index, Tag, maxTag, minTag, nextHigherTag, nextLowerTag, readsPrecList, showsPrecList)
import Prelude hiding (filter, lookup, null)
import qualified Data.Map as M

data OSet a = OSet !(Map a Tag) !(Map Tag a)

-- | Values appear in insertion order, not ascending order.
instance Foldable OSet where foldMap f (OSet _ vs) = foldMap f vs
instance         Eq   a  => Eq   (OSet a) where (==)    = (==)    `on` toList
instance         Ord  a  => Ord  (OSet a) where compare = compare `on` toList
instance         Show a  => Show (OSet a) where showsPrec = showsPrecList toList
instance (Ord a, Read a) => Read (OSet a) where readsPrec = readsPrecList fromList

infixr 5 <|, |<   -- copy :
infixl 5 >|, |>
infixr 6 <>|, |<> -- copy <>

(<|) , (|<)  :: Ord a =>      a -> OSet a -> OSet a
(>|) , (|>)  :: Ord a => OSet a ->      a -> OSet a
(<>|), (|<>) :: Ord a => OSet a -> OSet a -> OSet a

v <| o@(OSet ts vs)
	| v `member` o = o
	| otherwise    = OSet (M.insert v t ts) (M.insert t v vs) where
		t = nextLowerTag vs

v |< o = OSet (M.insert v t ts) (M.insert t v vs) where
	t = nextLowerTag vs
	OSet ts vs = delete v o

o@(OSet ts vs) |> v
	| v `member` o = o
	| otherwise    = OSet (M.insert v t ts) (M.insert t v vs) where
		t = nextHigherTag vs

o >| v = OSet (M.insert v t ts) (M.insert t v vs) where
	t = nextHigherTag vs
	OSet ts vs = delete v o

o <>| o' = unsafeMappend (o \\ o') o'
o |<> o' = unsafeMappend o (o' \\ o)

-- assumes that ts and ts' have disjoint keys
unsafeMappend (OSet ts vs) (OSet ts' vs')
	= OSet (M.union tsBumped tsBumped')
	       (M.union vsBumped vsBumped')
	where
	bump  = case maxTag vs  of
		Nothing -> 0
		Just k  -> -k-1
	bump' = case minTag vs' of
		Nothing -> 0
		Just k  -> -k
	tsBumped  = (bump +) <$> ts
	tsBumped' = (bump'+) <$> ts'
	vsBumped  = (bump +) `M.mapKeysMonotonic` vs
	vsBumped' = (bump'+) `M.mapKeysMonotonic` vs'

-- | Set difference: @r \\\\ s@ deletes all the values in @s@ from @r@. The
-- order of @r@ is unchanged.
(\\) :: Ord a => OSet a -> OSet a -> OSet a
o@(OSet ts vs) \\ o'@(OSet ts' vs') = if size o < size o'
	then filter (`notMember` o') o
	else foldr delete o vs'

empty :: OSet a
empty = OSet M.empty M.empty

member, notMember :: Ord a => a -> OSet a -> Bool
member    v (OSet ts _) = M.member    v ts
notMember v (OSet ts _) = M.notMember v ts

size :: OSet a -> Int
size (OSet ts _) = M.size ts

filter :: (a -> Bool) -> OSet a -> OSet a
filter f (OSet ts vs) = OSet (M.filterWithKey (\v t -> f v) ts)
                             (M.filterWithKey (\t v -> f v) vs)

delete :: Ord a => a -> OSet a -> OSet a
delete v o@(OSet ts vs) = case M.lookup v ts of
	Nothing -> o
	Just t  -> OSet (M.delete v ts) (M.delete t vs)

singleton :: a -> OSet a
singleton v = OSet (M.singleton v 0) (M.singleton 0 v)

-- | If a value occurs multiple times, only the first occurrence is used.
fromList :: Ord a => [a] -> OSet a
fromList = foldl' (|>) empty

null :: OSet a -> Bool
null (OSet ts _) = M.null ts

findIndex :: Ord a => a -> OSet a -> Maybe Index
findIndex v o@(OSet ts vs) = do
	t <- M.lookup v ts
	M.lookupIndex t vs

elemAt :: OSet a -> Index -> Maybe a
elemAt o@(OSet ts vs) i = do
	guard (0 <= i && i < M.size vs)
	return . snd $ M.elemAt i vs

-- | Returns values in ascending order. (Use 'toList' to return them in
-- insertion order.)
toAscList :: OSet a -> [a]
toAscList o@(OSet ts _) = fst <$> M.toAscList ts