{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE UndecidableInstances #-}
-- | List utilities at the type-level.
module Language.Symantic.Typing.List where

import GHC.Exts (Constraint)

import Language.Symantic.Typing.Peano

-- ** Type 'Index'
-- | Return the position of a type within a list of them.
-- This is useful to work around @OverlappingInstances@.
type family Index xs x where
	Index (x     ': xs) x = Zero
	Index (not_x ': xs) x = Succ (Index xs x)

-- * Type family @(++)@
type family (++) xs ys where
	'[] ++ ys  = ys
	-- xs  ++ '[] = xs
	(x ': xs) ++ ys = x ': xs ++ ys
infixr 5 ++

-- * Type family 'Concat'
type family Concat (xs::[[k]]) :: [k] where
	Concat '[] = '[]
	Concat (x ': xs) = x ++ Concat xs

-- * Type family 'Concat_Constraints'
type family Concat_Constraints (cs::[Constraint]) :: Constraint where
	Concat_Constraints '[] = ()
	Concat_Constraints (c ': cs) = (c, Concat_Constraints cs)

-- * Type family 'DeleteAll'
type family DeleteAll (x::k) (xs::[k]) :: [k] where
	DeleteAll x '[] = '[]
	DeleteAll x (x ': xs) = DeleteAll x xs
	DeleteAll x (y ': xs) = y ': DeleteAll x xs

-- * Type family 'Head'
type family Head (xs::[k]) :: k where
	Head (x ': _xs) = x

-- * Type family 'Tail'
type family Tail (xs::[k]) :: [k] where
	Tail (_x ': xs) = xs

{-
-- * Type family 'Map'
type family Map (f::a -> b) (cs::[a]) :: [b] where
	Map f '[] = '[]
	Map f (c ': cs) = f c ': Map f cs
-}

-- * Type family 'Nub'
type family Nub (xs::[k]) :: [k] where
	Nub '[] = '[]
	Nub (x ': xs) = x ': Nub (DeleteAll x xs)

{-
-- * Type family 'L'
type family L (xs::[k]) :: Nat where
	L '[] = 'Z
	L (x ': xs) = 'S (L xs)

-- ** Class 'Inj_L'
class Inj_L (as::[k]) where
	inj_L :: SNat (L as)
instance Inj_L '[] where
	inj_L = SNatZ
instance Inj_L as => Inj_L (a ': as) where
	inj_L = SNatS (inj_L @_ @as)
-}

-- * Type 'Len'
data Len (xs::[k]) where
	LenZ :: Len '[]
	LenS :: Len xs -> Len (x ': xs)

instance Show (Len vs) where
	showsPrec _p = showsPrec 10 . intLen

addLen :: Len a -> Len b -> Len (a ++ b)
addLen LenZ     b = b
addLen (LenS a) b = LenS $ addLen a b

shiftLen ::
 forall t b a.
 Len a ->
 Len (a ++      b) ->
 Len (a ++ t ': b)
shiftLen LenZ b = LenS b
shiftLen (LenS a) (LenS b) = LenS $ shiftLen @t @b a b

intLen :: Len xs -> Int
intLen = go 0
	where
	go :: Int -> Len xs -> Int
	go i LenZ     = i
	go i (LenS l) = go (1 + i) l

-- ** Class 'Inj_Len'
class Inj_Len (vs::[k]) where
	inj_Len :: Len vs
instance Inj_Len '[] where
	inj_Len = LenZ
instance Inj_Len as => Inj_Len (a ': as) where
	inj_Len = LenS inj_Len