{-# LANGUAGE BangPatterns       #-}
{-# LANGUAGE CPP                #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE LambdaCase         #-}
{-# LANGUAGE ViewPatterns       #-}
{-# OPTIONS_HADDOCK not-home    #-}
module Data.Set.NonEmpty.Internal (
    NESet(..)
  , nonEmptySet
  , withNonEmpty
  , toSet
  , singleton
  , fromList
  , toList
  , size
  , union
  , unions
  , foldr
  , foldl
  , foldr'
  , foldl'
  , MergeNESet(..)
  , merge
  , valid
  , insertMinSet
  , insertMaxSet
  , disjointSet
  , powerSetSet
  , disjointUnionSet
  , cartesianProductSet
  ) where
import           Control.DeepSeq
import           Data.Data
import           Data.Function
import           Data.Functor.Classes
import           Data.List.NonEmpty                   (NonEmpty(..))
import           Data.Semigroup
import           Data.Semigroup.Foldable              (Foldable1)
import           Data.Set.Internal                    (Set(..))
import           Data.Typeable                        (Typeable)
import           Prelude hiding                       (foldr, foldr1, foldl, foldl1)
import           Text.Read
import qualified Data.Foldable                        as F
import qualified Data.Semigroup.Foldable              as F1
import qualified Data.Set                             as S
import qualified Data.Set.Internal                    as S
#if !MIN_VERSION_containers(0,5,11)
import           Utils.Containers.Internal.StrictPair
#endif
data NESet a =
    NESet { nesV0  :: !a   
          , nesSet :: !(Set a)
          }
  deriving (Typeable)
instance Eq a => Eq (NESet a) where
    t1 == t2  = S.size (nesSet t1) == S.size (nesSet t2)
             && toList t1 == toList t2
instance Ord a => Ord (NESet a) where
    compare = compare `on` toList
    (<)     = (<) `on` toList
    (>)     = (>) `on` toList
    (<=)    = (<=) `on` toList
    (>=)    = (>=) `on` toList
instance Show a => Show (NESet a) where
    showsPrec p xs = showParen (p > 10) $
      showString "fromList (" . shows (toList xs) . showString ")"
instance (Read a, Ord a) => Read (NESet a) where
    readPrec = parens $ prec 10 $ do
      Ident "fromList" <- lexP
      xs <- parens . prec 10 $ readPrec
      return (fromList xs)
    readListPrec = readListPrecDefault
instance Eq1 NESet where
    liftEq eq m n =
        size m == size n && liftEq eq (toList m) (toList n)
instance Ord1 NESet where
    liftCompare cmp m n =
        liftCompare cmp (toList m) (toList n)
instance Show1 NESet where
    liftShowsPrec sp sl d m =
        showsUnaryWith (liftShowsPrec sp sl) "fromList" d (toList m)
instance NFData a => NFData (NESet a) where
    rnf (NESet x s) = rnf x `seq` rnf s
instance (Data a, Ord a) => Data (NESet a) where
  gfoldl f z set = z fromList `f` toList set
  toConstr _     = fromListConstr
  gunfold k z c  = case constrIndex c of
    1 -> k (z fromList)
    _ -> error "gunfold"
  dataTypeOf _   = setDataType
  dataCast1      = gcast1
fromListConstr :: Constr
fromListConstr = mkConstr setDataType "fromList" [] Prefix
setDataType :: DataType
setDataType = mkDataType "Data.Set.NonEmpty.Internal.NESet" [fromListConstr]
nonEmptySet :: Set a -> Maybe (NESet a)
nonEmptySet = (fmap . uncurry) NESet . S.minView
{-# INLINE nonEmptySet #-}
withNonEmpty
    :: r                  
    -> (NESet a -> r)     
    -> Set a
    -> r
withNonEmpty def f = maybe def f . nonEmptySet
{-# INLINE withNonEmpty #-}
toSet :: NESet a -> Set a
toSet (NESet x s) = insertMinSet x s
{-# INLINE toSet #-}
singleton :: a -> NESet a
singleton x = NESet x S.empty
{-# INLINE singleton #-}
fromList :: Ord a => NonEmpty a -> NESet a
fromList (x :| s) = withNonEmpty (singleton x) (<> singleton x)
                  . S.fromList
                  $ s
{-# INLINE fromList #-}
toList :: NESet a -> NonEmpty a
toList (NESet x s) = x :| S.toList s
{-# INLINE toList #-}
size :: NESet a -> Int
size (NESet _ s) = 1 + S.size s
{-# INLINE size #-}
foldr :: (a -> b -> b) -> b -> NESet a -> b
foldr f z (NESet x s) = x `f` S.foldr f z s
{-# INLINE foldr #-}
foldr' :: (a -> b -> b) -> b -> NESet a -> b
foldr' f z (NESet x s) = x `f` y
  where
    !y = S.foldr' f z s
{-# INLINE foldr' #-}
foldr1 :: (a -> a -> a) -> NESet a -> a
foldr1 f (NESet x s) = maybe x (f x . uncurry (S.foldr f))
                     . S.maxView
                     $ s
{-# INLINE foldr1 #-}
foldl :: (a -> b -> a) -> a -> NESet b -> a
foldl f z (NESet x s) = S.foldl f (f z x) s
{-# INLINE foldl #-}
foldl' :: (a -> b -> a) -> a -> NESet b -> a
foldl' f z (NESet x s) = S.foldl' f y s
  where
    !y = f z x
{-# INLINE foldl' #-}
foldl1 :: (a -> a -> a) -> NESet a -> a
foldl1 f (NESet x s) = S.foldl f x s
{-# INLINE foldl1 #-}
union
    :: Ord a
    => NESet a
    -> NESet a
    -> NESet a
union n1@(NESet x1 s1) n2@(NESet x2 s2) = case compare x1 x2 of
    LT -> NESet x1 . S.union s1 . toSet $ n2
    EQ -> NESet x1 . S.union s1         $ s2
    GT -> NESet x2 . S.union (toSet n1) $ s2
{-# INLINE union #-}
unions
    :: (Foldable1 f, Ord a)
    => f (NESet a)
    -> NESet a
unions (F1.toNonEmpty->(s :| ss)) = F.foldl' union s ss
{-# INLINE unions #-}
instance Ord a => Semigroup (NESet a) where
    (<>) = union
    {-# INLINE (<>) #-}
    sconcat = unions
    {-# INLINE sconcat #-}
instance Foldable NESet where
#if MIN_VERSION_base(4,11,0)
    fold      (NESet x s) = x <> F.fold s
    {-# INLINE fold #-}
    foldMap f (NESet x s) = f x <> foldMap f s
    {-# INLINE foldMap #-}
#else
    fold      (NESet x s) = x `mappend` F.fold s
    {-# INLINE fold #-}
    foldMap f (NESet x s) = f x `mappend` foldMap f s
    {-# INLINE foldMap #-}
#endif
    foldr   = foldr
    {-# INLINE foldr #-}
    foldr'  = foldr'
    {-# INLINE foldr' #-}
    foldr1  = foldr1
    {-# INLINE foldr1 #-}
    foldl   = foldl
    {-# INLINE foldl #-}
    foldl'  = foldl'
    {-# INLINE foldl' #-}
    foldl1  = foldl1
    {-# INLINE foldl1 #-}
    null _  = False
    {-# INLINE null #-}
    length  = size
    {-# INLINE length #-}
    elem x (NESet x0 s) = F.elem x s
                       || x == x0
    {-# INLINE elem #-}
    minimum (NESet x _) = x
    {-# INLINE minimum #-}
    maximum (NESet x s) = maybe x fst . S.maxView $ s
    {-# INLINE maximum #-}
    
    toList  = F.toList . toList
    {-# INLINE toList #-}
instance Foldable1 NESet where
    fold1 (NESet x s) = option x (x <>)
                      . F.foldMap (Option . Just)
                      $ s
    {-# INLINE fold1 #-}
    
    foldMap1 f (NESet x s) = option (f x) (f x <>)
                           . F.foldMap (Option . Just . f)
                           $ s
    {-# INLINE foldMap1 #-}
    toNonEmpty = toList
    {-# INLINE toNonEmpty #-}
newtype MergeNESet a = MergeNESet { getMergeNESet :: NESet a }
instance Semigroup (MergeNESet a) where
    MergeNESet n1 <> MergeNESet n2 = MergeNESet (merge n1 n2)
    {-# INLINE (<>) #-}
merge :: NESet a -> NESet a -> NESet a
merge (NESet x1 s1) n2 = NESet x1 $ s1 `S.merge` toSet n2
valid :: Ord a => NESet a -> Bool
valid (NESet x s) = S.valid s
                  && all ((x <) . fst) (S.minView s)
insertMinSet :: a -> Set a -> Set a
insertMinSet x = \case
    Tip         -> S.singleton x
    Bin _ y l r -> balanceL y (insertMinSet x l) r
{-# INLINABLE insertMinSet #-}
insertMaxSet :: a -> Set a -> Set a
insertMaxSet x = \case
    Tip         -> S.singleton x
    Bin _ y l r -> balanceR y l (insertMaxSet x r)
{-# INLINABLE insertMaxSet #-}
disjointSet :: Ord a => Set a -> Set a -> Bool
#if MIN_VERSION_containers(0,5,11)
disjointSet = S.disjoint
#else
disjointSet xs = S.null . S.intersection xs
#endif
{-# INLINE disjointSet #-}
powerSetSet :: Set a -> Set (Set a)
#if MIN_VERSION_containers(0,5,11)
powerSetSet = S.powerSet
{-# INLINE powerSetSet #-}
#else
powerSetSet xs0 = insertMinSet S.empty (S.foldr' step' Tip xs0) where
  step' x pxs = insertMinSet (S.singleton x) (insertMinSet x `S.mapMonotonic` pxs) `glue` pxs
{-# INLINABLE powerSetSet #-}
minViewSure :: a -> Set a -> Set a -> StrictPair a (Set a)
minViewSure = go
  where
    go x Tip r = x :*: r
    go x (Bin _ xl ll lr) r =
      case go xl ll lr of
        xm :*: l' -> xm :*: balanceR x l' r
maxViewSure :: a -> Set a -> Set a -> StrictPair a (Set a)
maxViewSure = go
  where
    go x l Tip = x :*: l
    go x l (Bin _ xr rl rr) =
      case go xr rl rr of
        xm :*: r' -> xm :*: balanceL x l r'
glue :: Set a -> Set a -> Set a
glue Tip r = r
glue l Tip = l
glue l@(Bin sl xl ll lr) r@(Bin sr xr rl rr)
  | sl > sr = let !(m :*: l') = maxViewSure xl ll lr in balanceR m l' r
  | otherwise = let !(m :*: r') = minViewSure xr rl rr in balanceL m l r'
#endif
disjointUnionSet :: Set a -> Set b -> Set (Either a b)
#if MIN_VERSION_containers(0,5,11)
disjointUnionSet = S.disjointUnion
#else
disjointUnionSet as bs = S.merge (S.mapMonotonic Left as) (S.mapMonotonic Right bs)
#endif
{-# INLINE disjointUnionSet #-}
cartesianProductSet :: Set a -> Set b -> Set (a, b)
#if MIN_VERSION_containers(0,5,11)
cartesianProductSet = S.cartesianProduct
#else
cartesianProductSet as bs =
  getMergeSet $ foldMap (\a -> MergeSet $ S.mapMonotonic ((,) a) bs) as
newtype MergeSet a = MergeSet { getMergeSet :: Set a }
instance Semigroup (MergeSet a) where
    MergeSet xs <> MergeSet ys = MergeSet (S.merge xs ys)
instance Monoid (MergeSet a) where
    mempty = MergeSet S.empty
    mappend = (<>)
#endif
{-# INLINE cartesianProductSet #-}
balanceR :: a -> Set a -> Set a -> Set a
balanceR x l r = case l of
    Tip -> case r of
      Tip -> Bin 1 x Tip Tip
      Bin _ _ Tip Tip -> Bin 2 x Tip r
      Bin _ rx Tip rr@Bin{} -> Bin 3 rx (Bin 1 x Tip Tip) rr
      Bin _ rx (Bin _ rlx _ _) Tip -> Bin 3 rlx (Bin 1 x Tip Tip) (Bin 1 rx Tip Tip)
      Bin rs rx rl@(Bin rls rlx rll rlr) rr@(Bin rrs _ _ _)
        | rls < ratio*rrs -> Bin (1+rs) rx (Bin (1+rls) x Tip rl) rr
        | otherwise -> Bin (1+rs) rlx (Bin (1+S.size rll) x Tip rll) (Bin (1+rrs+S.size rlr) rx rlr rr)
    Bin ls _ _ _ -> case r of
      Tip -> Bin (1+ls) x l Tip
      Bin rs rx rl rr
         | rs > delta*ls  -> case (rl, rr) of
              (Bin rls rlx rll rlr, Bin rrs _ _ _)
                | rls < ratio*rrs -> Bin (1+ls+rs) rx (Bin (1+ls+rls) x l rl) rr
                | otherwise -> Bin (1+ls+rs) rlx (Bin (1+ls+S.size rll) x l rll) (Bin (1+rrs+S.size rlr) rx rlr rr)
              (_, _) -> error "Failure in Data.Map.balanceR"
                | otherwise -> Bin (1+ls+rs) x l r
{-# NOINLINE balanceR #-}
balanceL :: a -> Set a -> Set a -> Set a
balanceL x l r = case r of
    Tip -> case l of
      Tip -> Bin 1 x Tip Tip
      Bin _ _ Tip Tip -> Bin 2 x l Tip
      Bin _ lx Tip (Bin _ lrx _ _) -> Bin 3 lrx (Bin 1 lx Tip Tip) (Bin 1 x Tip Tip)
      Bin _ lx ll@Bin{} Tip -> Bin 3 lx ll (Bin 1 x Tip Tip)
      Bin ls lx ll@(Bin lls _ _ _) lr@(Bin lrs lrx lrl lrr)
        | lrs < ratio*lls -> Bin (1+ls) lx ll (Bin (1+lrs) x lr Tip)
        | otherwise -> Bin (1+ls) lrx (Bin (1+lls+S.size lrl) lx ll lrl) (Bin (1+S.size lrr) x lrr Tip)
    Bin rs _ _ _ -> case l of
             Tip -> Bin (1+rs) x Tip r
             Bin ls lx ll lr
                | ls > delta*rs  -> case (ll, lr) of
                     (Bin lls _ _ _, Bin lrs lrx lrl lrr)
                       | lrs < ratio*lls -> Bin (1+ls+rs) lx ll (Bin (1+rs+lrs) x lr r)
                       | otherwise -> Bin (1+ls+rs) lrx (Bin (1+lls+S.size lrl) lx ll lrl) (Bin (1+rs+S.size lrr) x lrr r)
                     (_, _) -> error "Failure in Data.Set.NonEmpty.Internal.balanceL"
                | otherwise -> Bin (1+ls+rs) x l r
{-# NOINLINE balanceL #-}
delta,ratio :: Int
delta = 3
ratio = 2