{-# LANGUAGE UndecidableInstances #-} module Data.Type.BitRecords.Arithmetic where import Data.Type.Bool import GHC.TypeLits -- | Get the remainder of the integer division of x and y, such that @forall x -- y. exists k. (Rem x y) == x - y * k@ The algorithm is: count down x -- until zero, incrementing the accumulator at each step. Whenever the -- accumulator is equal to y set it to zero. -- -- If the accumulator has reached y reset it. It is important to do this -- BEFORE checking if x == y and then returning the accumulator, for the case -- where x = k * y with k > 0. For example: -- -- @ -- 6 `Rem` 3 = RemImpl 6 3 0 -- RemImpl 6 3 0 = RemImpl (6-1) 3 (0+1) -- RemImpl Clause 4 -- RemImpl 5 3 1 = RemImpl (5-1) 3 (1+1) -- RemImpl Clause 4 -- RemImpl 4 3 2 = RemImpl (4-1) 3 (2+1) -- RemImpl Clause 4 -- RemImpl 3 3 3 = RemImpl 3 3 0 -- RemImpl Clause 2 !!! -- RemImpl 3 3 0 = 0 -- RemImpl Clause 3 !!! -- @ type family Rem (x :: Nat) (y :: Nat) :: Nat where Rem x 1 = 0 Rem x 0 = TypeError ('Text "divide by zero: " ':<>: 'ShowType x ':<>: 'Text " `Rem` 0") Rem x y = RemImpl x y 0 type family RemImpl (x :: Nat) (y :: nat) (acc :: Nat) :: Nat where -- finished if x was < y: RemImpl 0 y acc = acc RemImpl x y y = RemImpl x y 0 -- finished if x was >= y: RemImpl y y acc = acc -- the base case RemImpl x y acc = RemImpl (x - 1) y (acc + 1) -- | Efficient 'Rem' operation for power of 2 values. Note that x must be -- representable by 'RemPow2Bits' bits. type RemPow2 x p = FromBits (TakeLastN p (ToBits x RemPow2Bits)) type TakeLastN n xs = TakeLastNImplRev n xs '[] type family TakeLastNImplRev (n :: Nat) (xs :: [t]) (acc :: [t]) :: [t] where TakeLastNImplRev n '[] acc = TakeLastNImplTakeNRev n acc '[] TakeLastNImplRev n (x ': xs) acc = TakeLastNImplRev n xs (x ': acc) type family TakeLastNImplTakeNRev (n :: Nat) (rs :: [t]) (acc :: [t]) :: [t] where TakeLastNImplTakeNRev n '[] acc = acc TakeLastNImplTakeNRev 0 rs acc = acc TakeLastNImplTakeNRev n (r ': rs) acc = TakeLastNImplTakeNRev (n-1) rs (r ': acc) -- | Maximum number of bits an argument @x@ of 'RemPow2' may occupy. type RemPow2Bits = 32 -- | Integer division of x and y: @Div x y ==> x / y@, -- NOTE This only works for small numbers currently type Div (x :: Nat) (y :: Nat) = DivImpl (x - (x `Rem` y)) y 0 type family DivImpl (x :: Nat) (y :: nat) (acc :: Nat) :: Nat where DivImpl 0 y acc = acc DivImpl x y acc = If (x + 1 <=? y) acc (DivImpl (x - y) y (acc + 1)) -- * Bit manipulation type family TestHighBit (x :: Nat) (n :: Nat) :: Bool where TestHighBit x n = ((2 ^ n) <=? x) -- x > 2^n type ToBits x n = ToBits_ x n 'False type family ToBits_ (x :: Nat) (n :: Nat) (started :: Bool) :: [Bool] where ToBits_ x 0 started = '[] ToBits_ x n started = ToBitsInner (TestHighBit x (n - 1)) x (n - 1) started type family ToBitsInner (highBitSet :: Bool) (x :: Nat) (n :: Nat) (started :: Bool) :: [Bool] where ToBitsInner 'True x n started = 'True ': ToBits_ (x - 2^n) n 'True ToBitsInner 'False x n 'False = ToBits_ x n 'False ToBitsInner 'False x n 'True = 'False ': ToBits_ x n 'True type FromBits bits = FromBits_ bits 0 type family FromBits_ (bits :: [Bool]) (acc :: Nat) :: Nat where FromBits_ '[] acc = acc FromBits_ ('False ': rest) acc = FromBits_ rest (acc + acc) FromBits_ ('True ': rest) acc = FromBits_ rest (1 + acc + acc) type family ShiftBitsR (bits :: [Bool]) (n :: Nat) :: [Bool] where ShiftBitsR bits 0 = bits ShiftBitsR '[] n = '[] ShiftBitsR '[e] 1 = '[] ShiftBitsR (e ': rest) 1 = e ': ShiftBitsR rest 1 ShiftBitsR (e ': rest) n = ShiftBitsR (ShiftBitsR (e ': rest) 1) (n - 1) type family GetMostSignificantBitIndex (highestBit :: Nat) (n :: Nat) :: Nat where GetMostSignificantBitIndex 0 n = 1 GetMostSignificantBitIndex highestBit n = If (2 ^ (highestBit + 1) <=? n) (TypeError ('Text "number to big: " ':<>: 'ShowType n ':<>: 'Text " >= " ':<>: 'ShowType (2 ^ (highestBit + 1)))) (If (2 ^ highestBit <=? n) highestBit (GetMostSignificantBitIndex (highestBit - 1) n)) -- | Shift a type level natural to the right. This useful for division by powers -- of two. type family ShiftR (xMaxBits :: Nat) (x :: Nat) (bits :: Nat) :: Nat where ShiftR xMaxBits x n = FromBits (ShiftBitsR (ToBits x (1 + GetMostSignificantBitIndex xMaxBits x)) n)