module Camfort.Specification.Stencils.Syntax where
import Camfort.Helpers
import Prelude hiding (sum)
import Data.Data
import Data.Generics.Uniplate.Data
import Data.List hiding (sum)
import Data.Function
import Data.Maybe
import Debug.Trace
import Control.Applicative
type Variable = String
data Approximation a =
  Exact a | Bound (Maybe a) (Maybe a)
   deriving (Eq, Data, Typeable, Show)
fromExact :: Approximation a -> a
fromExact (Exact a) = a
fromExact _ = error "Exception: fromExact on a non-exact result"
upperBound :: a -> Approximation a
upperBound x = Bound Nothing (Just x)
lowerBound :: a -> Approximation a
lowerBound x = Bound (Just x) Nothing
instance Functor Approximation where
  fmap f (Exact x) = Exact (f x)
  fmap f (Bound x y) = Bound (fmap f x) (fmap f y)
absoluteRep = 100 :: Int 
type RegionEnv = [(String, RegionSum)]
type SpecDecls = [([String], Specification)]
pprintSpecDecls :: SpecDecls -> String
pprintSpecDecls =
 concatMap (\(names, spec) ->
            show spec ++ " :: " ++ intercalate "," names ++ "\n")
lookupAggregate :: Eq a => [([a], b)] -> a -> [b]
lookupAggregate [] _ = []
lookupAggregate ((names, spec) : ss) name =
  if name `elem` names
  then spec : lookupAggregate ss name
  else lookupAggregate ss name
data Specification =
  Specification (Multiplicity (Approximation Spatial))
    deriving (Eq, Data, Typeable)
isEmpty :: Specification -> Bool
isEmpty (Specification mult) = isUnit . fromMult $ mult
data Spatial = Spatial RegionSum
  deriving (Eq, Data, Typeable)
fromBool :: Bool -> Linearity
fromBool True = NonLinear
fromBool False = Linear
hasDuplicates :: Eq a => [a] -> ([a], Bool)
hasDuplicates xs = (nub xs, nub xs /= xs)
fromMult :: Multiplicity a -> a
fromMult (Multiple a) = a
fromMult (Single a) = a
setLinearity :: Linearity -> Specification -> Specification
setLinearity l (Specification mult)
  | l == Linear = Specification $ Single $ fromMult mult
  | l == NonLinear = Specification $ Multiple $ fromMult mult
data Linearity = Linear | NonLinear deriving (Eq, Data, Typeable)
data Multiplicity a = Multiple a | Single a
    deriving (Eq, Data, Typeable, Functor, Show)
type Dimension  = Int 
type Depth      = Int
type IsRefl     = Bool
data Region where
    Forward  :: Depth -> Dimension -> IsRefl -> Region
    Backward :: Depth -> Dimension -> IsRefl -> Region
    Centered :: Depth -> Dimension -> IsRefl -> Region
  deriving (Eq, Data, Typeable)
getDimension :: Region -> Dimension
getDimension (Forward _ dim _) = dim
getDimension (Backward _ dim _) = dim
getDimension (Centered _ dim _) = dim
instance Ord Region where
  (Forward dep dim _) <= (Forward dep' dim' _)
    | dep == dep' = dim <= dim'
    | otherwise   = dep <= dep'
  (Backward dep dim _) <= (Backward dep' dim' _)
    | dep == dep' = dim <= dim'
    | otherwise   = dep <= dep'
  (Centered dep dim _) <= (Centered dep' dim' _)
    | dep == dep' = dim <= dim'
    | otherwise   = dep <= dep'
  
  Forward{}  <= _          = True
  Backward{} <= Centered{} = True
  _          <= _          = False
newtype RegionProd = Product {unProd :: [Region]}
  deriving (Eq, Data, Typeable)
newtype RegionSum = Sum {unSum :: [RegionProd]}
  deriving (Eq, Data, Typeable)
instance Ord RegionProd where
   (Product xs) <= (Product xs') = xs <= xs'
regionPlus :: Region -> Region -> Maybe Region
regionPlus (Forward dep dim reflx) (Backward dep' dim' reflx')
    | dep == dep' && dim == dim' = Just $ Centered dep dim (reflx || reflx')
regionPlus (Backward dep dim reflx) (Forward dep' dim' reflx')
    | dep == dep' && dim == dim' = Just $ Centered dep dim (reflx || reflx')
regionPlus x y | x == y          = Just x
regionPlus x y                   = Nothing
instance PartialMonoid RegionProd where
   emptyM = Product []
   appendM (Product [])   s  = Just s
   appendM s (Product [])    = Just s
   appendM (Product [s]) (Product [s']) =
       regionPlus s s' >>= (\sCombined -> return $ Product [sCombined])
   appendM (Product ss) (Product ss')
       | ss == ss' = Just $ Product ss
       | otherwise =
         case absorbReflexive ss ss' of
           Just (ss0, ss1) ->
               case distAndOverlaps ss0 ss1 of
                 Just ss'' -> return $ Product $ sort ss''
                 Nothing   -> return $ Product $ sort (ss0 ++ ss1)
           Nothing -> case distAndOverlaps ss ss' of
                        Just ss'' -> return $ Product $ sort ss''
                        Nothing   -> Nothing
absorbReflexive :: [Region] -> [Region] -> Maybe ([Region], [Region])
absorbReflexive a b =
      absorbReflexive' (sortBy cmpDims a) (sortBy cmpDims b)
  <|> absorbReflexive' (sortBy cmpDims b) (sortBy cmpDims a)
  where cmpDims = compare `on` getDimension
absorbReflexive' [] [] = Just ([], [])
absorbReflexive' (Forward d dim reflx : rs) [Centered 0 dim' _]
  | dim == dim' = Just (Forward d dim True:rs, [])
absorbReflexive' (Backward d dim reflx : rs) [Centered 0 dim' _]
  | dim == dim' = Just (Backward d dim True:rs, [])
absorbReflexive' (Centered d dim reflx : rs) [Centered 0 dim' _]
  | dim == dim' && d /= 0 = Just (Centered d dim True:rs, [])
absorbReflexive' _ _ = Nothing
distAndOverlaps :: [Region] -> [Region] -> Maybe [Region]
distAndOverlaps x y =
    if length x <= 1 || length y <= 1
    then Nothing
    else 
         distAndOverlaps' x y <|> distAndOverlaps' y x
distAndOverlaps' [] xs = Just xs
distAndOverlaps' xs [] = Just xs
distAndOverlaps' (Forward d dim refl : rs) (Forward d' dim' refl' : rs')
  | rs == rs' && dim == dim'
      = Just (Forward (max d d') dim (refl || refl') : rs)
distAndOverlaps' (Backward d dim refl : rs) (Backward d' dim' refl' : rs')
  | rs == rs' && dim == dim'
      = Just (Backward (max d d') dim (refl || refl') : rs)
distAndOverlaps' (Centered d dim refl : rs) (Centered d' dim' refl' : rs')
  | rs == rs' && dim == dim' && d /= 0 && d' /= 0
      = Just (Centered (max d d') dim (refl || refl') : rs)
distAndOverlaps' (Forward d dim refl : rs) (Centered d' dim' refl' : rs')
  | rs == rs' && dim == dim' && d <= d' && d' /= 0
      = Just (Centered d' dim (refl || refl') : rs)
distAndOverlaps' (Backward d dim refl : rs) (Centered d' dim' refl' : rs')
  | rs == rs' && dim == dim' && d <= d' && d' /= 0
      = Just (Centered d' dim (refl || refl') : rs)
distAndOverlaps' (Forward d dim reflx : rs) (Backward d' dim' reflx' : rs')
    | rs == rs' && d == d' && dim == dim'
      = Just (Centered d dim (reflx || reflx') : rs)
distAndOverlaps' (Centered d dim reflx : rs) (Centered 0 dim' True : rs')
    | rs == rs' && dim == dim' && d /= 0
      = Just (Centered d dim True : rs)
distAndOverlaps' (Forward d dim reflx : rs) (Centered 0 dim' True : rs')
    | rs == rs' && dim == dim'
      = Just (Forward d dim True : rs)
distAndOverlaps' (Backward d dim reflx : rs) (Centered 0 dim' True : rs')
    | rs == rs' && dim == dim'
      = Just (Backward d dim True : rs)
distAndOverlaps' p1@(Backward d1 dim1 refl1 : Backward d2 dim2 refl2 : rs)
                 p2@(Backward d1' dim1' refl1' : Backward d2' dim2' refl2' : rs')
    | rs == rs' && dim1 == dim1' && dim2 == dim2'
      && d1 == d1' && d2 == d2' && refl1 == not refl1' && refl2 == not refl2'
      = Just $ [Backward d1 dim1 True, Backward d2 dim2 True] ++ rs
    | rs == rs' && dim1 == dim2' && dim2 == dim1'
      && d1 == d2' && d2 == d1' && refl1 == not refl2' && refl2 == not refl1'
      = Just $ [Backward d1 dim1 True, Backward d2 dim2 True] ++ rs
distAndOverlaps' p1@(Centered d1 dim1 refl1 : Centered d2 dim2 refl2 : rs)
                 p2@(Centered d1' dim1' refl1' : Centered d2' dim2' refl2' : rs')
    | rs == rs' && dim1 == dim1' && dim2 == dim2' && (d1 * d2 * d1' * d2' /= 0)
      && d1 == d1' && d2 == d2' && refl1 == not refl1' && refl2 == not refl2'
      = Just $ [Centered d1 dim1 True, Centered d2 dim2 True] ++ rs
    | rs == rs' && dim1 == dim2' && dim2 == dim1'
      && d1 == d2' && d2 == d1' && refl1 == not refl2' && refl2 == not refl1'
      = Just $ [Centered d1 dim1 True, Centered d2 dim2 True] ++ rs
distAndOverlaps' p1@(Forward d1 dim1 refl1 : Forward d2 dim2 refl2 : rs)
                 p2@(Forward d1' dim1' refl1' : Forward d2' dim2' refl2' : rs')
    | rs == rs' && dim1 == dim1' && dim2 == dim2' && (d1 * d2 * d1' * d2' /= 0)
      && d1 == d1' && d2 == d2' && refl1 == not refl1' && refl2 == not refl2'
      = Just $ [Forward d1 dim1 True, Forward d2 dim2 True] ++ rs
    | rs == rs' && dim1 == dim2' && dim2 == dim1' && (d1 * d2 * d1' * d2' /= 0)
      && d1 == d2' && d2 == d1' && refl1 == not refl2' && refl2 == not refl1'
      = Just $ [Forward d1 dim1 True, Forward d2 dim2 True] ++ rs
distAndOverlaps' (r:rs) (r':rs')
    | r == r'   = do rs'' <- distAndOverlaps rs rs'
                     return $ r : rs''
distAndOverlaps' _ _ = Nothing
class RegionRig t where
  sum  :: t -> t -> t
  prod :: t -> t -> t
  one  :: t
  zero :: t
  isUnit :: t -> Bool
instance RegionRig a => RegionRig (Maybe a) where
  sum (Just x) (Just y) = Just $ sum x y
  sum x Nothing = x
  sum Nothing x = x
  prod (Just x) (Just y) = Just $ prod x y
  prod x Nothing = x
  prod Nothing x = x
  one  = Just one
  zero = Just zero
  isUnit Nothing = True
  isUnit (Just x) = isUnit x
instance RegionRig Spatial where
  sum (Spatial s) (Spatial s') = Spatial (sum s s')
  prod (Spatial s) (Spatial s') = Spatial (prod s s')
  one = Spatial one
  zero = Spatial zero
  isUnit (Spatial ss) = isUnit ss
instance RegionRig (Approximation Spatial) where
  sum (Exact s) (Exact s')      = Exact (sum s s')
  sum (Exact s) (Bound l u)     = Bound (sum (Just s) l) (sum (Just s) u)
  sum (Bound l u) (Bound l' u') = Bound (sum l l') (sum u u')
  sum s s'                      = sum s' s
  prod (Exact s) (Exact s')      = Exact (prod s s')
  prod (Exact s) (Bound l u)     = Bound (prod (Just s) l) (prod (Just s) u)
  prod (Bound l u) (Bound l' u') = Bound (prod l l') (prod u u') 
  prod s s'                      = prod s' s
  one  = Exact one
  zero = Exact zero
  isUnit (Exact s) = isUnit s
  isUnit (Bound x y) = isUnit x && isUnit y
instance RegionRig RegionSum where
  prod (Sum ss) (Sum ss') =
   Sum $ nub $ 
     do (Product spec) <- ss
        (Product spec') <- ss'
        return $ Product $ nub $ sort $ spec ++ spec'
  sum (Sum ss) (Sum ss') = Sum $ normalise $ ss ++ ss'
  zero = Sum []
  one = Sum [Product []]
  isUnit s@(Sum ss) = s == zero || s == one || all (== Product []) ss
showL :: Show a => [a] -> String
showL = intercalate "," . map show
showProdSpecs, showSumSpecs :: Show a => [a] -> String
showProdSpecs = intercalate "*" . map show
showSumSpecs  = intercalate "+" . map show
instance Show Specification where
  show (Specification sp) = "stencil " ++ show sp
instance  Show (Multiplicity (Approximation Spatial)) where
  show mult
    | Multiple appr <- mult = apprStr empty appr
    | Single appr <- mult = apprStr "readOnce, " appr
    where
      apprStr linearity appr =
        case appr of
          Exact s -> linearity ++ show s
          Bound Nothing Nothing -> "empty"
          Bound Nothing (Just s) -> "atMost, " ++ linearity ++ show s
          Bound (Just s) Nothing -> "atLeast, " ++ linearity ++ show s
          Bound (Just sL) (Just sU) ->
            "atLeast, " ++ linearity ++ show sL ++
            "; atMost, " ++ linearity ++ show sU
instance  Show (Approximation Spatial) where
  show (Exact s) = show s
  show (Bound Nothing Nothing) = "empty"
  show (Bound Nothing (Just s)) = "atMost, " ++ show s
  show (Bound (Just s) Nothing) = "atLeast, " ++ show s
  show (Bound (Just sL) (Just sU)) =
      "atLeast, " ++ show sL ++ "; atMost, " ++ show sU
instance Show Spatial where
  show (Spatial region) =
    
    case show region of
      "empty" -> ""
      xs      -> xs
instance Show RegionSum where
    
    show (Sum []) = "empty"
    
    show (Sum [Product []]) = "empty"
    show (Sum specs) =
      intercalate " + " ppspecs
      where ppspecs = filter (/= "") $ map show specs
instance Show RegionProd where
    show (Product []) = ""
    show (Product ss)  =
       intercalate "*" . map (\s -> "(" ++ show s ++ ")") $ ss
instance Show Region where
   show (Forward dep dim reflx)   = showRegion "forward" dep dim reflx
   show (Backward dep dim reflx)  = showRegion "backward" dep dim reflx
   show (Centered dep dim reflx)
     | dep == 0 = "reflexive(dim=" ++ show dim ++ ")"
     | otherwise = showRegion "centered" dep dim reflx
showRegion typ depS dimS reflx = typ ++ "(depth=" ++ show depS
                               ++ ", dim=" ++ show dimS
                               ++ (if reflx then "" else ", irreflexive")
                               ++ ")"
groupKeyBy :: Eq b => [(a, b)] -> [([a], b)]
groupKeyBy = groupKeyBy' . map (\ (k, v) -> ([k], v))
  where
    groupKeyBy' []        = []
    groupKeyBy' [(ks, v)] = [(ks, v)]
    groupKeyBy' ((ks1, v1):((ks2, v2):xs))
      | v1 == v2          = groupKeyBy' ((ks1 ++ ks2, v1) : xs)
      | otherwise         = (ks1, v1) : groupKeyBy' ((ks2, v2) : xs)