This files gives an executable implementation of the model for
abstract stencil specifications. This model is used to drive both
the specification checking and program synthesis features.


{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE MultiWayIf #-}

module Camfort.Specification.Stencils.Model ( Interval(..)
                                            , Bound(..)
                                            , approxVec
                                            , Offsets(..)
                                            , UnionNF
                                            , vecLength
                                            , unfCompare
                                            , optimise
                                            , maximas
                                            , Approximation(..)
                                            , lowerBound, upperBound
                                            , fromExact
                                            , Multiplicity(..)
                                            , Peelable(..)
                                            ) where

import qualified Control.Monad as CM

import           Algebra.Lattice
import           Data.Semigroup
import qualified Data.List.NonEmpty as NE
import qualified Data.Set as S
import           Data.Foldable
import           Data.SBV
import           Data.Data
import           Data.List (sortBy, nub)
import           Data.Maybe (fromJust)
import qualified Data.PartialOrd as PO

import qualified Camfort.Helpers.Vec as V
import System.IO.Unsafe

-- Utility container
class Container a where
  type MemberTyp a
  type CompTyp a

  member :: MemberTyp a -> a -> Bool
  compile :: a -> (CompTyp a -> SBool)

-- Arbitrary sets representing offsets

data Offsets =
    Offsets (S.Set Int64)
  | SetOfIntegers
  deriving Eq

instance Ord Offsets where
    Offsets s `compare` Offsets s' = s `compare` s'
    Offsets _ `compare` SetOfIntegers = LT
    SetOfIntegers `compare` Offsets _ = GT
    SetOfIntegers `compare` SetOfIntegers = EQ

instance Container Offsets where
  type MemberTyp Offsets = Int64
  type CompTyp Offsets = SInt64

  member i (Offsets s) = i `S.member` s
  member _ _ = True

  compile (Offsets s) i = i `sElem` map fromIntegral (S.toList s)
  compile SetOfIntegers _ = true

instance JoinSemiLattice Offsets where
  (Offsets s) \/ (Offsets s') = Offsets $ s `S.union` s'
  _ \/ _ = SetOfIntegers

instance MeetSemiLattice Offsets where
  (Offsets s) /\ (Offsets s') = Offsets $ s `S.intersection` s'
  off@Offsets{} /\ _ = off
  _ /\ o = o

instance Lattice Offsets

instance BoundedJoinSemiLattice Offsets where
  bottom = Offsets S.empty

instance BoundedMeetSemiLattice Offsets where
  top = SetOfIntegers

instance BoundedLattice Offsets

-- Interval as defined in the paper

data Bound = Arbitrary | Standard

-- | Interval data structure assumes the following:
-- 1. The first num. param. is less than the second;
-- 2. For holed intervals, first num. param. <= 0 <= second num. param.;
data Interval a where
  IntervArbitrary :: Int -> Int -> Interval Arbitrary
  IntervInfiniteArbitrary :: Interval Arbitrary
  IntervHoled     :: Int64 -> Int64 -> Bool -> Interval Standard
  IntervInfinite  :: Interval Standard

deriving instance Eq (Interval a)

instance Show (Interval Standard) where
  show IntervInfinite = "IntervInfinite"
  show (IntervHoled lb up p) =
    "Interv [" ++ show lb ++ "," ++ show up ++ "]^" ++ show p

approxInterv :: Interval Arbitrary -> Approximation (Interval Standard)
approxInterv (IntervArbitrary a b)
  | a > b = error
    "Interval condition violated: lower bound is bigger than the upper bound."
  | a <=  0, b >=  0 = Exact  $ IntervHoled a' b' True
  | a <= -1, b == -1 = Exact  $ IntervHoled a' 0  False
  | a ==  1, b >=  1 = Exact  $ IntervHoled 0  b' False
  | a >   1, b >   1 = Bound Nothing $ Just $ IntervHoled 0  b' False
  | a <  -1, b <  -1 = Bound Nothing $ Just $ IntervHoled a' 0  False
  | otherwise = error "Impossible: All posibilities are covered."
    a' = fromIntegral a
    b' = fromIntegral b
approxInterv IntervInfiniteArbitrary = Exact IntervInfinite

approxVec :: forall n .
             V.Vec n (Interval Arbitrary)
          -> Approximation (V.Vec n (Interval Standard))
approxVec v =
  case findApproxIntervs stdVec of
    ([],_) -> Exact . fmap fromExact $ stdVec
    _      -> Bound Nothing (Just $ upperBound <$> stdVec)
    stdVec :: V.Vec n (Approximation (Interval Standard))
    stdVec = fmap approxInterv v

    findApproxIntervs :: forall n . V.Vec n (Approximation (Interval Standard))
                      -> ([ Int ], [ Int ])
    findApproxIntervs v = findApproxIntervs' 0 v ([],[])

    findApproxIntervs' :: forall n . Int
                       -> V.Vec n (Approximation (Interval Standard))
                       -> ([ Int ], [ Int ])
                       -> ([ Int ], [ Int ])
    findApproxIntervs' _ V.Nil acc = acc
    findApproxIntervs' i (V.Cons x xs) (bixs, eixs) =
      findApproxIntervs' (i+1) xs $
        case x of
          Bound{} -> (i:bixs, eixs)
          Exact{} -> (bixs, i:eixs)

instance Container (Interval Standard) where
  type MemberTyp (Interval Standard) = Int64
  type CompTyp (Interval Standard) = SInt64

  member 0 (IntervHoled _ _ b) = b
  member i (IntervHoled a b _) = i >= a && i <= b
  member _ _ = True

  compile (IntervHoled i1 i2 b) i
    | b = inRange i range
    | otherwise = inRange i range &&& i ./= 0
      range = (fromIntegral i1, fromIntegral i2)
  compile IntervInfinite _ = true

instance JoinSemiLattice (Interval Standard) where
  (IntervHoled lb ub noHole) \/ (IntervHoled lb' ub' noHole') =
    IntervHoled (min lb lb') (max ub ub') (noHole || noHole')
  _ \/ _ = top

instance MeetSemiLattice (Interval Standard) where
  (IntervHoled lb ub noHole) /\ (IntervHoled lb' ub' noHole') =
    IntervHoled (max lb lb') (min ub ub') (noHole && noHole')
  int@IntervHoled{} /\ _ = int
  _ /\ int = int

instance Lattice (Interval Standard)

instance BoundedJoinSemiLattice (Interval Standard) where
  bottom = IntervHoled 0 0 False

instance BoundedMeetSemiLattice (Interval Standard) where
  top = IntervInfinite

instance BoundedLattice (Interval Standard)

-- Union of cartesian products normal form

type UnionNF n a = NE.NonEmpty (V.Vec n a)

vecLength :: UnionNF n a -> V.Natural n
vecLength = V.lengthN . NE.head

instance Container a => Container (UnionNF n a) where
  type MemberTyp (UnionNF n a) = V.Vec n (MemberTyp a)
  type CompTyp (UnionNF n a) = V.Vec n (CompTyp a)
  member is = any (member' is)
      member' is space = and $ V.zipWith member is space

  compile spaces is = foldr1 (|||) $ NE.map (`compile'` is) spaces
      compile' space is =
        foldr' (\(set, i) -> (&&&) $ compile set i) true $ V.zip space is

instance JoinSemiLattice (UnionNF n a) where
  oi \/ oi' = oi <> oi'

instance BoundedLattice a => MeetSemiLattice (UnionNF n a) where
  (/\) = CM.liftM2 (V.zipWith (/\))

instance BoundedLattice a => Lattice (UnionNF n a)

unfCompare :: forall a b n . ( Container a,          Container b
                             , MemberTyp a ~ Int64,  MemberTyp b ~ Int64
                             , CompTyp a ~ SInt64,   CompTyp b ~ SInt64
           => UnionNF n a -> UnionNF n b -> Ordering
unfCompare oi oi' = unsafePerformIO $ do
    thmRes <- prove pred
    case thmRes of
      -- Tell the user if there was a hard proof error (e.g., if
      -- z3 is not installed/accessible).
      -- TODO: give more information
      ThmResult (ProofError _ msgs) -> fail $ unlines msgs
      _ ->
        if modelExists thmRes
        then do
          ce <- counterExample thmRes
          case V.fromList ce of
             V.VecBox cev ->
               case V.proveEqSize (NE.head oi) cev of
                 Just V.ReflEq ->
                   -- TODO: The second branch is defensive programming the
                   -- member check is not necessary unless the counter example
                   -- is bogus (it shouldn't be). Delete if it adversely
                   -- effects the performance.
                   if | cev `member` oi  -> return GT
                      | cev `member` oi' -> return LT
                      | otherwise -> fail
                         "Impossible: counter example is in \
                          \neither of the operands"
                 Nothing -> fail
                    "Impossible: Counter example size doesn't \
                    \match the original vector size."
        else return EQ
    counterExample :: ThmResult -> IO [ Int64 ]
    counterExample thmRes =
      case getModelAssignment thmRes of
        Right (False, ce) -> return ce
        Right (True, _) -> fail "Returned probable model."
        Left str -> fail str

    pred :: Predicate
    pred = do
      freeVars <- (mkFreeVars . dimensionality) oi :: Symbolic [ SInt64 ]
      case V.fromList freeVars of
        V.VecBox freeVarVec ->
          case V.proveEqSize (NE.head oi) freeVarVec of
            Just V.ReflEq -> return $
              compile oi freeVarVec .== compile oi' freeVarVec
            Nothing -> fail $
              "Impossible: Free variables size doesn't match that of the " ++
              "union parameter."
    dimensionality = V.length . NE.head

-- Optimise unions

instance PO.PartialOrd Offsets where
  (Offsets s) <= (Offsets s') = s <= s'
  SetOfIntegers <= Offsets{} = False
  _ <= SetOfIntegers = True

instance PO.PartialOrd (Interval Standard) where
  (IntervHoled lb ub p) <= (IntervHoled lb' ub' p') =
    (p' || not p) && lb >= lb' && ub <= ub'
  IntervInfinite <= IntervHoled{} = False
  _ <= IntervInfinite = True

instance PO.PartialOrd a => PO.PartialOrd (V.Vec n a) where
  v <= v' = and $ V.zipWith (PO.<=) v v'

optimise :: UnionNF n (Interval Standard) -> UnionNF n (Interval Standard)
optimise = NE.fromList . maximas . fixedPointUnion . NE.toList
    fixedPointUnion unf =
      let unf' = unionLemma . maximas $ unf
      in if unf' == unf then unf' else fixedPointUnion unf'

sensibleGroupBy :: Eq a =>
                   (a -> a -> Ordering)
                -> (a -> a -> Bool)
                -> [ a ]
                -> [ [ a ] ]
sensibleGroupBy ord p l = nub . map (\el -> sortBy ord . filter (p el) $ l) $ l

maximas :: [ V.Vec n (Interval Standard) ] -> [ V.Vec n (Interval Standard) ]
maximas = nub
        . fmap (head . PO.maxima)
        . sensibleGroupBy ord (PO.<=)
    ord a b = fromJust $ a `PO.compare` b

-- | Union lemma says that if we have a product of intervals (as defined in
-- the paper) and we union two that agrees in each dimension except one.
-- The union is again a product of intervals that agrees with the original
-- dimensions in all dimensions except the original differing one. At that
-- point it is the union of intervals, which is itself still an interval.
unionLemma :: [ V.Vec n (Interval Standard) ] -> [ V.Vec n (Interval Standard) ]
unionLemma = map (foldr1 (V.zipWith (\/)))
           . sensibleGroupBy (\a b -> if a == b then EQ else LT) agreeButOne
    -- This function returns true if two vectors agree at all points but one.
    -- It also holds if two vectors are identical.
    agreeButOne :: Eq a => V.Vec n a -> V.Vec n a -> Bool
    agreeButOne = go False
        go :: Eq a => Bool -> V.Vec n a -> V.Vec n a -> Bool
        go _ V.Nil V.Nil = True
        go False (V.Cons x xs) (V.Cons y ys)
          | x == y = go False xs ys
          | otherwise = go True xs ys
        go True (V.Cons x xs) (V.Cons y ys)
          | x == y = go True xs ys
          | otherwise = False

-- Injections for multiplicity and exactness

data Approximation a = Exact a | Bound (Maybe a) (Maybe a)
  deriving (Eq, Show, Functor, Foldable, Traversable, Data, Typeable)

fromExact :: Approximation a -> a
fromExact (Exact a) = a
fromExact _ = error "Can't retrieve from bounded as if it was exact."

lowerBound :: Approximation a -> a
lowerBound (Bound (Just a) _) = a
lowerBound (Bound Nothing _) = error "Approximation doesn't have a lower bound."
lowerBound (Exact a) = a

upperBound :: Approximation a -> a
upperBound (Bound _ (Just a)) = a
upperBound (Bound _ Nothing) = error "Approximation doesn't have a upper bound."
upperBound (Exact a) = a

class Peelable a where
  type CoreTyp a
  peel :: a -> CoreTyp a

data Multiplicity a = Mult a | Once a
  deriving (Eq, Show, Functor, Foldable, Traversable, Data, Typeable)

instance Peelable (Multiplicity a) where
  type CoreTyp (Multiplicity a) = a

  peel (Mult a) = a
  peel (Once a) = a

