module Mezzo.Model.Prim
(
Vector (..)
, Times (..)
, Elem (..)
, type (**)
, OptVector (..)
, Head
, Head'
, Last
, Tail'
, Init'
, Length
, Length'
, Matrix
, type (++)
, type (++.)
, type (:-|)
, type (+*+)
, type (+|+)
, type (+-+)
, Align
, VectorToColMatrix
, If
, Not
, type (.&&.)
, type (.||.)
, type (.~.)
, MaxN
, MinN
, Valid
, Invalid
, AllSatisfy
, AllPairsSatisfy
, AllPairsSatisfy'
, SatisfiesAll
, AllSatisfyAll
) where
import Data.Kind
import GHC.TypeLits
infixr 7 :*
infixr 7 **
infixr 6 :-
infixr 5 :--
infixl 4 ++
infixl 4 +*+
infixl 4 +|+
infixl 4 +-+
infixl 3 .&&.
infixl 3 .||.
infixl 5 .~.
data Vector :: Type -> Nat -> Type where
None :: Vector t 0
(:--) :: t -> Vector t (n 1) -> Vector t n
data Times (n :: Nat) where
T :: Times n
data Elem :: Type -> Nat -> Type where
(:*) :: t -> Times n -> Elem t n
type family (v :: t) ** (d :: Nat) :: Elem t d where
v ** d = v :* (T :: Times d)
data OptVector :: Type -> Nat -> Type where
End :: OptVector t 0
(:-) :: Elem t l -> OptVector t (n l) -> OptVector t n
type family Head (v :: OptVector t n) :: t where
Head End = TypeError (Text "Vector has no head element.")
Head (v :* _ :- _) = v
type family Head' (v :: Vector t n) :: t where
Head' None = TypeError (Text "Vector has no head element.")
Head' (v :-- _) = v
type family Last (v :: OptVector t n) :: t where
Last End = TypeError (Text "Vector has no last element.")
Last (v :* _ :- End) = v
Last (_ :- vs) = Last vs
type family Tail' (v :: Vector t n) :: Vector t (n 1) where
Tail' None = TypeError (Text "Vector has no tail.")
Tail' (_ :-- vs) = vs
type family Init' (v :: Vector t n) :: Vector t (n 1) where
Init' None = TypeError (Text "Vector is empty.")
Init' (p :-- None) = None
Init' (p :-- ps) = p :-- Init' ps
type family Length (v :: OptVector t n) :: Nat where
Length (v :: OptVector t n) = n
type family Length' (v :: Vector t n) :: Nat where
Length' (v :: Vector t n) = n
type family (x :: OptVector t n) ++ (y :: OptVector t m) :: OptVector t (n + m) where
ys ++ End = ys
End ++ ys = ys
(x :- xs) ++ ys = x :- (xs ++ ys)
type family (x :: Vector t n) ++. (y :: Vector t m) :: Vector t (n + m) where
None ++. ys = ys
(x :-- xs) ++. ys = x :-- (xs ++. ys)
type family (v :: Vector t n) :-| (e :: t) :: Vector t (n + 1) where
v :-| e = v ++. (e :-- None)
type family (a :: t) +*+ (n :: Nat) :: OptVector t n where
x +*+ 0 = End
x +*+ n = x ** n :- End
type Matrix t p q = Vector (OptVector t q) p
type family (a :: Matrix t p q) +|+ (b :: Matrix t p r) :: Matrix t p (q + r) where
None +|+ None = None
(r1 :-- rs1) +|+ (r2 :-- rs2) = (r1 ++ r2) :-- (rs1 +|+ rs2)
type family (a :: Matrix t p r) +-+ (b :: Matrix t q r) :: Matrix t (p + q) r where
m1 +-+ m2 = ConcatPair (Align m1 m2)
type family ConcatPair (vs :: (Vector t p, Vector t q)) :: Vector t (p + q) where
ConcatPair '(v1, v2) = v1 ++. v2
type family Align (a :: Matrix t p r) (b :: Matrix t q r) :: (Matrix t p r, Matrix t q r) where
Align None m = '(None, m)
Align m None = '(m, None)
Align (r1 :-- rs1) (r2 :-- rs2) =
'(FragmentMatByVec (r1 :-- rs1) r2, FragmentMatByVec (r2 :-- rs2) r1)
type family FragmentMatByVec (m :: Matrix t q p) (v :: OptVector t p) :: Matrix t q p where
FragmentMatByVec None _ = None
FragmentMatByVec (r :-- rs) v = FragmentVecByVec r v :-- FragmentMatByVec rs v
type family FragmentVecByVec (v :: OptVector t p) (u :: OptVector t p) :: OptVector t p where
FragmentVecByVec End _ = End
FragmentVecByVec (v :* (T :: Times k) :- vs) (u :* (T :: Times k) :- us) =
v ** k :- (FragmentVecByVec vs us)
FragmentVecByVec (v :* (T :: Times k) :- vs) (u :* (T :: Times l) :- us) =
If (k <=? l)
((v ** k) :- (FragmentVecByVec vs (u ** (l k) :- us)))
((v ** l) :- (FragmentVecByVec (v ** (k l) :- vs) us))
type family VectorToColMatrix (v :: Vector t n) (l :: Nat) :: Matrix t n l where
VectorToColMatrix None _ = None
VectorToColMatrix (v :-- vs) l = (VectorToColMatrix vs l) ++. (v ** l :- End :-- None)
type family If (b :: Bool) (t :: k) (e :: k) :: k where
If True t e = t
If False t e = e
type family Not (a :: Bool) :: Bool where
Not True = False
Not False = True
type family (b1 :: Bool) .&&. (b2 :: Bool) :: Bool where
b1 .&&. b2 = If b1 b2 False
type family (b1 :: Bool) .||. (b2 :: Bool) :: Bool where
b1 .||. b2 = If b1 True b2
type family (a :: k) .~. (b :: k) :: Bool where
a .~. a = True
a .~. b = False
type family MaxN (n1 :: Nat) (n2 :: Nat) :: Nat where
MaxN 0 n2 = n2
MaxN n1 0 = n1
MaxN n n = n
MaxN n1 n2 = If (n1 <=? n2) (n2) (n1)
type family MinN (n1 :: Nat) (n2 :: Nat) :: Nat where
MinN 0 n2 = 0
MinN n1 0 = 0
MinN n n = n
MinN n1 n2 = If (n1 <=? n2) (n1) (n2)
type Valid = (() :: Constraint)
type Invalid = True ~ False
type family AllSatisfy (c :: a -> Constraint)
(xs :: OptVector a n)
:: Constraint where
AllSatisfy c End = Valid
AllSatisfy c (x :* _ :- xs) = ((c x), AllSatisfy c xs)
type family AllPairsSatisfy (c :: a -> b -> Constraint)
(xs :: OptVector a n) (ys :: OptVector b n)
:: Constraint where
AllPairsSatisfy c End End = Valid
AllPairsSatisfy c (x :* _ :- xs) (y :* _ :- ys) = ((c x y), AllPairsSatisfy c xs ys)
type family AllPairsSatisfy' (c :: a -> b -> Constraint)
(xs :: Vector a n) (ys :: Vector b n)
:: Constraint where
AllPairsSatisfy' c None None = Valid
AllPairsSatisfy' c (x :-- xs) (y :-- ys) = ((c x y), AllPairsSatisfy' c xs ys)
type family SatisfiesAll (cs :: [a -> Constraint])
(xs :: a)
:: Constraint where
SatisfiesAll '[] a = Valid
SatisfiesAll (c : cs) a = (c a, SatisfiesAll cs a)
type family AllSatisfyAll (c1 :: [a -> Constraint])
(xs :: Vector a n)
:: Constraint where
AllSatisfyAll _ None = Valid
AllSatisfyAll cs (v :-- vs) = (SatisfiesAll cs v, AllSatisfyAll cs vs)