{-# LANGUAGE ConstraintKinds        #-}
{-# LANGUAGE DataKinds              #-}
{-# LANGUAGE FlexibleContexts       #-}
{-# LANGUAGE FlexibleInstances      #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs                  #-}
{-# LANGUAGE MultiParamTypeClasses  #-}
{-# LANGUAGE OverlappingInstances   #-}
{-# LANGUAGE ScopedTypeVariables    #-}
{-# LANGUAGE TypeFamilies           #-}
{-# LANGUAGE TypeOperators          #-}
{-# LANGUAGE UndecidableInstances   #-}

module Data.OpenUnion.Internal where
import           Control.Applicative
import           Data.Monoid         hiding (All)
import           GHC.Exts            (Constraint)

data Union (as :: [*]) where
  Single    :: a -> Union (a ': as)
  Union     :: Union as -> Union (a ': as)
  Exhausted :: Union '[] -> Union '[]

exhaust :: Union '[] -> a
exhaust (Exhausted x) = exhaust x
{-# INLINE exhaust #-}

type family All (ca :: * -> Constraint) (as :: [*]) :: Constraint
type instance All ca '[] = ()
type instance All ca (a ': as) = (ca a, All ca as)

instance (All Show r) => Show (Union r) where
  show (Single x) = show x
  show (Union u)  = show u
  show _ = undefined

(||>) :: (a -> r) -> (Union as -> r) -> (Union (a ': as) -> r)
(||>) f _ (Single x) = f x
(||>) _ g (Union u) = g u
infixr 2 ||>
{-# INLINE (||>) #-}

data Position (a :: *) (as :: [*]) where
  Zero :: Position a (a ': as)
  Succ :: Position a as -> Position a (b ': as)

class Member a as | as -> a where
  position :: Position a as

instance Member a (a ': as) where
  position = Zero

instance Member a as => Member a (b ': as) where
  position = Succ position

-- | Lift a value into a Union.
liftU :: forall a as. Member a as => a -> Union as
liftU x = go (position :: Position a as)
  where
    go :: Position a bs -> Union bs
    go Zero     = Single x
    go (Succ n) = Union $ go n
{-# INLINE liftU #-}

-- | Traversal for a specific element.
picked :: forall a as f. (Applicative f, Member a as) => (a -> f a) -> Union as -> f (Union as)
picked k = go (position :: Position a as)
  where
    go :: Position a bs -> Union bs -> f (Union bs)
    go Zero (Single x)       = Single <$> k x
    go Zero u@(Union _)      = pure u
    go (Succ _) u@(Single _) = pure u
    go (Succ n) (Union u)    = Union <$> go n u
{-# INLINE picked #-}

-- | Retrieve the value from a Union.
retractU :: Member a as => Union as -> Maybe a
retractU = getFirst . getConst . picked (Const . First . Just)
{-# INLINE retractU #-}

-- | Lift a function into @Union@.
hoistU :: Member a as => (a -> a) -> Union as -> Union as
hoistU f = getId . picked (Id . f)
{-# INLINE hoistU #-}

-- | Instead of Identity.
newtype Id a = Id { getId :: a }

instance Functor Id where
  fmap f (Id a) = Id (f a)

instance Applicative Id where
  pure = Id
  Id f <*> Id a = Id (f a)

class Include as bs where
  reunion :: Union as -> Union bs

instance (Member a bs, Include as bs) => Include (a ': as) bs where
  reunion (Single x) = liftU x
  reunion (Union u)  = reunion u

instance (Include '[] bs) where
  reunion = exhaust

type a  as  = Member a as
type as  bs = Include as bs
type as  bs = (Include as bs, Include bs as)

infix 4 
infix 4 
infix 4