module Data.ByteString.IsoBaseFileFormat.Util.BitRecords where
import Data.Kind
import Data.Word
import Data.Type.Bool
import GHC.TypeLits
import Data.Bits
import Data.Proxy
import Test.TypeSpecCrazy
data Field :: Nat -> Type
data (:=>) :: k -> Type -> Type
data (:*:) :: Type -> Type -> Type
type FieldPosition = (Nat, Nat)
type Flag = Field 1
infixr 6 :=>
infixl 5 :*:
data (:/) :: Symbol -> k -> Type
infixr 7 :/
type family
GetFieldSize (f :: l) :: Nat where
GetFieldSize (label :=> f) = GetFieldSize f
GetFieldSize (Field n ) = n
GetFieldSize (l :*: r) = GetFieldSize l + GetFieldSize r
type family
HasField (f :: fk) (l :: lk) :: Bool where
HasField (l :=> f) l = 'True
HasField (l :=> f) (l :/ p) = HasField f p
HasField (f1 :*: f2) l = HasField f1 l || HasField f2 l
HasField f l = 'False
type family
HasFieldConstraint (label :: lk) (field :: fk) :: Constraint where
HasFieldConstraint l f =
If (HasField f l)
(HasField f l ~ 'True)
(TypeError ('Text "Label not found: '"
':<>: 'ShowType l
':<>: 'Text "' in:"
':$$: 'ShowType f ))
type family
FocusOn (l :: lk) (f :: fk) :: Result fk where
FocusOn l f =
If (HasField f l)
('Right (FocusOnUnsafe l f))
('Left ('Text "Label not found. Cannot focus '"
':<>: 'ShowType l
':<>: 'Text "' in:"
':$$: 'ShowType f ))
type family
FocusOnUnsafe (l :: lk) (f :: fk) :: fk where
FocusOnUnsafe l (l :=> f) = f
FocusOnUnsafe (l :/ p) (l :=> f) = FocusOnUnsafe p f
FocusOnUnsafe l (f :*: f') = FocusOnUnsafe l (If (HasField f l) f f')
type family
GetFieldPosition (f :: field) (l :: label) :: Result FieldPosition where
GetFieldPosition f l =
If (HasField f l)
('Right (GetFieldPositionUnsafe f l))
('Left ('Text "Label not found. Cannot get bit range for '"
':<>: 'ShowType l
':<>: 'Text "' in:"
':$$: 'ShowType f ))
type family
GetFieldPositionUnsafe (f :: field) (l :: label) :: FieldPosition where
GetFieldPositionUnsafe (l :=> f) l = '(0, GetFieldSize f 1)
GetFieldPositionUnsafe (l :=> f) (l :/ p) = GetFieldPositionUnsafe f p
GetFieldPositionUnsafe (f :*: f') l =
If (HasField f l)
(GetFieldPositionUnsafe f l)
(AddToFieldPosition (GetFieldSize f) (GetFieldPositionUnsafe f' l))
type family
AddToFieldPosition (v :: Nat) (e :: (Nat, Nat)) :: (Nat, Nat) where
AddToFieldPosition v '(a,b) = '(a + v, b + v)
type family
IsFieldPostition (pos :: FieldPosition) :: Constraint where
IsFieldPostition '(a, b) =
If (a <=? b)
(a <= b, KnownNat a, KnownNat b)
(TypeError
('Text "Bad field position: " ':<>: 'ShowType '(a,b)
':$$: 'Text "First index greater than last: "
':<>: 'ShowType a
':<>: 'Text " > "
':<>: 'ShowType b ))
type family
FieldPostitionToList (pos :: FieldPosition) :: [Nat] where
FieldPostitionToList '(a, a) = '[a]
FieldPostitionToList '(a, b) = (a ': (FieldPostitionToList '(a+1, b)))
type family
AlignField (a :: Nat) (f :: field) :: Result field where
AlignField 0 f = 'Left ('Text "Invalid alignment of 0")
AlignField a f = 'Right (AddPadding ((a (GetFieldSize f `Rem` a)) `Rem` a) f)
type family
AddPadding (n :: Nat) (f :: field) :: field where
AddPadding 0 f = f
AddPadding n f = f :*: Field n
type Rem (x :: Nat) (y :: Nat) = RemImpl x y 0
type family
RemImpl (x :: Nat) (y :: nat) (acc :: Nat) :: Nat where
RemImpl 0 y acc = acc
RemImpl x y y = RemImpl x y 0
RemImpl y y acc = acc
RemImpl x y acc = RemImpl (x 1) y (acc + 1)
getFlag
:: forall a (path :: k) (first :: Nat) field p1 p2
. ( IsFieldC path field first first
, Bits a )
=> p1 path -> p2 field -> a -> Bool
getFlag _ _ a = testBit a pos
where pos = fromIntegral $ natVal (Proxy :: Proxy first)
setFlag
:: forall a (path :: k) (first :: Nat) field p1 p2
. ( IsFieldC path field first first
, Bits a )
=> p1 path -> p2 field -> Bool -> a -> a
setFlag _ _ v a = modifyBit a pos
where pos = fromIntegral $ natVal (Proxy :: Proxy first)
modifyBit = if v then setBit else clearBit
getField
:: forall a b (path :: k) (first :: Nat) (last :: Nat) field pxy1 pxy2
. ( IsFieldC path field first last
, Integral a
, Bits a
, Num b)
=> pxy1 path -> pxy2 field -> a -> b
getField _ _ a = fromIntegral ((a `shiftR` posFirst) .&. bitMask)
where
bitMask =
let bitCount = 1 + posLast posFirst
in (2 ^ bitCount) 1
posFirst = fromIntegral $ natVal (Proxy :: Proxy first)
posLast = fromIntegral $ natVal (Proxy :: Proxy last)
setField
:: forall a b (path :: k) (first :: Nat) (last :: Nat) field pxy1 pxy2
. ( IsFieldC path field first last
, Num a
, Bits a
, Integral b)
=> pxy1 path -> pxy2 field -> b -> a -> a
setField _ _ v x = (x .&. bitMaskField) .|. (v' `shiftL` posFirst)
where
v' = bitMaskValue .&. fromIntegral v
bitMaskField = complement (bitMaskValue `shiftL` posFirst)
bitMaskValue =
let bitCount = 1 + posLast posFirst
in (2 ^ bitCount) 1
posFirst = fromIntegral $ natVal (Proxy :: Proxy first)
posLast = fromIntegral $ natVal (Proxy :: Proxy last)
type Foo =
"foo" :=> Flag
:*: Field 4
:*: "bar" :=> Field 2
:*: Field 4
:*: "baz" :=> Field 17
type IsFieldC name field first last =
( name `HasFieldConstraint` field
, KnownNat first
, KnownNat last
, 'Right '(first, last) ~ (GetFieldPosition field name)
)
getFooField :: IsFieldC name Foo first last
=> proxy name -> Word64 -> Word64
getFooField px = getField px (Proxy :: Proxy Foo)