{-# LANGUAGE CPP                  #-}
{-# LANGUAGE ConstraintKinds      #-}
{-# LANGUAGE DataKinds            #-}
{-# LANGUAGE ExplicitForAll       #-}
{-# LANGUAGE GADTs                #-}
{-# LANGUAGE TypeFamilies         #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}
#if __GLASGOW_HASKELL__ >= 806
{-# LANGUAGE NoStarIsType         #-}
#endif


{-# LANGUAGE FlexibleContexts     #-}
--------------------------------------------------------------------------------
-- |
-- Module      :  Data.Type.List
-- Copyright   :  (c) Artem Chirkin
-- License     :  BSD3
--
-- A Mixture of `GHC.TypeLits` and  `GHC.TypeNats` with @Nat@s represented as
-- @Natural@ at runtime, plus some helper functions of our own.
--
--------------------------------------------------------------------------------

module Data.Type.Lits
  ( -- * Kinds
    TN.Nat, TL.Symbol
    -- * Linking type and value level
  , TN.KnownNat, TN.natVal, TN.natVal'
  , TL.KnownSymbol, TL.symbolVal, TL.symbolVal'
  , TN.SomeNat(..), TL.SomeSymbol(..)
  , TN.someNatVal, TL.someSymbolVal
  , TN.sameNat, TL.sameSymbol
    -- * Functions on type literals
  , type (+), type (*), type (^), type (-)
  , type TN.Div, type TN.Mod, type TN.Log2
  , TL.AppendSymbol, ShowNat
  , TN.CmpNat, TL.CmpSymbol, type (<=)
  , SOrdering (..), cmpNat, cmpSymbol
    -- * User-defined type errors
  , TL.TypeError
  , TL.ErrorMessage(..)
  ) where


import           Data.Kind     (Constraint, Type)
import qualified GHC.TypeLits  as TL
import           GHC.TypeNats  (type (*), type (+), type (-), type (^))
import qualified GHC.TypeNats  as TN
import           Unsafe.Coerce (unsafeCoerce)




-- | Convert a type-level @Nat@ into a type-level @Symbol@.
type family ShowNat (n :: TN.Nat) :: TL.Symbol where
    -- lol
    ShowNat 0 = "0"
    ShowNat 1 = "1"
    ShowNat 2 = "2"
    ShowNat 3 = "3"
    ShowNat 4 = "4"
    ShowNat 5 = "5"
    ShowNat 6 = "6"
    ShowNat 7 = "7"
    ShowNat 8 = "8"
    ShowNat 9 = "9"
    ShowNat d = TL.AppendSymbol (ShowNat (TN.Div d 10)) (ShowNat (TN.Mod d 10))


-- | Singleton-style version of `Ordering`.
--   Pattern-match againts its constructor to witness the result of
--   type-level comparison.
data SOrdering :: Ordering -> Type where
    SLT :: SOrdering 'LT
    SEQ :: SOrdering 'EQ
    SGT :: SOrdering 'GT

-- | Pattern-match against the result of this function to get the evidence
--   of comparing type-level Nats.
cmpNat :: forall (a :: TN.Nat) (b :: TN.Nat) (proxy :: TN.Nat -> Type)
        . (TN.KnownNat a, TN.KnownNat b)
       => proxy a -> proxy b -> SOrdering (TN.CmpNat a b)
cmpNat a b
  = case compare (TN.natVal a) (TN.natVal b) of
    LT -> unsafeCoerce SLT
    EQ -> unsafeCoerce SEQ
    GT -> unsafeCoerce SGT
{-# INLINE cmpNat #-}

-- | Pattern-match against the result of this function to get the evidence
--   of comparing type-level Symbols.
cmpSymbol :: forall (a :: TL.Symbol) (b :: TL.Symbol) (proxy :: TL.Symbol -> Type)
           . (TL.KnownSymbol a, TL.KnownSymbol b)
          => proxy a -> proxy b -> SOrdering (TL.CmpSymbol a b)
cmpSymbol a b
  = case compare (TL.symbolVal a) (TL.symbolVal b) of
    LT -> unsafeCoerce SLT
    EQ -> unsafeCoerce SEQ
    GT -> unsafeCoerce SGT
{-# INLINE cmpSymbol #-}

-- | Comparison of type-level naturals, as a constraint.
type (<=) (a :: TN.Nat) (b :: TN.Nat) = LE a b (TN.CmpNat a b)

type family LE (a :: TN.Nat) (b :: TN.Nat) (r :: Ordering) :: Constraint where
    LE _ _ 'LT = ()
    LE _ _ 'EQ = ()
    LE a b 'GT = TL.TypeError
      ('TL.Text "Cannot deduce type-level Nat relation: "
          'TL.:<>: 'TL.ShowType a
          'TL.:<>: 'TL.Text " <= "
          'TL.:<>: 'TL.ShowType b
      )