{-# LANGUAGE DataKinds #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE CPP #-} -- We need this otherwise GHC chokes on the export of -- "type (*)" #if MIN_VERSION_GLASGOW_HASKELL (8,6,0,0) {-# LANGUAGE NoStarIsType #-} #endif -- | Type-level Nat module Haskus.Utils.Types.Nat ( Nat , natValue , natValue' , KnownNat , SomeNat (..) , someNatVal , sameNat -- * Comparisons , CmpNat , type (<=?) , type (<=) , NatEq , NatNotEq , Max , Min , IsZero , IsNotZero -- * Operations , type (+) , type (-) , type (*) , type (^) , Mod , Log2 , Div -- * Helpers , NatBitCount ) where import GHC.TypeNats import Haskus.Utils.Types.Bool import Data.Proxy -- $setup -- >>> :set -XDataKinds -- >>> :set -XTypeApplications -- >>> :set -XFlexibleContexts -- >>> :set -XTypeFamilies -- >>> import Haskus.Utils.Types -- | Get a Nat value natValue :: forall (n :: Nat) a. (KnownNat n, Num a) => a {-# INLINABLE natValue #-} natValue = fromIntegral (natVal (Proxy :: Proxy n)) -- | Get a Nat value as a Word natValue' :: forall (n :: Nat). KnownNat n => Word {-# INLINABLE natValue' #-} natValue' = natValue @n -- | Type equality to Nat type family NatEq a b :: Nat where NatEq a a = 1 NatEq a b = 0 -- | Type inequality to Nat type family NatNotEq a b :: Nat where NatNotEq a a = 0 NatNotEq a b = 1 -- | Max of two naturals type family Max (a :: Nat) (b :: Nat) where Max a b = If (a <=? b) b a -- | Min of two naturals type family Min (a :: Nat) (b :: Nat) where Min a b = If (a <=? b) a b -- | Number of bits (>= 1) required to store a Nat value -- -- >>> natValue' @(NatBitCount 0) -- 1 -- -- >>> natValue' @(NatBitCount 1) -- 1 -- -- >>> natValue' @(NatBitCount 2) -- 2 -- -- >>> natValue' @(NatBitCount 5) -- 3 -- -- >>> natValue' @(NatBitCount 15) -- 4 -- -- >>> natValue' @(NatBitCount 16) -- 5 -- type family NatBitCount (n :: Nat) :: Nat where NatBitCount 0 = 1 NatBitCount n = NatBitCount' (n+1) (Log2 (n+1)) type family NatBitCount' v log2 where NatBitCount' v log2 = log2 + NatNotEq v (2^log2) -- | Return 1 if 0, and 0 otherwise type family IsZero (n :: Nat) :: Nat where IsZero 0 = 1 IsZero _ = 0 -- | Return 0 if 0, and 1 otherwise type family IsNotZero (n :: Nat) :: Nat where IsNotZero 0 = 0 IsNotZero _ = 1