{-# LANGUAGE UndecidableInstances #-} module Data.Type.BitRecords.Arithmetic where import Data.Type.Bool import GHC.TypeLits -- | Efficient 'Mod' operation for power of 2 values. Note that x must be -- representable by 'ModPow2Bits' bits. type ModPow2 value power = FromBits (TakeLastN power (ToBits value ModPow2Bits)) type TakeLastN n xs = TakeLastNImplRev n xs '[] type family TakeLastNImplRev (n :: Nat) (xs :: [t]) (acc :: [t]) :: [t] where TakeLastNImplRev n '[] acc = TakeLastNImplTakeNRev n acc '[] TakeLastNImplRev n (x ': xs) acc = TakeLastNImplRev n xs (x ': acc) type family TakeLastNImplTakeNRev (n :: Nat) (rs :: [t]) (acc :: [t]) :: [t] where TakeLastNImplTakeNRev n '[] acc = acc TakeLastNImplTakeNRev 0 rs acc = acc TakeLastNImplTakeNRev n (r ': rs) acc = TakeLastNImplTakeNRev (n-1) rs (r ': acc) -- | Maximum number of bits an argument @x@ of 'ModPow2' may occupy. type ModPow2Bits = 32 -- * Bit manipulation type family TestHighBit (x :: Nat) (n :: Nat) :: Bool where TestHighBit x n = ((2 ^ n) <=? x) -- x > 2^n type ToBits x n = ToBits_ x n 'False type family ToBits_ (x :: Nat) (n :: Nat) (started :: Bool) :: [Bool] where ToBits_ x 0 started = '[] ToBits_ x n started = ToBitsInner (TestHighBit x (n - 1)) x (n - 1) started type family ToBitsInner (highBitSet :: Bool) (x :: Nat) (n :: Nat) (started :: Bool) :: [Bool] where ToBitsInner 'True x n started = 'True ': ToBits_ (x - 2^n) n 'True ToBitsInner 'False x n 'False = ToBits_ x n 'False ToBitsInner 'False x n 'True = 'False ': ToBits_ x n 'True type FromBits bits = FromBits_ bits 0 type family FromBits_ (bits :: [Bool]) (acc :: Nat) :: Nat where FromBits_ '[] acc = acc FromBits_ ('False ': rest) acc = FromBits_ rest (acc + acc) FromBits_ ('True ': rest) acc = FromBits_ rest (1 + acc + acc) type family ShiftBitsR (bits :: [Bool]) (n :: Nat) :: [Bool] where ShiftBitsR bits 0 = bits ShiftBitsR '[] n = '[] ShiftBitsR '[e] 1 = '[] ShiftBitsR (e ': rest) 1 = e ': ShiftBitsR rest 1 ShiftBitsR (e ': rest) n = ShiftBitsR (ShiftBitsR (e ': rest) 1) (n - 1) type family GetMostSignificantBitIndex (highestBit :: Nat) (n :: Nat) :: Nat where GetMostSignificantBitIndex 0 n = 1 GetMostSignificantBitIndex highestBit n = If (2 ^ (highestBit + 1) <=? n) (TypeError ('Text "number to big: " ':<>: 'ShowType n ':<>: 'Text " >= " ':<>: 'ShowType (2 ^ (highestBit + 1)))) (If (2 ^ highestBit <=? n) highestBit (GetMostSignificantBitIndex (highestBit - 1) n)) -- | Shift a type level natural to the right. This useful for division by powers -- of two. type family ShiftR (xMaxBits :: Nat) (x :: Nat) (bits :: Nat) :: Nat where ShiftR xMaxBits x n = FromBits (ShiftBitsR (ToBits x (1 + GetMostSignificantBitIndex xMaxBits x)) n)