{- 
 - 	Monadic Constraint Programming
 - 	http://www.cs.kuleuven.be/~toms/MCP/
 - 	Pieter Wuille
 -}



{-# LANGUAGE StandaloneDeriving #-}



module Data.Expr.Util (

  Expr(), BoolExpr(), ColExpr(),

  transform, colTransform, boolTransform,

  transformEx, colTransformEx, boolTransformEx,

  property, colProperty, boolProperty,

  propertyEx, colPropertyEx, boolPropertyEx,

  collapse, colCollapse, boolCollapse,

  simplify, colSimplify, boolSimplify,

  WalkPhase(..), WalkResult(..), walk, colWalk, boolWalk,

) where 



import Data.Expr.Data



-------------------------

-- | Helper functions |--

-------------------------



relCheck :: Integer -> ExprRel -> Integer -> Bool

relCheck a EREqual b = a==b

relCheck a ERDiff b = a/=b

relCheck a ERLess b = a<b



-------------------------------------------------------------------------

-- | Transform expressions over one type to expressions over another | --

-------------------------------------------------------------------------



transform :: (Eq a, Eq b, Eq c, Eq d, Eq e, Eq f) => (a->b,c->d,e->f,b->a,d->c,f->e) -> Expr a c e -> Expr b d f

transform (f,fc,fb,fi,fic,fib) = transformEx (Term . f, ColTerm . fc, BoolTerm . fb, Term . fi, ColTerm . fic, BoolTerm . fib)



transformEx :: (Eq a, Eq b, Eq c, Eq d, Eq e, Eq f) => ((a -> Expr b d f),(c -> ColExpr b d f),(e -> BoolExpr b d f),(b -> Expr a c e),(d -> ColExpr a c e),(f -> BoolExpr a c e)) -> Expr a c e -> Expr b d f

transformEx (f,_,_,_,_,_) (Term v) = f v

transformEx f (Const i) = Const i

transformEx f (ExprHole i) = ExprHole i

transformEx f (Plus a b) = simplify $ Plus (transformEx f a) (transformEx f b)

transformEx f (Minus a b) = simplify $ Minus (transformEx f a) (transformEx f b)

transformEx f (Mult a b) = simplify $ Mult (transformEx f a) (transformEx f b)

transformEx f (Div a b) = simplify $ Div (transformEx f a) (transformEx f b)

transformEx f (Mod a b) = simplify $ Mod (transformEx f a) (transformEx f b)

transformEx f (Abs a) = simplify $ Abs (transformEx f a)

transformEx f (At c a) = simplify $ At (colTransformEx f c) (transformEx f a)

transformEx f (ColSize c) = simplify $ ColSize $ colTransformEx f c

transformEx f (Channel a) = simplify $ Channel $ boolTransformEx f a

transformEx f (Cond c t e) = simplify $ Cond (boolTransformEx f c) (transformEx f t) (transformEx f e)

transformEx t@(f,fc,fb,fi,fic,fib) (Fold m i c) = simplify $ Fold (\a b -> transformEx t (m (transformEx (fi,fic,fib,f,fc,fb) a) (transformEx (fi,fic,fib,f,fc,fb) b))) (transformEx t i) (colTransformEx t c)



colTransform :: (Eq a, Eq b, Eq c, Eq d, Eq e, Eq f) => (a->b,c->d,e->f,b->a,d->c,f->e) -> ColExpr a c e -> ColExpr b d f

colTransform (f,fc,fb,fi,fic,fib) = colTransformEx (Term . f, ColTerm . fc, BoolTerm . fb, Term . fi, ColTerm . fic, BoolTerm . fib)



colTransformEx :: (Eq a, Eq b, Eq c, Eq d, Eq e, Eq f) => ((a -> Expr b d f),(c -> ColExpr b d f),(e -> BoolExpr b d f),(b -> Expr a c e),(d -> ColExpr a c e),f -> BoolExpr a c e) -> ColExpr a c e -> ColExpr b d f

colTransformEx (_,f,_,_,_,_)  (ColTerm c) = f c

colTransformEx f (ColList l) = colSimplify $ ColList $ map (transformEx f) l

colTransformEx t@(f,fc,fb,fi,fic,fib) (ColMap m c) = colSimplify $ ColMap (\a -> transformEx t (m (transformEx (fi,fic,fib,f,fc,fb) a))) (colTransformEx t c)

colTransformEx t@(f,fc,fb,fi,fic,fib) (ColSlice p l c) = colSimplify $ ColSlice (\a -> transformEx t (p (transformEx (fi,fic,fib,f,fc,fb) a))) (transformEx t l) (colTransformEx t c)

colTransformEx f (ColCat a b) = colSimplify $ ColCat (colTransformEx f a) (colTransformEx f b)

colTransformEx f (ColRange a b) = colSimplify $ ColRange (transformEx f a) (transformEx f b)



boolTransform :: (Eq a, Eq b, Eq c, Eq d, Eq e, Eq f) => (a->b,c->d,e->f,b->a,d->c,f->e) -> BoolExpr a c e -> BoolExpr b d f

boolTransform (f,fc,fb,fi,fic,fib) = boolTransformEx (Term . f, ColTerm . fc, BoolTerm . fb, Term . fi, ColTerm . fic, BoolTerm . fib)



boolTransformEx :: (Eq a, Eq b, Eq c, Eq d, Eq e, Eq f) => ((a -> Expr b d f),(c -> ColExpr b d f),(e -> BoolExpr b d f),(b -> Expr a c e),(d -> ColExpr a c e),f -> BoolExpr a c e) -> BoolExpr a c e -> BoolExpr b d f

boolTransformEx (_,_,f,_,_,_) (BoolTerm v) = f v

boolTransformEx f (BoolConst c) = BoolConst c

boolTransformEx f (BoolAnd a b) = boolSimplify $ BoolAnd (boolTransformEx f a) (boolTransformEx f b)

boolTransformEx f (BoolOr a b) = boolSimplify $ BoolOr (boolTransformEx f a) (boolTransformEx f b)

boolTransformEx f (BoolEqual a b) = boolSimplify $ BoolEqual (boolTransformEx f a) (boolTransformEx f b)

boolTransformEx f (BoolNot a) = boolSimplify $ BoolNot (boolTransformEx f a)

boolTransformEx f (Rel a r b) = boolSimplify $ Rel (transformEx f a) r (transformEx f b)

boolTransformEx t@(f,fc,fb,fi,fic,fib) (BoolAll m c) = boolSimplify $ BoolAll (\a -> boolTransformEx t (m (transformEx (fi,fic,fib,f,fc,fb) a))) (colTransformEx t c)

boolTransformEx t@(f,fc,fb,fi,fic,fib) (BoolAny m c) = boolSimplify $ BoolAny (\a -> boolTransformEx t (m (transformEx (fi,fic,fib,f,fc,fb) a))) (colTransformEx t c)

boolTransformEx f (ColEqual a b) = boolSimplify $ ColEqual (colTransformEx f a) (colTransformEx f b)

boolTransformEx f (Sorted b c) = boolSimplify $ Sorted b (colTransformEx f c)

boolTransformEx f (AllDiff b c) = boolSimplify $ AllDiff b (colTransformEx f c)

boolTransformEx f (BoolCond c t e) = boolSimplify $ BoolCond (boolTransformEx f c) (boolTransformEx f t) (boolTransformEx f e)

boolTransformEx f (Dom i c) = boolSimplify $ Dom (transformEx f i) (colTransformEx f c)



------------------------------------------------------------------------------------------

-- | Check whether an expression is possibly referring to terms with a given property | --

------------------------------------------------------------------------------------------



propertyEx :: (Expr a b c -> Maybe Bool, ColExpr a b c -> Maybe Bool, BoolExpr a b c -> Maybe Bool) -> Expr a b c -> Bool

propertyEx f@(fi,fc,fb) t = case fi t of

  Just a -> a

  Nothing -> case t of

    Plus a b -> propertyEx f a || propertyEx f b

    Minus a b -> propertyEx f a || propertyEx f b

    Mult a b -> propertyEx f a || propertyEx f b

    Div a b -> propertyEx f a || propertyEx f b

    Mod a b -> propertyEx f a || propertyEx f b

    Abs a -> propertyEx f a

    At a b -> propertyEx f b || colPropertyEx f a

    ColSize a -> colPropertyEx f a

    Fold _ _ _ -> True

    Channel b -> boolPropertyEx f b

    Cond c t e -> boolPropertyEx f c || propertyEx f t || propertyEx f e

    _ -> False



colPropertyEx :: (Expr a b c -> Maybe Bool, ColExpr a b c -> Maybe Bool, BoolExpr a b c -> Maybe Bool) -> ColExpr a b c -> Bool

colPropertyEx f@(fi,fc,fb) t = case fc t of

  Just a -> a

  Nothing -> case t of

    ColList l -> any (propertyEx f) l

    ColMap _ _ -> True

    ColSlice p l c -> propertyEx f (p (ExprHole (-1))) || propertyEx f l || colPropertyEx f c

    ColRange l h -> propertyEx f l || propertyEx f h

    ColCat a b -> colPropertyEx f a || colPropertyEx f b

    _ -> False



boolPropertyEx :: (Expr a b c -> Maybe Bool, ColExpr a b c -> Maybe Bool, BoolExpr a b c -> Maybe Bool) -> BoolExpr a b c -> Bool

boolPropertyEx f@(fi,fc,fb) t = case fb t of

  Just a -> a

  Nothing -> case t of

    BoolAnd a b -> boolPropertyEx f a || boolPropertyEx f b

    BoolOr a b -> boolPropertyEx f a || boolPropertyEx f b

    BoolNot a -> boolPropertyEx f a

    BoolEqual a b -> boolPropertyEx f a || boolPropertyEx f b

    Rel a _ b -> propertyEx f a || propertyEx f b

    BoolAll _ _ -> True

    BoolAny _ _ -> True

    ColEqual a b -> colPropertyEx f a || colPropertyEx f b

    AllDiff _ c -> colPropertyEx f c

    Sorted _ c -> colPropertyEx f c

    BoolCond c t e -> boolPropertyEx f c || boolPropertyEx f t || boolPropertyEx f e

    Dom i c -> propertyEx f i || colPropertyEx f c

    _ -> False





property :: (a -> Bool) -> (b -> Bool) -> (c -> Bool) -> Expr a b c -> Bool

property fit fct fbt = propertyEx (propInt fit, propCol fct, propBool fbt)

colProperty :: (a -> Bool) -> (b -> Bool) -> (c -> Bool) -> ColExpr a b c -> Bool

colProperty fit fct fbt = colPropertyEx (propInt fit, propCol fct, propBool fbt)

boolProperty :: (a -> Bool) -> (b -> Bool) -> (c -> Bool) -> BoolExpr a b c -> Bool

boolProperty fit fct fbt = boolPropertyEx (propInt fit, propCol fct, propBool fbt)



propInt :: (a -> Bool) -> Expr a b c -> Maybe Bool

propInt ft t = case t of

  Term x -> Just $ ft x

  _ -> Nothing



propCol :: (b -> Bool) -> ColExpr a b c -> Maybe Bool

propCol ft t = case t of

  ColTerm x -> Just $ ft x

  _ -> Nothing



propBool :: (c -> Bool) -> BoolExpr a b c -> Maybe Bool

propBool ft t = case t of

  BoolTerm x -> Just $ ft x

  _ -> Nothing





-------------------------------------------------------------------

-- | Count how many references to terms an expression contains | --

-------------------------------------------------------------------



varrefs :: Expr a b c -> Int

varrefs (Term _)     = 1

varrefs (Const _)    = 0

varrefs (ExprHole _) = 0

varrefs (Plus a b)   = varrefs a + varrefs b

varrefs (Minus a b)  = varrefs a + varrefs b

varrefs (Mult a b)   = varrefs a + varrefs b

varrefs (Div a b)    = varrefs a + varrefs b

varrefs (Mod a b)    = varrefs a + varrefs b

varrefs (Abs a)      = varrefs a

varrefs (At c i)     = varrefs i + colVarrefs c

varrefs (ColSize c)  = colVarrefs c

varrefs (Fold f i c) = varrefs i + colVarrefs c + varrefs (f (ExprHole 0) (ExprHole 1))

varrefs (Channel b)  = boolVarrefs b

varrefs (Cond c t e) = boolVarrefs c + varrefs t + varrefs e



colVarrefs :: ColExpr a b c -> Int

colVarrefs (ColTerm _) = 1

colVarrefs (ColList lst) = sum $ map varrefs lst

colVarrefs (ColMap m c) = colVarrefs c + varrefs (m (ExprHole 0))

colVarrefs (ColSlice p l c) = varrefs (p (ExprHole 0)) + varrefs l + colVarrefs c

colVarrefs (ColCat a b) = colVarrefs a + colVarrefs b

colVarrefs (ColRange a b) = varrefs a + varrefs b



boolVarrefs :: BoolExpr a b c -> Int

boolVarrefs (BoolTerm _) = 1

boolVarrefs (BoolConst _) = 0

boolVarrefs (BoolAnd a b) = boolVarrefs a + boolVarrefs b

boolVarrefs (BoolOr a b) = boolVarrefs a + boolVarrefs b

boolVarrefs (BoolEqual a b) = boolVarrefs a + boolVarrefs b

boolVarrefs (BoolNot a) = boolVarrefs a

boolVarrefs (BoolAll f c) = boolVarrefs (f $ ExprHole 0) + colVarrefs c

boolVarrefs (BoolAny f c) = boolVarrefs (f $ ExprHole 0) + colVarrefs c

boolVarrefs (Rel a _ b) = varrefs a + varrefs b

boolVarrefs (ColEqual a b) = colVarrefs a + colVarrefs b

boolVarrefs (Sorted _ c) = colVarrefs c

boolVarrefs (AllDiff _ c) = colVarrefs c

boolVarrefs (BoolCond c t e) = boolVarrefs c + boolVarrefs t + boolVarrefs e

boolVarrefs (Dom i c)    = varrefs i + colVarrefs c



------------------------------

-- | Simplify expressions | --

------------------------------



simplify :: (Eq s, Eq c, Eq b) => Expr s c b -> Expr s c b

-- dropout rules (things that won't ever be changed)

simplify a@(Const _) = a

simplify a@(Term _) = a

simplify a@(ExprHole _) = a

-- simplification rules (either decrease # of variable references, or leave that equal and decrease # of tree nodes)

--- level 0 (result in a final expression)

simplify (Mult a@(Const 0) _) = a

simplify (Div a@(Const 0) _) = a

simplify (Mod a@(Const 0) _) = a

simplify (Mod _ (Const 1)) = Const 0

simplify (Mod _ (Const (-1))) = Const 0

simplify (Mod (Mult (Const a) b) (Const c)) | (a `mod` c)==0 = Const 0

simplify (Minus a b) | a==b = Const 0

simplify (Plus (Const a) (Const b)) = Const (a+b)

simplify (Minus (Const a) (Const b)) = Const (a-b)

simplify (Mult (Const a) (Const b)) = Const (a*b)

simplify (Div (Const a) (Const b)) = Const $ (a `div` b)

simplify (Abs (Const a)) = Const (abs a)

simplify (Mod (Const a) (Const b)) = Const $ (a `mod` b)

simplify (Plus (Const 0) a) = a

simplify (Mult (Const 1) a) = a

simplify (Div a (Const 1)) = a

simplify (At (ColList l) (Const c)) = l!!(fromInteger c)

simplify (ColSize (ColList l)) = Const $ toInteger $ length l

simplify (ColSize (ColSlice _ l _)) = l

simplify (Channel (BoolConst False)) = Const 0

simplify (Channel (BoolConst True)) = Const 1

simplify (Cond (BoolConst True) t _) = t

simplify (Cond (BoolConst False) _ f) = f

--- level 1 (result in one recursive call to simplify)

simplify (Plus a b) | a==b = simplify $ Mult (Const 2) a

simplify (Div a (Const (-1))) = simplify $ Minus (Const 0) a

simplify (Plus (Const c) (Plus (Const a) b)) = simplify $ Plus (Const $ c+a) b

simplify (Plus (Const c) (Minus (Const a) b)) = simplify $ Minus (Const $ c+a) b

simplify (Minus (Const c) (Plus (Const a) b)) = simplify $ Minus (Const $ c-a) b

simplify (Minus (Const c) (Minus (Const a) b)) = simplify $ Plus (Const $ c-a) b

simplify (Mult (Const c) (Mult (Const a) b)) = simplify $ Mult (Const $ a*c) b

simplify (Div (Mult (Const a) b) (Const c)) | (a `mod` c)==0 = simplify $ Mult (Const (a `div` c)) b

simplify (ColSize (ColMap _ c)) = simplify $ ColSize c

simplify (Fold f1 i (ColMap f2 c)) = simplify $ Fold (\a b -> f1 a (f2 b)) i c

simplify (At (ColRange l h) p) = simplify $ Plus l p

simplify (Cond (BoolNot c) t f) = simplify $ Cond c f t

--- level 2 (result in two recursive calls to simplify)

simplify (Plus a (Mult b c)) | a==b && ((varrefs a)>0) = simplify $ Mult (simplify $ Plus c (Const 1)) a

simplify (Plus a (Mult b c)) | a==c && ((varrefs a)>0) = simplify $ Mult (simplify $ Plus b (Const 1)) a

simplify (Plus (Mult b c) a) | a==b && ((varrefs a)>0) = simplify $ Mult (simplify $ Plus c (Const 1)) a

simplify (Plus (Mult b c) a) | a==c && ((varrefs a)>0) = simplify $ Mult (simplify $ Plus b (Const 1)) a

simplify (Plus (Mult a b) (Mult c d)) | a==c = simplify $ Mult (simplify $ Plus b d) a

simplify (Plus (Mult a b) (Mult c d)) | a==d = simplify $ Mult (simplify $ Plus b c) a

simplify (Plus (Mult a b) (Mult c d)) | b==c = simplify $ Mult (simplify $ Plus a d) b

simplify (Plus (Mult a b) (Mult c d)) | b==d = simplify $ Mult (simplify $ Plus a c) b

simplify (Minus a (Mult b c)) | a==b && ((varrefs a)>0) = simplify $ Mult (simplify $ Minus (Const 1) c) a

simplify (Minus a (Mult b c)) | a==c && ((varrefs a)>0) = simplify $ Mult (simplify $ Minus (Const 1) b) a

simplify (Minus (Mult b c) a) | a==b && ((varrefs a)>0) = simplify $ Mult (simplify $ Minus c (Const 1)) a

simplify (Minus (Mult b c) a) | a==c && ((varrefs a)>0) = simplify $ Mult (simplify $ Minus b (Const 1)) a

simplify (Minus (Mult a b) (Mult c d)) | a==c = simplify $ Mult (simplify $ Minus b d) a

simplify (Minus (Mult a b) (Mult c d)) | a==d = simplify $ Mult (simplify $ Minus b c) a

simplify (Minus (Mult a b) (Mult c d)) | b==c = simplify $ Mult (simplify $ Minus a d) b

simplify (Minus (Mult a b) (Mult c d)) | b==d = simplify $ Mult (simplify $ Minus a c) b

simplify (Mult (Abs a) (Abs b)) = simplify $ Abs (simplify $ Mult a b)

simplify (Div (Abs a) (Abs b)) = simplify $ Abs (simplify $ Div a b)

simplify (ColSize (ColRange l h)) = simplify $ Plus (Const 1) $ simplify $ Minus h l

simplify (At (ColSlice f _ c) i) = simplify $ At c (f i)

simplify (At (ColMap m c) i) = simplify $ m $ simplify $ At c i

simplify t@(At (ColCat c1 c2) c@(Const p)) = case simplify (ColSize c1) of

  Const l | p<l -> simplify $ At c1 c

  Const l | p>=l -> simplify $ At c2 (Const $ p-l)

  _ -> t    {- no further (At _ _) rules may follow after this -}

--- level 3 (results in three recursive calls to simplify)

simplify (ColSize (ColCat a b)) = simplify $ Plus (simplify $ ColSize a) (simplify $ ColSize b)

-- reordering rules (do not decrease # of variables or # of tree nodes, but normalize an expression in such a way that the same normalization cannot be applied anymore - possibly because that can only occur in a case already matched by a simplification rule above)

--- level 1

simplify (Plus a (Const c)) = simplify $ Plus (Const c) a

simplify (Minus a (Const c)) = simplify $ Plus (Const (-c)) a

simplify (Mult a (Const c)) = simplify $ Mult (Const c) a

simplify (Mult (Const (-1)) a) = simplify $ Minus (Const 0) a

--- level 2

simplify (Mult t@(Const c) (Plus (Const a) b)) = simplify $ Plus (Const (a*c)) (simplify $ Mult t b)

simplify (Mult t@(Const c) (Minus (Const a) b)) = simplify $ Minus (Const (a*c)) (simplify $ Mult t b)

simplify (Plus a (Plus t@(Const b) c)) = simplify $ Plus t (simplify $ Plus a c)

simplify (Plus a (Minus t@(Const b) c)) = simplify $ Plus t (simplify $ Minus a c)

simplify (Minus a (Plus (Const b) c)) = simplify $ Plus (Const (-b)) (simplify $ Minus a c)

simplify (Minus a (Minus (Const b) c)) = simplify $ Plus (Const (-b)) (simplify $ Plus a c)

simplify (Mult a (Mult t@(Const b) c)) = simplify $ Mult t (simplify $ Mult a c)

simplify (Plus (Plus t@(Const a) b) c) = simplify $ Plus t (simplify $ Plus b c)

simplify (Plus (Minus t@(Const a) b) c) = simplify $ Plus t (simplify $ Minus c b)

simplify (Minus (Plus t@(Const a) b) c) = simplify $ Plus t (simplify $ Minus b c)

simplify (Minus (Minus t@(Const a) b) c) = simplify $ Minus t (simplify $ Plus b c)

simplify (Mult (Mult t@(Const a) b) c) = simplify $ Mult t (simplify $ Mult b c)

simplify (Mult a (Minus t@(Const 0) b)) = simplify $ Minus t (simplify $ Mult a b)

simplify (Mult (Minus t@(Const 0) b) a) = simplify $ Minus t (simplify $ Mult a b)

simplify (Div (Minus t@(Const 0) a) b) = simplify $ Minus t (simplify $ Div a b)

simplify (Div a (Minus t@(Const 0) b)) = simplify $ Minus t (simplify $ Div a b)

-- fallback rule

simplify a = a



colSimplify :: (Eq s, Eq c, Eq b) => ColExpr s c b -> ColExpr s c b

-- dropout rules

colSimplify t@(ColTerm _) = t

-- simplify rules

--- level 1

colSimplify (ColMap f1 (ColMap f2 c)) = colSimplify $ ColMap (f1.f2) c

colSimplify (ColMap f (ColList l)) = colSimplify $ ColList (map f l)

--- level 2

colSimplify (ColSlice p1 l1 (ColSlice p2 l2 c)) = colSimplify $ ColSlice (p1 . p2) l1 c

-- reordering rules

--- level 2

colSimplify (ColCat (ColCat c1 c2) c3) = colSimplify $ ColCat c1 (colSimplify $ ColCat c2 c3)

colSimplify (ColSlice p l (ColMap f c)) = colSimplify $ ColMap f $ colSimplify $ ColSlice p l c

-- fallback rule

colSimplify x = x



boolSimplify :: (Eq s, Eq c, Eq b) => BoolExpr s c b -> BoolExpr s c b

-- dropout rules

boolSimplify t@(BoolTerm _) = t

boolSimplify t@(BoolConst _) = t

-- simplify rules

--- level 0

boolSimplify (BoolAnd (BoolConst False) _) = BoolConst False

boolSimplify (BoolAnd (BoolConst True) a) = a

boolSimplify (BoolAnd _ (BoolConst False)) = BoolConst False

boolSimplify (BoolAnd a (BoolConst True)) = a

boolSimplify (BoolOr (BoolConst True) _) = BoolConst True

boolSimplify (BoolOr (BoolConst False) a) = a

boolSimplify (BoolOr _ (BoolConst True)) = BoolConst True

boolSimplify (BoolOr a (BoolConst False)) = a

boolSimplify (BoolNot (BoolConst a)) = BoolConst (not a)

boolSimplify (BoolEqual (BoolConst True) a) = a

boolSimplify (BoolEqual a (BoolConst True)) = a

boolSimplify (BoolNot (BoolNot a)) = a

boolSimplify (BoolOr a b) | a==b = a

boolSimplify (BoolAnd a b) | a==b = a

boolSimplify (BoolEqual a b) | a==b = BoolConst False

boolSimplify (Rel (Const a) r (Const b)) = BoolConst $ relCheck a r b

boolSimplify (BoolAll f (ColList [])) = BoolConst True

boolSimplify (BoolAny f (ColList [])) = BoolConst False

boolSimplify (BoolAll f (ColList [a])) = f a

boolSimplify (BoolAny f (ColList [a])) = f a

boolSimplify (ColEqual (ColList []) (ColList [])) = BoolConst True

boolSimplify (ColEqual (ColList []) (ColList _)) = BoolConst False

boolSimplify (ColEqual (ColList _) (ColList [])) = BoolConst False

boolSimplify (BoolCond (BoolConst True) t _) = t

boolSimplify (BoolCond (BoolConst False) _ f) = f

--- level 1

boolSimplify (BoolEqual (BoolNot a) (BoolNot b)) = boolSimplify $ BoolEqual a b

boolSimplify (BoolEqual (BoolConst False) a) = boolSimplify $ BoolNot a

boolSimplify (BoolEqual a (BoolConst False)) = boolSimplify $ BoolNot a

boolSimplify (BoolNot (Rel a EREqual b)) = boolSimplify $ Rel a ERDiff b

boolSimplify (BoolNot (Rel a ERDiff b)) = boolSimplify $ Rel a EREqual b

boolSimplify (BoolAll f (ColList [a,b])) = boolSimplify $ f a `BoolAnd` f b

boolSimplify (BoolAny f (ColList [a,b])) = boolSimplify $ f a `BoolOr` f b

boolSimplify (ColEqual (ColList [a]) (ColList [b])) = boolSimplify $ Rel a EREqual b

boolSimplify (Rel (Channel a) EREqual (Channel b)) = boolSimplify $ BoolEqual a b

boolSimplify (BoolCond (BoolNot c) t f) = boolSimplify $ BoolCond c f t

--- level 2

boolSimplify (BoolAnd (BoolNot a) (BoolNot b)) = boolSimplify $ BoolNot $ boolSimplify $ BoolOr a b

boolSimplify (BoolOr (BoolNot a) (BoolNot b)) = boolSimplify $ BoolNot $ boolSimplify $ BoolAnd a b

boolSimplify (Rel (Channel a) ERDiff (Channel b)) = boolSimplify $ BoolNot $ boolSimplify $ BoolEqual a b

boolSimplify (Rel (Channel a) ERLess (Channel b)) = boolSimplify $ BoolAnd b $ boolSimplify $ BoolNot a     -- int(b1) < int(b2)   <=>  !b1 && b2

-- fallback

boolSimplify a = a



-------------------------------------------------------------------

-- | Turn expressions over expressions into simply expressions | --

-------------------------------------------------------------------



collapse :: (Eq t, Eq c, Eq b) => Expr (Expr t c b) (ColExpr t c b) (BoolExpr t c b) -> Expr t c b

collapse (Term t) = t

collapse (Const i) = Const i

collapse (Plus a b) = simplify $ Plus (collapse a) (collapse b)

collapse (Minus a b) = simplify $ Minus (collapse a) (collapse b)

collapse (Mult a b) = simplify $ Mult (collapse a) (collapse b)

collapse (Div a b) = simplify $ Div (collapse a) (collapse b)

collapse (Mod a b) = simplify $ Mod (collapse a) (collapse b)

collapse (Abs a) = simplify $ Abs (collapse a)

collapse (At c a) = simplify $ At (colCollapse c) (collapse a)

collapse (ColSize c) = simplify $ ColSize (colCollapse c)

collapse (Fold f i c) = simplify $ Fold (\a b -> collapse $ f (Term a) (Term b)) (collapse i) (colCollapse c)

collapse (Channel b) = simplify $ Channel (boolCollapse b)

collapse (Cond c t e) = simplify $ Cond (boolCollapse c) (collapse t) (collapse e)



colCollapse :: (Eq t, Eq c, Eq b) => ColExpr (Expr t c b) (ColExpr t c b) (BoolExpr t c b) -> ColExpr t c b

colCollapse (ColTerm t) = t

colCollapse (ColList l) = colSimplify $ ColList $ map collapse l

colCollapse (ColMap f c) = colSimplify $ ColMap (\a -> collapse $ f (Term a)) (colCollapse c)

colCollapse (ColSlice p l c) = colSimplify $ ColSlice (\x -> collapse $ p (Term x)) (collapse l) (colCollapse c)

colCollapse (ColCat a b) = colSimplify $ ColCat (colCollapse a) (colCollapse b)

colCollapse (ColRange a b) = colSimplify $ ColRange (collapse a) (collapse b)



boolCollapse :: (Eq t, Eq c, Eq b) => BoolExpr (Expr t c b) (ColExpr t c b) (BoolExpr t c b) -> BoolExpr t c b

boolCollapse (BoolTerm t) = t

boolCollapse (BoolConst c) = BoolConst c

boolCollapse (BoolAnd a b) = boolSimplify $ BoolAnd (boolCollapse a) (boolCollapse b)

boolCollapse (BoolOr a b) = boolSimplify $ BoolOr (boolCollapse a) (boolCollapse b)

boolCollapse (BoolEqual a b) = boolSimplify $ BoolEqual (boolCollapse a) (boolCollapse b)

boolCollapse (BoolNot a) = boolSimplify $ BoolNot (boolCollapse a)

boolCollapse (Rel a r b) = boolSimplify $ Rel (collapse a) r (collapse b)

boolCollapse (BoolAll f c) = boolSimplify $ BoolAll (\a -> boolCollapse $ f (Term a)) (colCollapse c)

boolCollapse (BoolAny f c) = boolSimplify $ BoolAny (\a -> boolCollapse $ f (Term a)) (colCollapse c)

boolCollapse (ColEqual a b) = boolSimplify $ ColEqual (colCollapse a) (colCollapse b)

boolCollapse (Sorted b c) = boolSimplify $ Sorted b (colCollapse c)

boolCollapse (AllDiff b c) = boolSimplify $ AllDiff b (colCollapse c)

boolCollapse (BoolCond c t e) = boolSimplify $ BoolCond (boolCollapse c) (boolCollapse t) (boolCollapse e)

boolCollapse (Dom i c) = boolSimplify $ Dom (collapse i) (colCollapse c)



-----------------------------------------

-- | walk through expressions

-----------------------------------------



data WalkPhase = WalkPre | WalkSingle | WalkPost

  deriving (Ord,Eq,Enum,Show)



data WalkResult = WalkSkip | WalkDescend

  deriving (Ord,Eq,Enum,Show)



xwalker :: (Eq t, Eq c, Eq b, Monad m) => (WalkPhase -> m WalkResult) -> (Expr t c b -> WalkPhase -> m WalkResult, ColExpr t c b -> WalkPhase -> m WalkResult, BoolExpr t c b -> WalkPhase -> m WalkResult) -> ([Expr t c b],[ColExpr t c b],[BoolExpr t c b]) -> m ()

xwalker q f ([],[],[]) = do

  q WalkSingle

  return ()

xwalker q f (li,lc,lb) = do

  r <- q WalkPre

  case r of

    WalkSkip -> return ()

    WalkDescend -> do

      mapM_ (\p -> walk p f) li

      mapM_ (\p -> colWalk p f) lc

      mapM_ (\p -> boolWalk p f) lb

      q WalkPost

      return ()



walker :: (Eq t, Eq c, Eq b, Monad m) => Expr t c b -> (Expr t c b -> WalkPhase -> m WalkResult, ColExpr t c b -> WalkPhase -> m WalkResult, BoolExpr t c b -> WalkPhase -> m WalkResult) -> ([Expr t c b],[ColExpr t c b],[BoolExpr t c b]) -> m ()

walker x f@(i,c,b) l = xwalker (i x) f l

colWalker :: (Eq t, Eq c, Eq b, Monad m) => ColExpr t c b -> (Expr t c b -> WalkPhase -> m WalkResult, ColExpr t c b -> WalkPhase -> m WalkResult, BoolExpr t c b -> WalkPhase -> m WalkResult) -> ([Expr t c b],[ColExpr t c b],[BoolExpr t c b]) -> m ()

colWalker x f@(i,c,b) l = xwalker (c x) f l

boolWalker :: (Eq t, Eq c, Eq b, Monad m) => BoolExpr t c b -> (Expr t c b -> WalkPhase -> m WalkResult, ColExpr t c b -> WalkPhase -> m WalkResult, BoolExpr t c b -> WalkPhase -> m WalkResult) -> ([Expr t c b],[ColExpr t c b],[BoolExpr t c b]) -> m ()

boolWalker x f@(i,c,b) l = xwalker (b x) f l



walk :: (Eq t, Eq c, Eq b, Monad m) => Expr t c b -> (Expr t c b -> WalkPhase -> m WalkResult, ColExpr t c b -> WalkPhase -> m WalkResult, BoolExpr t c b -> WalkPhase -> m WalkResult) -> m ()

walk x@(Term _) f = walker x f ([],[],[])

walk x@(Const _) f = walker x f ([],[],[])

walk x@(Plus a b) f = walker x f ([a,b],[],[])

walk x@(Minus a b) f = walker x f ([a,b],[],[])

walk x@(Mult a b) f = walker x f ([a,b],[],[])

walk x@(Div a b) f = walker x f ([a,b],[],[])

walk x@(Mod a b) f = walker x f ([a,b],[],[])

walk x@(Abs a) f = walker x f ([a],[],[])

walk x@(At c a) f = walker x f ([a],[c],[])

walk x@(ColSize c) f = walker x f ([],[c],[])

walk x@(Fold _ i c) f = walker x f ([i],[c],[])

walk x@(Channel b) f = walker x f ([],[],[b])

walk x@(Cond c t e) f = walker x f ([t,e],[],[c])

walk x@(ExprHole _) f = return ()



colWalk x@(ColTerm _) f = colWalker x f ([],[],[])

colWalk x@(ColList l) f = colWalker x f (l,[],[])

colWalk x@(ColMap _ c) f = colWalker x f ([],[c],[])

colWalk x@(ColSlice _ l c) f = colWalker x f ([l],[c],[])

colWalk x@(ColCat a b) f = colWalker x f ([],[a,b],[])

colWalk x@(ColRange a b) f = colWalker x f ([a,b],[],[])



boolWalk x@(BoolTerm _) f = boolWalker x f ([],[],[])

boolWalk x@(BoolConst _) f = boolWalker x f ([],[],[])

boolWalk x@(BoolAnd a b) f = boolWalker x f ([],[],[a,b])

boolWalk x@(BoolOr a b) f = boolWalker x f ([],[],[a,b])

boolWalk x@(BoolEqual a b) f = boolWalker x f ([],[],[a,b])

boolWalk x@(BoolNot a) f = boolWalker x f ([],[],[a])

boolWalk x@(Rel a _ b) f = boolWalker x f ([a,b],[],[])

boolWalk x@(BoolAll _ c) f = boolWalker x f ([],[c],[])

boolWalk x@(BoolAny _ c) f = boolWalker x f ([],[c],[])

boolWalk x@(ColEqual a b) f = boolWalker x f ([],[a,b],[])

boolWalk x@(Sorted _ c) f = boolWalker x f ([],[c],[])

boolWalk x@(AllDiff _ c) f = boolWalker x f ([],[c],[])

boolWalk x@(BoolCond c t e) f = boolWalker x f ([],[],[c,t,e])

boolWalk x@(Dom i c) f = boolWalker x f ([i],[c],[])