module Data.Type.BitRecords.Arithmetic where
import Data.Type.Bool
import GHC.TypeLits
type family Rem (x :: Nat) (y :: Nat) :: Nat where
Rem x 1 = 0
Rem x 0 = TypeError ('Text "divide by zero: " ':<>: 'ShowType x ':<>: 'Text " `Rem` 0")
Rem x y = RemImpl x y 0
type family
RemImpl (x :: Nat) (y :: nat) (acc :: Nat) :: Nat where
RemImpl 0 y acc = acc
RemImpl x y y = RemImpl x y 0
RemImpl y y acc = acc
RemImpl x y acc = RemImpl (x 1) y (acc + 1)
type RemPow2 x p =
FromBits (TakeLastN p (ToBits x RemPow2Bits))
type TakeLastN n xs = TakeLastNImplRev n xs '[]
type family TakeLastNImplRev (n :: Nat) (xs :: [t]) (acc :: [t]) :: [t] where
TakeLastNImplRev n '[] acc = TakeLastNImplTakeNRev n acc '[]
TakeLastNImplRev n (x ': xs) acc =
TakeLastNImplRev n xs (x ': acc)
type family TakeLastNImplTakeNRev (n :: Nat) (rs :: [t]) (acc :: [t]) :: [t] where
TakeLastNImplTakeNRev n '[] acc = acc
TakeLastNImplTakeNRev 0 rs acc = acc
TakeLastNImplTakeNRev n (r ': rs) acc = TakeLastNImplTakeNRev (n1) rs (r ': acc)
type RemPow2Bits = 32
type Div (x :: Nat) (y :: Nat) = DivImpl (x (x `Rem` y)) y 0
type family
DivImpl (x :: Nat) (y :: nat) (acc :: Nat) :: Nat where
DivImpl 0 y acc = acc
DivImpl x y acc = If (x + 1 <=? y) acc (DivImpl (x y) y (acc + 1))
type family TestHighBit (x :: Nat) (n :: Nat) :: Bool where
TestHighBit x n = ((2 ^ n) <=? x)
type ToBits x n = ToBits_ x n 'False
type family ToBits_ (x :: Nat) (n :: Nat) (started :: Bool) :: [Bool] where
ToBits_ x 0 started = '[]
ToBits_ x n started = ToBitsInner (TestHighBit x (n 1)) x (n 1) started
type family
ToBitsInner (highBitSet :: Bool) (x :: Nat) (n :: Nat) (started :: Bool) :: [Bool] where
ToBitsInner 'True x n started = 'True ': ToBits_ (x 2^n) n 'True
ToBitsInner 'False x n 'False = ToBits_ x n 'False
ToBitsInner 'False x n 'True = 'False ': ToBits_ x n 'True
type FromBits bits = FromBits_ bits 0
type family FromBits_ (bits :: [Bool]) (acc :: Nat) :: Nat where
FromBits_ '[] acc = acc
FromBits_ ('False ': rest) acc = FromBits_ rest (acc + acc)
FromBits_ ('True ': rest) acc = FromBits_ rest (1 + acc + acc)
type family
ShiftBitsR (bits :: [Bool]) (n :: Nat) :: [Bool] where
ShiftBitsR bits 0 = bits
ShiftBitsR '[] n = '[]
ShiftBitsR '[e] 1 = '[]
ShiftBitsR (e ': rest) 1 = e ': ShiftBitsR rest 1
ShiftBitsR (e ': rest) n = ShiftBitsR (ShiftBitsR (e ': rest) 1) (n 1)
type family
GetMostSignificantBitIndex (highestBit :: Nat) (n :: Nat) :: Nat where
GetMostSignificantBitIndex 0 n = 1
GetMostSignificantBitIndex highestBit n =
If (2 ^ (highestBit + 1) <=? n)
(TypeError ('Text "number to big: "
':<>: 'ShowType n
':<>: 'Text " >= "
':<>: 'ShowType (2 ^ (highestBit + 1))))
(If (2 ^ highestBit <=? n)
highestBit
(GetMostSignificantBitIndex (highestBit 1) n))
type family
ShiftR (xMaxBits :: Nat) (x :: Nat) (bits :: Nat) :: Nat where
ShiftR xMaxBits x n =
FromBits
(ShiftBitsR
(ToBits x
(1 + GetMostSignificantBitIndex xMaxBits x))
n)