module CLaSH.Promoted.Nat
  ( SNat (..), snat, withSNat, snatToInteger, addSNat, subSNat, mulSNat, powSNat
  , UNat (..), toUNat, addUNat, multUNat, powUNat
  )
where
import Data.Proxy      (Proxy (..))
import Data.Reflection (reifyNat)
import GHC.TypeLits    (KnownNat, Nat, type (+), type (), type (*), type (^),
                        natVal)
import Unsafe.Coerce   (unsafeCoerce)
data SNat (n :: Nat) = KnownNat n => SNat (Proxy n)
instance Show (SNat n) where
  show (SNat p) = 'd' : show (natVal p)
snat :: KnownNat n => SNat n
snat = SNat Proxy
withSNat :: KnownNat n => (SNat n -> a) -> a
withSNat f = f (SNat Proxy)
snatToInteger :: SNat n -> Integer
snatToInteger (SNat p) = natVal p
data UNat :: Nat -> * where
  UZero :: UNat 0
  USucc :: UNat n -> UNat (n + 1)
toUNat :: SNat n -> UNat n
toUNat (SNat p) = fromI (natVal p)
  where
    fromI :: Integer -> UNat m
    fromI 0 = unsafeCoerce UZero
    fromI n = unsafeCoerce (USucc (fromI (n  1)))
addUNat :: UNat n -> UNat m -> UNat (n + m)
addUNat UZero     y     = y
addUNat x         UZero = x
addUNat (USucc x) y     = USucc (addUNat x y)
multUNat :: UNat n -> UNat m -> UNat (n * m)
multUNat UZero      _     = UZero
multUNat _          UZero = UZero
multUNat (USucc x) y      = addUNat y (multUNat x y)
powUNat :: UNat n -> UNat m -> UNat (n ^ m)
powUNat _ UZero     = USucc UZero
powUNat x (USucc y) = multUNat x (powUNat x y)
addSNat :: SNat a -> SNat b -> SNat (a+b)
addSNat x y = reifyNat (snatToInteger x + snatToInteger y) (unsafeCoerce . SNat)
subSNat :: SNat a -> SNat b -> SNat (ab)
subSNat x y = reifyNat (snatToInteger x  snatToInteger y) (unsafeCoerce . SNat)
mulSNat :: SNat a -> SNat b -> SNat (a*b)
mulSNat x y = reifyNat (snatToInteger x * snatToInteger y) (unsafeCoerce . SNat)
powSNat :: SNat a -> SNat b -> SNat (a^b)
powSNat x y = reifyNat (snatToInteger x ^ snatToInteger y) (unsafeCoerce . SNat)