{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE Trustworthy #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise       #-}
{-# OPTIONS_HADDOCK show-extensions #-}
module Clash.Promoted.Nat
  ( 
    
    SNat (..)
    
  , snatProxy
  , withSNat
    
  , snatToInteger, snatToNatural, snatToNum
    
  , natToInteger, natToNatural, natToNum
    
  , addSNat, mulSNat, powSNat, minSNat, maxSNat, succSNat
    
  , subSNat, divSNat, modSNat, flogBaseSNat, clogBaseSNat, logBaseSNat, predSNat
    
  , pow2SNat
    
  , SNatLE (..), compareSNat
    
    
  , UNat (..)
    
  , toUNat
    
  , fromUNat
    
  , addUNat, mulUNat, powUNat
    
  , predUNat, subUNat
    
    
  , BNat (..)
    
  , toBNat
    
  , fromBNat
    
  , showBNat
    
  , succBNat, addBNat, mulBNat, powBNat
    
  , predBNat, div2BNat, div2Sub1BNat, log2BNat
    
  , stripZeros
    
  , leToPlus
  , leToPlusKN
  )
where
import Data.Kind          (Type)
import GHC.Show           (appPrec)
import GHC.TypeLits       (KnownNat, Nat, type (+), type (-), type (*),
                           type (^), type (<=), natVal)
import GHC.TypeLits.Extra (CLog, FLog, Div, Log, Mod, Min, Max)
import GHC.Natural        (naturalFromInteger)
import Language.Haskell.TH (appT, conT, litT, numTyLit, sigE)
import Language.Haskell.TH.Syntax (Lift (..))
import Numeric.Natural    (Natural)
import Unsafe.Coerce      (unsafeCoerce)
import Clash.XException   (ShowX (..), showsPrecXWith)
data SNat (n :: Nat) where
  SNat :: KnownNat n => SNat n
instance Lift (SNat n) where
  lift s = sigE [| SNat |]
                (appT (conT ''SNat) (litT $ numTyLit (snatToInteger s)))
snatProxy :: KnownNat n => proxy n -> SNat n
snatProxy _ = SNat
instance Show (SNat n) where
  showsPrec d p@SNat | n <= 1024 = showChar 'd' . shows n
                     | otherwise = showParen (d > appPrec) $
                                     showString "SNat @" . shows n
   where
    n = snatToInteger p
instance ShowX (SNat n) where
  showsPrecX = showsPrecXWith showsPrec
{-# INLINE withSNat #-}
withSNat :: KnownNat n => (SNat n -> a) -> a
withSNat f = f SNat
natToInteger :: forall n . KnownNat n => Integer
natToInteger = snatToInteger (SNat @n)
{-# INLINE natToInteger #-}
snatToInteger :: SNat n -> Integer
snatToInteger p@SNat = natVal p
{-# INLINE snatToInteger #-}
natToNatural :: forall n . KnownNat n => Natural
natToNatural = snatToNatural (SNat @n)
{-# INLINE natToNatural #-}
snatToNatural :: SNat n -> Natural
snatToNatural = naturalFromInteger . snatToInteger
{-# INLINE snatToNatural #-}
natToNum :: forall n a . (Num a, KnownNat n) => a
natToNum = snatToNum (SNat @n)
{-# INLINE natToNum #-}
snatToNum :: forall a n . Num a => SNat n -> a
snatToNum p@SNat = fromInteger (snatToInteger p)
{-# INLINE snatToNum #-}
data UNat :: Nat -> Type where
  UZero :: UNat 0
  USucc :: UNat n -> UNat (n + 1)
instance KnownNat n => Show (UNat n) where
  show x = 'u':show (natVal x)
instance KnownNat n => ShowX (UNat n) where
  showsPrecX = showsPrecXWith showsPrec
toUNat :: forall n . SNat n -> UNat n
toUNat p@SNat = fromI @n (snatToInteger p)
  where
    fromI :: forall m . Integer -> UNat m
    fromI 0 = unsafeCoerce @(UNat 0) @(UNat m) UZero
    fromI n = unsafeCoerce @(UNat ((m-1)+1)) @(UNat m) (USucc (fromI @(m-1) (n - 1)))
fromUNat :: UNat n -> SNat n
fromUNat UZero     = SNat :: SNat 0
fromUNat (USucc x) = addSNat (fromUNat x) (SNat :: SNat 1)
addUNat :: UNat n -> UNat m -> UNat (n + m)
addUNat UZero     y     = y
addUNat x         UZero = x
addUNat (USucc x) y     = USucc (addUNat x y)
mulUNat :: UNat n -> UNat m -> UNat (n * m)
mulUNat UZero      _     = UZero
mulUNat _          UZero = UZero
mulUNat (USucc x) y      = addUNat y (mulUNat x y)
powUNat :: UNat n -> UNat m -> UNat (n ^ m)
powUNat _ UZero     = USucc UZero
powUNat x (USucc y) = mulUNat x (powUNat x y)
predUNat :: UNat (n+1) -> UNat n
predUNat (USucc x) = x
predUNat UZero     =
  error "predUNat: impossible: 0 minus 1, -1 is not a natural number"
subUNat :: UNat (m+n) -> UNat n -> UNat m
subUNat x         UZero     = x
subUNat (USucc x) (USucc y) = subUNat x y
subUNat UZero     _         = error "subUNat: impossible: 0 + (n + 1) ~ 0"
predSNat :: SNat (a+1) -> SNat (a)
predSNat SNat = SNat
{-# INLINE predSNat #-}
succSNat :: SNat a -> SNat (a+1)
succSNat SNat = SNat
{-# INLINE succSNat #-}
addSNat :: SNat a -> SNat b -> SNat (a+b)
addSNat SNat SNat = SNat
{-# INLINE addSNat #-}
infixl 6 `addSNat`
subSNat :: SNat (a+b) -> SNat b -> SNat a
subSNat SNat SNat = SNat
{-# INLINE subSNat #-}
infixl 6 `subSNat`
mulSNat :: SNat a -> SNat b -> SNat (a*b)
mulSNat SNat SNat = SNat
{-# INLINE mulSNat #-}
infixl 7 `mulSNat`
powSNat :: SNat a -> SNat b -> SNat (a^b)
powSNat SNat SNat = SNat
{-# NOINLINE powSNat #-}
infixr 8 `powSNat`
divSNat :: (1 <= b) => SNat a -> SNat b -> SNat (Div a b)
divSNat SNat SNat = SNat
{-# INLINE divSNat #-}
infixl 7 `divSNat`
modSNat :: (1 <= b) => SNat a -> SNat b -> SNat (Mod a b)
modSNat SNat SNat = SNat
{-# INLINE modSNat #-}
infixl 7 `modSNat`
minSNat :: SNat a -> SNat b -> SNat (Min a b)
minSNat SNat SNat = SNat
maxSNat :: SNat a -> SNat b -> SNat (Max a b)
maxSNat SNat SNat = SNat
flogBaseSNat :: (2 <= base, 1 <= x)
             => SNat base 
             -> SNat x
             -> SNat (FLog base x)
flogBaseSNat SNat SNat = SNat
{-# NOINLINE flogBaseSNat #-}
clogBaseSNat :: (2 <= base, 1 <= x)
             => SNat base 
             -> SNat x
             -> SNat (CLog base x)
clogBaseSNat SNat SNat = SNat
{-# NOINLINE clogBaseSNat #-}
logBaseSNat :: (FLog base x ~ CLog base x)
            => SNat base 
            -> SNat x
            -> SNat (Log base x)
logBaseSNat SNat SNat = SNat
{-# NOINLINE logBaseSNat #-}
pow2SNat :: SNat a -> SNat (2^a)
pow2SNat SNat = SNat
{-# INLINE pow2SNat #-}
data SNatLE a b where
  SNatLE :: forall a b . a <= b => SNatLE a b
  SNatGT :: forall a b . (b+1) <= a => SNatLE a b
compareSNat :: forall a b . SNat a -> SNat b -> SNatLE a b
compareSNat a b =
  if snatToInteger a <= snatToInteger b
     then unsafeCoerce (SNatLE @0 @0)
     else unsafeCoerce (SNatGT @1 @0)
data BNat :: Nat -> Type where
  BT :: BNat 0
  B0 :: BNat n -> BNat (2*n)
  B1 :: BNat n -> BNat ((2*n) + 1)
instance KnownNat n => Show (BNat n) where
  show x = 'b':show (natVal x)
instance KnownNat n => ShowX (BNat n) where
  showsPrecX = showsPrecXWith showsPrec
showBNat :: BNat n -> String
showBNat = go []
  where
    go :: String -> BNat m -> String
    go xs BT  = "0b" ++ xs
    go xs (B0 x) = go ('0':xs) x
    go xs (B1 x) = go ('1':xs) x
toBNat :: SNat n -> BNat n
toBNat s@SNat = toBNat' (snatToInteger s)
  where
    toBNat' :: forall m . Integer -> BNat m
    toBNat' 0 = unsafeCoerce BT
    toBNat' n = case n `divMod` 2 of
      (n',1) -> unsafeCoerce (B1 (toBNat' @(Div (m-1) 2) n'))
      (n',_) -> unsafeCoerce (B0 (toBNat' @(Div m 2) n'))
fromBNat :: BNat n -> SNat n
fromBNat BT     = SNat :: SNat 0
fromBNat (B0 x) = mulSNat (SNat :: SNat 2) (fromBNat x)
fromBNat (B1 x) = addSNat (mulSNat (SNat :: SNat 2) (fromBNat x))
                          (SNat :: SNat 1)
addBNat :: BNat n -> BNat m -> BNat (n+m)
addBNat (B0 a) (B0 b) = B0 (addBNat a b)
addBNat (B0 a) (B1 b) = B1 (addBNat a b)
addBNat (B1 a) (B0 b) = B1 (addBNat a b)
addBNat (B1 a) (B1 b) = B0 (succBNat (addBNat a b))
addBNat BT     b      = b
addBNat a      BT     = a
mulBNat :: BNat n -> BNat m -> BNat (n*m)
mulBNat BT      _  = BT
mulBNat _       BT = BT
mulBNat (B0 a)  b  = B0 (mulBNat a b)
mulBNat (B1 a)  b  = addBNat (B0 (mulBNat a b)) b
powBNat :: BNat n -> BNat m -> BNat (n^m)
powBNat _  BT      = B1 BT
powBNat a  (B0 b)  = let z = powBNat a b
                     in  mulBNat z z
powBNat a  (B1 b)  = let z = powBNat a b
                     in  mulBNat a (mulBNat z z)
succBNat :: BNat n -> BNat (n+1)
succBNat BT     = B1 BT
succBNat (B0 a) = B1 a
succBNat (B1 a) = B0 (succBNat a)
predBNat :: (1 <= n) => BNat n -> BNat (n-1)
predBNat (B1 a) = case stripZeros a of
  BT -> BT
  a' -> B0 a'
predBNat (B0 x) = B1 (predBNat x)
div2BNat :: BNat (2*n) -> BNat n
div2BNat BT     = BT
div2BNat (B0 x) = x
div2BNat (B1 _) = error "div2BNat: impossible: 2*n ~ 2*n+1"
div2Sub1BNat :: BNat (2*n+1) -> BNat n
div2Sub1BNat (B1 x) = x
div2Sub1BNat _      = error "div2Sub1BNat: impossible: 2*n+1 ~ 2*n"
log2BNat :: BNat (2^n) -> BNat n
log2BNat BT = error "log2BNat: log2(0) not defined"
log2BNat (B1 x) = case stripZeros x of
  BT -> BT
  _  -> error "log2BNat: impossible: 2^n ~ 2x+1"
log2BNat (B0 x) = succBNat (log2BNat x)
stripZeros :: BNat n -> BNat n
stripZeros BT      = BT
stripZeros (B1 x)  = B1 (stripZeros x)
stripZeros (B0 BT) = BT
stripZeros (B0 x)  = case stripZeros x of
  BT -> BT
  k  -> B0 k
leToPlus
  :: forall (k :: Nat) (n :: Nat) r
   . ( k <= n
     )
  => (forall m . (n ~ (k + m)) => r)
  
  -> r
leToPlus r = r @(n - k)
{-# INLINE leToPlus #-}
leToPlusKN
  :: forall (k :: Nat) (n :: Nat) r
   . ( k <= n
     , KnownNat k
     , KnownNat n
     )
  => (forall m . (n ~ (k + m), KnownNat m) => r)
  
  -> r
leToPlusKN r = r @(n - k)
{-# INLINE leToPlusKN #-}