-----------------------------------------------------------------------------
-- |
-- Module      :  Mezzo.Model.Prim
-- Description :  Mezzo type primitives
-- Copyright   :  (c) Dima Szamozvancev
-- License     :  MIT
--
-- Maintainer  :  ds709@cam.ac.uk
-- Stability   :  experimental
-- Portability :  portable
--
-- Primitive types that make up the base for the Mezzo type model.
--
-----------------------------------------------------------------------------

module Mezzo.Model.Prim
    (
    -- * Vectors and matrices
      Vector (..)
    , Times (..)
    , Elem (..)
    , type (**)
    , OptVector (..)
    , Head
    , Head'
    , Last
    , Tail'
    , Init'
    , Length
    , Length'
    , Matrix
    , type (++)
    , type (++.)
    , type (:-|)
    , type (+*+)
    , type (+|+)
    , type (+-+)
    , Align
    , VectorToColMatrix
    -- * Logic and arithmetic
    , If
    , Not
    , type (.&&.)
    , type (.||.)
    , type (.~.)
    , MaxN
    , MinN
    -- * Constraints
    , Valid
    , Invalid
    , AllSatisfy
    , AllPairsSatisfy
    , AllPairsSatisfy'
    , SatisfiesAll
    , AllSatisfyAll
    ) where

import Data.Kind
import GHC.TypeLits

infixr 7 :*
infixr 7 **
infixr 6 :-
infixr 5 :--
infixl 4 ++
infixl 4 +*+
infixl 4 +|+
infixl 4 +-+
infixl 3 .&&.
infixl 3 .||.
infixl 5 .~.

-------------------------------------------------------------------------------
-- Type-level vectors and matrices
-------------------------------------------------------------------------------

-- | Simple length-indexed vector.
data Vector :: Type -> Nat -> Type where
    None :: Vector t 0
    (:--) :: t -> Vector t (n - 1) -> Vector t n

-- | Singleton type for the number of repetitions of an element.
data Times (n :: Nat) where
    T :: Times n

-- | An element of a "run-length encoded" vector, containing the value and
-- the number of repetitions
data Elem :: Type -> Nat -> Type where
    (:*) :: t -> Times n -> Elem t n

-- | Replicate a value the specified number of times to create a new 'Elem'.
type family (v :: t) ** (d :: Nat) :: Elem t d where
    v ** d = v :* (T :: Times d)

-- | A length-indexed vector, optimised for repetitions.
data OptVector :: Type -> Nat -> Type where
    End  :: OptVector t 0
    (:-) :: Elem t l -> OptVector t (n - l) -> OptVector t n

-- | Get the first element of an optimised vector.
type family Head (v :: OptVector t n) :: t where
    Head End           = TypeError (Text "Vector has no head element.")
    Head (v :* _ :- _) = v

-- | Get the first element of a simple vector.
type family Head' (v :: Vector t n) :: t where
    Head' None      = TypeError (Text "Vector has no head element.")
    Head' (v :-- _) = v

-- | Get the last element of the vector.
type family Last (v :: OptVector t n) :: t where
    Last End             = TypeError (Text "Vector has no last element.")
    Last (v :* _ :- End) = v
    Last (_ :- vs)       = Last vs

-- | Get the tail of the vector.
type family Tail' (v :: Vector t n) :: Vector t (n - 1) where
    Tail' None = TypeError (Text "Vector has no tail.")
    Tail' (_ :-- vs) = vs

-- | Get everything but the last element of the vector.
type family Init' (v :: Vector t n) :: Vector t (n - 1) where
    Init' None = TypeError (Text "Vector is empty.")
    Init' (p :-- None) = None
    Init' (p :-- ps) = p :-- Init' ps

-- | Get the length of an optimised vector.
type family Length (v :: OptVector t n) :: Nat where
    Length (v :: OptVector t n) = n

-- | Get the length of a vector.
type family Length' (v :: Vector t n) :: Nat where
    Length' (v :: Vector t n) = n

-- | Append two optimised vectors.
type family (x :: OptVector t n) ++ (y :: OptVector t m) :: OptVector t (n + m) where
    ys        ++ End = ys
    End       ++ ys = ys
    (x :- xs) ++ ys = x :- (xs ++ ys)

-- | Append two simple vectors.
type family (x :: Vector t n) ++. (y :: Vector t m) :: Vector t (n + m) where
    None       ++. ys = ys
    (x :-- xs) ++. ys = x :-- (xs ++. ys)

-- | Add an element to the end of a simple vector.
type family (v :: Vector t n) :-| (e :: t) :: Vector t (n + 1) where
    v :-| e = v ++. (e :-- None)

-- | Repeat the value the specified number of times to create a new 'OptVector'.
type family (a :: t) +*+ (n :: Nat) :: OptVector t n where
    x +*+ 0 = End
    x +*+ n = x ** n :- End

-- | A dimension-indexed matrix.
type Matrix t p q = Vector (OptVector t q) p

-- | Horizontal concatenation of type-level matrices.
-- Places the first matrix to the left of the second.
type family (a :: Matrix t p q) +|+ (b :: Matrix t p r) :: Matrix t p (q + r) where
    None         +|+ None         = None
    (r1 :-- rs1) +|+ (r2 :-- rs2) = (r1 ++ r2) :-- (rs1 +|+ rs2)

-- | Vertical concatenation of type-level matrices.
-- Places the first matrix on top of the second.
type family (a :: Matrix t p r) +-+ (b :: Matrix t q r) :: Matrix t (p + q) r where
    m1 +-+ m2 = ConcatPair (Align m1 m2)

-- | Concatenates a type-level pair of vectors.
type family ConcatPair (vs :: (Vector t p, Vector t q)) :: Vector t (p + q) where
    ConcatPair '(v1, v2) = v1 ++. v2

-- | Vertically aligns two matrices by separating elements so that the element
-- boundaries line up.
type family Align (a :: Matrix t p r) (b :: Matrix t q r) :: (Matrix t p r, Matrix t q r) where
    Align None m = '(None, m)
    Align m None = '(m, None)
    Align (r1 :-- rs1) (r2 :-- rs2) =
            '(FragmentMatByVec (r1 :-- rs1) r2, FragmentMatByVec (r2 :-- rs2) r1)

-- | Fragments a matrix by a vector: all the element boundaries in the vector must
-- also appear in the fragmented matrix.
type family FragmentMatByVec (m :: Matrix t q p) (v :: OptVector t p) :: Matrix t q p where
    FragmentMatByVec None       _ = None
    FragmentMatByVec (r :-- rs) v = FragmentVecByVec r v :-- FragmentMatByVec rs v

-- | Fragments a vector by another vector: all the element boundaries in the second
-- vector must also appear in the first.
type family FragmentVecByVec (v :: OptVector t p) (u :: OptVector t p) :: OptVector t p where
    FragmentVecByVec End _ = End
    -- If the lengths of the first element match up, they are not fragmented.
    FragmentVecByVec (v :* (T :: Times k) :- vs) (u :* (T :: Times k) :- us) =
            v ** k :- (FragmentVecByVec vs us)
    -- If the lengths of the first elements don't match up, we fragment the element
    -- by the shortest of the two lengths, and add the remainder as a separate element.
    FragmentVecByVec (v :* (T :: Times k) :- vs) (u :* (T :: Times l) :- us) =
        If (k <=? l)
            ((v ** k) :- (FragmentVecByVec vs (u ** (l - k) :- us)))
            ((v ** l) :- (FragmentVecByVec (v ** (k - l) :- vs) us))

-- | Convert a simple vector to a column matrix.
type family VectorToColMatrix (v :: Vector t n) (l :: Nat) :: Matrix t n l where
    VectorToColMatrix None _ = None
    VectorToColMatrix (v :-- vs) l = (VectorToColMatrix vs l) ++. (v ** l :- End :-- None)

-------------------------------------------------------------------------------
-- Type-level logic and arithmetic
-------------------------------------------------------------------------------

-- | Conditional expression at the type level.
type family If (b :: Bool) (t :: k) (e :: k) :: k where
    If True  t e = t
    If False t e = e

-- | Negation of type-level Booleans.
type family Not (a :: Bool) :: Bool where
    Not True  = False
    Not False = True

-- | Conjunction of type-level Booleans.
type family (b1 :: Bool) .&&. (b2 :: Bool) :: Bool where
    b1 .&&. b2 = If b1 b2 False

-- | Disjunction of type-level Booleans.
type family (b1 :: Bool) .||. (b2 :: Bool) :: Bool where
    b1 .||. b2 = If b1 True b2

-- | Equality of types.
type family (a :: k) .~. (b :: k) :: Bool where
    a .~. a = True
    a .~. b = False

-- | Returns the maximum of two natural numbers.
type family MaxN (n1 :: Nat) (n2 :: Nat) :: Nat where
    MaxN 0 n2 = n2
    MaxN n1 0 = n1
    MaxN n n = n
    MaxN n1 n2 = If (n1 <=? n2) (n2) (n1)

-- | Returns the minimum of two natural numbers.
type family MinN (n1 :: Nat) (n2 :: Nat) :: Nat where
    MinN 0 n2 = 0
    MinN n1 0 = 0
    MinN n n = n
    MinN n1 n2 = If (n1 <=? n2) (n1) (n2)

-------------------------------------------------------------------------------
-- Constraints
-------------------------------------------------------------------------------

-- | Valid base constraint.
type Valid = (() :: Constraint)

-- | Invalid base constraint.
type Invalid = True ~ False

-- | Create a new constraint which is valid only if every element in the given
-- vector satisfies the given unary constraint.
-- Analogue of 'map' for constraints and vectors.
type family AllSatisfy (c  :: a -> Constraint)
                       (xs :: OptVector a n)
                           :: Constraint where
    AllSatisfy c End            = Valid
    AllSatisfy c (x :* _ :- xs) = ((c x), AllSatisfy c xs)

-- | Create a new constraint which is valid only if every pair of elements in
-- the given optimised vectors satisfy the given binary constraint.
-- Analogue of 'zipWith' for constraints and optimised vectors.
type family AllPairsSatisfy (c  :: a -> b -> Constraint)
                            (xs :: OptVector a n) (ys :: OptVector b n)
                                :: Constraint where
    AllPairsSatisfy c End            End            = Valid
    AllPairsSatisfy c (x :* _ :- xs) (y :* _ :- ys) = ((c x y), AllPairsSatisfy c xs ys)

-- | Create a new constraint which is valid only if every pair of elements in
-- the given vectors satisfy the given binary constraint.
-- Analogue of 'zipWith' for constraints and vectors.
type family AllPairsSatisfy' (c  :: a -> b -> Constraint)
                            (xs :: Vector a n) (ys :: Vector b n)
                                :: Constraint where
    AllPairsSatisfy' c None       None       = Valid
    AllPairsSatisfy' c (x :-- xs) (y :-- ys) = ((c x y), AllPairsSatisfy' c xs ys)

-- | Create a new constraint which is valid only if the given value satisfies
-- every unary constraint in the given list.
type family SatisfiesAll (cs :: [a -> Constraint])
                         (xs :: a)
                             :: Constraint where
    SatisfiesAll '[]      a = Valid
    SatisfiesAll (c : cs) a = (c a, SatisfiesAll cs a)

-- | Create a new constraint which is valid only if every element in the given
-- vector satisfies every unary constraint in the given list.
type family AllSatisfyAll (c1 :: [a -> Constraint])
                          (xs :: Vector a n)
                              :: Constraint where
    AllSatisfyAll _ None        = Valid
    AllSatisfyAll cs (v :-- vs) = (SatisfiesAll cs v, AllSatisfyAll cs vs)