{-# OPTIONS_GHC -Wno-redundant-constraints #-}

-- | General type utilities.
module Util.Type
  ( type (==)
  , If
  , type (++)
  , IsElem
  , type (/)
  , type (//)
  , Guard
  , FailWhen
  , FailUnless
  , failUnlessEvi
  , failWhenEvi
  , AllUnique
  , RequireAllUnique
  , ReifyList (..)
  , PatternMatch
  , PatternMatchL
  , KnownList (..)
  , KList (..)
  , RSplit
  , rsplit

  , reifyTypeEquality
  ) where

import Data.Vinyl.Core (Rec (..))
import Data.Vinyl.TypeLevel (type (++))
import qualified Data.Kind as Kind
import Data.Type.Bool (Not, type (&&), If)
import Unsafe.Coerce (unsafeCoerce)
import Data.Constraint ((:-)(..), Dict (..))
import Data.Type.Equality (type (==))
import GHC.TypeLits (Symbol, TypeError, ErrorMessage (..))

type family IsElem (a :: k) (l :: [k]) :: Bool where
  IsElem _ '[] = 'False
  IsElem a (a ': _) = 'True
  IsElem a (_ ': as) = IsElem a as

-- | Remove all occurences of the given element.
type family (l :: [k]) / (a :: k) where
  '[] / _ = '[]
  (a ': xs) / a = xs / a
  (b ': xs) / a = b ': (xs / a)

-- | Difference between two lists.
type family (l1 :: [k]) // (l2 :: [k]) :: [k] where
  l // '[] = l
  l // (x ': xs) = (l / x) // xs

type family Guard (cond :: Bool) (a :: k) :: Maybe k where
  Guard 'False _ = 'Nothing
  Guard 'True a = 'Just a

-- | Fail with given error if the condition does not hold.
type family FailUnless (cond :: Bool) (msg :: ErrorMessage) :: Constraint where
  FailUnless 'True _ = ()
  FailUnless 'False msg = TypeError msg

-- | Fail with given error if the condition holds.
type FailWhen cond msg = FailUnless (Not cond) msg

-- | A natural conclusion from the fact that error have not occured.
failUnlessEvi :: forall cond msg. FailUnless cond msg :- (cond ~ 'True)
failUnlessEvi = Sub $ unsafeCoerce $ Dict @('True ~ 'True)

failWhenEvi :: forall cond msg. FailWhen cond msg :- (cond ~ 'False)
failWhenEvi = Sub $ unsafeCoerce $ Dict @('False ~ 'False)

type family AllUnique (l :: [k]) :: Bool where
  AllUnique '[] = 'True
  AllUnique (x : xs) = Not (IsElem x xs) && AllUnique xs

type RequireAllUnique desc l = RequireAllUnique' desc l l

type family RequireAllUnique' (desc :: Symbol) (l :: [k]) (origL ::[k]) :: Constraint where
  RequireAllUnique' _ '[] _ = ()
  RequireAllUnique' desc (x : xs) origL =
    If (IsElem x xs)
       (TypeError ('Text "Duplicated " ':<>: 'Text desc ':<>: 'Text ":" ':$$:
                   'ShowType x ':$$:
                   'Text "Full list: " ':<>:
                   'ShowType origL
                  )
       )
       (RequireAllUnique' desc xs origL)

-- | Make sure given type is evaluated.
-- This type family fits only for types of 'Kind.Type' kind.
type family PatternMatch (a :: Kind.Type) :: Constraint where
  PatternMatch Int = ((), ())
  PatternMatch _ = ()

type family PatternMatchL (l :: [k]) :: Constraint where
  PatternMatchL '[] = ((), ())
  PatternMatchL _ = ()

-- | Bring type-level list at term-level using given function
-- to demote its individual elements.
class ReifyList (c :: k -> Constraint) (l :: [k]) where
  reifyList :: (forall a. c a => Proxy a -> r) -> [r]

instance ReifyList c '[] where
  reifyList _ = []

instance (c x, ReifyList c xs) => ReifyList c (x ': xs) where
  reifyList reifyElem = reifyElem (Proxy @x) : reifyList @_ @c @xs reifyElem

-- | Reify type equality from boolean equality.
reifyTypeEquality :: forall a b x. (a == b) ~ 'True => (a ~ b => x) -> x
reifyTypeEquality x =
  case unsafeCoerce @(Dict (a ~ a)) @(Dict (a ~ b)) Dict of
    Dict -> x

-- | Similar to @SingI []@, but does not require individual elements to be also
-- instance of @SingI@.
class KnownList l where
  klist :: KList l
instance KnownList '[] where
  klist = KNil
instance KnownList xs => KnownList (x ': xs) where
  klist = KCons Proxy Proxy

-- | 'SList' analogy for 'KnownList'.
data KList (l :: [k]) where
  KNil :: KList '[]
  KCons :: KnownList xs => Proxy x -> Proxy xs -> KList (x ': xs)

type RSplit l r = KnownList l

-- | Split a record into two pieces.
rsplit
  :: forall k (l :: [k]) (r :: [k]) f.
      (RSplit l r)
  => Rec f (l ++ r) -> (Rec f l, Rec f r)
rsplit = case klist @l of
  KNil -> (RNil, )
  KCons{} -> \(x :& r) ->
    let (x1, r1) = rsplit r
    in (x :& x1, r1)