-- | Various properties and transformations of Boogie program elements
module Language.Boogie.Util ( 
  -- * Types
  TypeBinding,
  typeSubst,
  renameTypeVars,
  fromTVNames,
  isFreeIn,
  isTypeVar,
  unifier,
  freshTVName,
  tupleType,  
  -- * Expressions
  freeVarsTwoState,
  freeVars,
  freeOldVars,
  VarBinding,
  exprSubst,
  paramSubst,
  freeSelections,
  applications,
  -- * Specs
  preconditions,
  postconditions,
  modifies,
  assumePreconditions,
  assumePostconditions,
  -- * Functions and procedures
  FSig (..),
  fsigType,
  fsigFromType,
  FDef (..),
  ConstraintSet,
  AbstractStore,
  asUnion,
  PSig (..),
  psigParams,
  psigArgTypes,
  psigRetTypes,
  psigModifies,
  psigRequires,
  psigEnsures,
  psigType,
  PDef (..),
  pdefLocals,
  -- * Code generation
  num, eneg, enot,
  (|+|), (|-|), (|*|), (|/|), (|%|), (|=|), (|!=|), (|<|), (|<=|), (|>|), (|>=|), (|&|), (|||), (|=>|), (|<=>|),
  conjunction,
  assume,
  -- * Misc
  interval,
  fromRight,
  deleteAll,
  restrictDomain,
  removeDomain,
  mapItwType,
  anyM,
  changeState,
  withLocalState,
  internalError
) where

import Language.Boogie.AST
import Language.Boogie.Position
import Language.Boogie.Tokens
import Data.Maybe
import Data.List
import Data.Map (Map, (!))
import qualified Data.Map as M
import Data.Set (Set)
import qualified Data.Set as S
import Control.Applicative
import Control.Monad.State
import Control.Monad.Stream
import Control.Lens

{- Types -}

-- | Mapping from type variables to types
type TypeBinding = Map Id Type

-- | 'typeSubst' @binding t@ :
-- Substitute all free type variables in @t@ according to binding;
-- all variables in the domain of @bindings@ are considered free if not explicitly bound
typeSubst :: TypeBinding -> Type -> Type
typeSubst _ BoolType = BoolType
typeSubst _ IntType = IntType
typeSubst binding (IdType id []) = case M.lookup id binding of
  Just t -> t
  Nothing -> IdType id []
typeSubst binding (IdType id args) = IdType id (map (typeSubst binding) args)
typeSubst binding (MapType bv domains range) = MapType bv (map (typeSubst removeBound) domains) (typeSubst removeBound range)
  where removeBound = deleteAll bv binding
  
-- | 'renameTypeVars' @tv newTV binding@ : @binding@ with each occurrence of one of @tv@ replaced with corresponding @newTV@ 
-- (in both domain and range)
renameTypeVars :: [Id] -> [Id] -> TypeBinding -> TypeBinding
renameTypeVars tv newTV binding = let
    tvMap = M.fromList $ zip tv newTV 
    replace tv = M.findWithDefault tv tv tvMap
    tvToType = fromTVNames tv newTV
  in M.map (typeSubst tvToType) (M.mapKeys replace binding)
    
-- | @x@ `isFreeIn` @t@ : does @x@ occur free in @t@?
isFreeIn :: Id -> Type -> Bool
x `isFreeIn` (IdType y []) = x == y
x `isFreeIn` (IdType y args) = any (x `isFreeIn`) args
x `isFreeIn` (MapType bv domains range) = x `notElem` bv && any (x `isFreeIn`) (range:domains)
_ `isFreeIn` _ = False  
  
-- | 'fromTVNames' @tvs tvs'@ : type binding that replaces type variables @tvs@ with type variables @tvs'@
fromTVNames :: [Id] -> [Id] -> TypeBinding
fromTVNames tvs tvs' = M.fromList (zip tvs (map nullaryType tvs'))

-- | 'freshTVName @n@ : Fresh type variable with a unique identifier n
freshTVName n = nonIdChar : (show n)

-- | 'isTypeVar' @contextTypeVars v@ : Is @v@ either one of  @contextTypeVars@ or a freash type variable generated by 'freshTVName'?
isTypeVar :: [Id] -> Id -> Bool
isTypeVar contextTypeVars v = head v == nonIdChar || v `elem` contextTypeVars
    
-- | 'unifier' @fv xs ys@ : most general unifier of @xs@ and @ys@ with shared free type variables of the context @fv@
unifier :: [Id] -> [Type] -> [Type] -> Maybe TypeBinding
unifier _ [] [] = Just M.empty
unifier fv (IntType:xs) (IntType:ys) = unifier fv xs ys
unifier fv (BoolType:xs) (BoolType:ys) = unifier fv xs ys
unifier fv ((IdType id1 args1):xs) ((IdType id2 args2):ys) | id1 == id2 = unifier fv (args1 ++ xs) (args2 ++ ys)
unifier fv ((IdType id []):xs) (y:ys) | isTypeVar fv id = 
  if id `isFreeIn` y then Nothing 
  else M.insert id y <$> unifier fv (update xs) (update ys)
    where update = map (typeSubst (M.singleton id y))
unifier fv (x:xs) ((IdType id []):ys) | isTypeVar fv id = 
  if id `isFreeIn` x then Nothing 
  else M.insert id x <$> unifier fv (update xs) (update ys)
    where update = map (typeSubst (M.singleton id x))
unifier fv ((MapType bv1 domains1 range1):xs) ((MapType bv2 domains2 range2):ys) =
  case forallUnifier fv bv1 (range1:domains1) bv2 (range2:domains2) of
    Nothing -> Nothing
    Just u -> M.union u <$> (unifier fv (update u xs) (update u ys))
  where
    update u = map (typeSubst u)
unifier _ _ _ = Nothing

-- | 'removeClashesWith' @tvs tvs'@ :
-- New names for type variables @tvs@ that are disjoint from @tvs'@
-- (if @tvs@ does not have duplicates, then result also does not have duplicates)
removeClashesWith :: [Id] -> [Id] -> [Id]
removeClashesWith tvs tvs' = map changeName tvs
  where
    -- new name for tv that does not coincide with any tvs'
    changeName tv = if tv `elem` tvs' then tv ++ replicate (level + 1) nonIdChar else tv
    -- maximum number of nonIdChar characters at the end of any tvs or tvs'; 
    -- by appending (level + 1) nonIdChar charactes to tv we make is different from all tvs' and unchanged tvs
    level = maximum [fromJust (findIndex (\c -> c /= nonIdChar) (reverse id)) | id <- tvs ++ tvs']
    
-- | 'forallUnifier' @fv bv1 xs bv2 ys@ :   
-- Most general unifier of @xs@ and @ys@,
-- where @bv1@ are universally quantified type variables in @xs@ and @bv2@ are universally quantified type variables in @ys@,
-- and @fv@ are free type variables of the enclosing context
forallUnifier :: [Id] -> [Id] -> [Type] -> [Id] -> [Type] -> Maybe TypeBinding
forallUnifier fv bv1 xs bv2 ys = if length bv1 /= length bv2 || length xs /= length ys 
  then Nothing
  else case unifier (fv ++ bv1) xs (map withFreshBV ys) of
    Nothing -> Nothing
    Just u -> let (boundU, freeU) = M.partitionWithKey (\k _ -> k `elem` bv1) u
      in if all isFreshBV (M.elems boundU) && not (any hasFreshBV (M.elems freeU))
        then Just freeU
        else Nothing
  where
    freshBV = bv2 `removeClashesWith` bv1
    withFreshBV = typeSubst (fromTVNames bv2 freshBV)
    -- does a type correspond to one of the renamed bound variables?
    isFreshBV (IdType id []) = id `elem` freshBV
    isFreshBV _ = False
    -- does type t contain any fresh bound variables of m2?
    hasFreshBV t = any (`isFreeIn` t) freshBV

-- | Internal tuple type constructor (used for representing procedure returns as a single type)
tupleType ts = IdType "*Tuple" ts
  
{- Expressions -}

-- | Free variables in an expression, referred to in current state and old state
freeVarsTwoState :: Expression -> ([Id], [Id])
freeVarsTwoState e = freeVarsTwoState' (node e)

freeVarsTwoState' FF = ([], [])
freeVarsTwoState' TT = ([], [])
freeVarsTwoState' (Numeral _) = ([], [])
freeVarsTwoState' (Var x) = ([x], [])
freeVarsTwoState' (Application name args) = over both (nub . concat) (unzip (map freeVarsTwoState args))
freeVarsTwoState' (MapSelection m args) =  over both (nub . concat) (unzip (map freeVarsTwoState (m : args)))
freeVarsTwoState' (MapUpdate m args val) =  over both (nub . concat) (unzip (map freeVarsTwoState (val : m : args)))
freeVarsTwoState' (Old e) = let (state, old) = freeVarsTwoState e in ([], state ++ old)
freeVarsTwoState' (IfExpr cond e1 e2) = over both (nub . concat) (unzip [freeVarsTwoState cond, freeVarsTwoState e1, freeVarsTwoState e2])
freeVarsTwoState' (Coercion e _) = freeVarsTwoState e
freeVarsTwoState' (UnaryExpression _ e) = freeVarsTwoState e
freeVarsTwoState' (BinaryExpression _ e1 e2) = over both (nub . concat) (unzip [freeVarsTwoState e1, freeVarsTwoState e2])
freeVarsTwoState' (Quantified _ _ boundVars e) = let (state, old) = freeVarsTwoState e in (state \\ map fst boundVars, old)

-- | Free variables in an expression, in current state
freeVars = fst . freeVarsTwoState
-- | Free variables in an expression, in old state
freeOldVars = snd . freeVarsTwoState

-- | Mapping from variables to expressions
type VarBinding = Map Id BareExpression

-- | 'exprSubst' @binding e@ : substitute all free variables in @e@ according to @binding@;
-- all variables in the domain of @bindings@ are considered free if not explicitly bound
exprSubst :: VarBinding -> Expression -> Expression
exprSubst binding (Pos pos e) = attachPos pos $ exprSubst' binding e

exprSubst' binding (Var id) = case M.lookup id binding of
  Nothing -> Var id
  Just e -> e
exprSubst' binding (Application id args) = Application id (map (exprSubst binding) args)
exprSubst' binding (MapSelection m args) = MapSelection (exprSubst binding m) (map (exprSubst binding) args)
exprSubst' binding (MapUpdate m args val) = MapUpdate (exprSubst binding m) (map (exprSubst binding) args) (exprSubst binding val)
exprSubst' binding (Old e) = Old (exprSubst binding e)
exprSubst' binding (IfExpr cond e1 e2) = IfExpr (exprSubst binding cond) (exprSubst binding e1) (exprSubst binding e2)
exprSubst' binding (Coercion e t) = Coercion (exprSubst binding e) t
exprSubst' binding (UnaryExpression op e) = UnaryExpression op (exprSubst binding e)
exprSubst' binding (BinaryExpression op e1 e2) = BinaryExpression op (exprSubst binding e1) (exprSubst binding e2)
exprSubst' binding (Quantified qop tv boundVars e) = Quantified qop tv boundVars (exprSubst binding' e)
  where binding' = deleteAll (map fst boundVars) binding
exprSubst' _ e = e

-- | 'paramBinding' @sig def@ :
-- Binding of parameter names from procedure signature @sig@ to their equivalents from procedure definition @def@
paramBinding :: PSig -> PDef -> VarBinding
paramBinding sig def = M.fromList $ zip (sigIns ++ sigOuts) (defIns ++ defOuts)
  where
    sigIns = map itwId $ psigArgs sig
    sigOuts = map itwId $ psigRets sig
    defIns = map Var $ pdefIns def
    defOuts = map Var $ pdefOuts def
  
-- | 'paramSubst' @sig def@ :
-- Substitute parameter names from @sig@ in an expression with their equivalents from @def@
paramSubst :: PSig -> PDef -> Expression -> Expression  
paramSubst sig def = if not (pdefParamsRenamed def) 
  then id 
  else exprSubst (paramBinding sig def)
  
-- | 'freeSelections' @expr@ : all map selections that occur in @expr@, where the map is a free variable
freeSelections :: Expression -> [(Id, [Expression])]
freeSelections expr = freeSelections' $ node expr

freeSelections' FF = []
freeSelections' TT = []
freeSelections' (Numeral _) = []
freeSelections' (Var x) = []
freeSelections' (Application name args) = nub . concat $ map freeSelections args
freeSelections' (MapSelection m args) = case node m of 
 Var name -> (name, args) : (nub . concat $ map freeSelections args)
 _ -> nub . concat $ map freeSelections (m : args)
freeSelections' (MapUpdate m args val) =  nub . concat $ map freeSelections (val : m : args)
freeSelections' (Old e) = internalError "freeSelections should only be applied in single-state context"
freeSelections' (IfExpr cond e1 e2) = nub . concat $ [freeSelections cond, freeSelections e1, freeSelections e2]
freeSelections' (Coercion e _) = freeSelections e
freeSelections' (UnaryExpression _ e) = freeSelections e
freeSelections' (BinaryExpression _ e1 e2) = nub . concat $ [freeSelections e1, freeSelections e2]
freeSelections' (Quantified _ _ boundVars e) = let boundVarNames = map fst boundVars 
  in [(m, args) | (m, args) <- freeSelections e, m `notElem` boundVarNames]
  
-- | 'applications' @expr@ : all function applications that occur in @expr@
applications :: Expression -> [(Id, [Expression])]
applications expr = applications' $ node expr

applications' FF = []
applications' TT = []
applications' (Numeral _) = []
applications' (Var x) = []
applications' (Application name args) = (name, args) : (nub . concat $ map applications args)
applications' (MapSelection m args) = nub . concat $ map applications (m : args)
applications' (MapUpdate m args val) =  nub . concat $ map applications (val : m : args)
applications' (Old e) = internalError "applications should only be applied in single-state context"
applications' (IfExpr cond e1 e2) = nub . concat $ [applications cond, applications e1, applications e2]
applications' (Coercion e _) = applications e
applications' (UnaryExpression _ e) = applications e
applications' (BinaryExpression _ e1 e2) = nub . concat $ [applications e1, applications e2]
applications' (Quantified _ _ _ e) = applications e  

{- Specs -}

-- | 'preconditions' @specs@ : all precondition clauses in @specs@  
preconditions :: [Contract] -> [SpecClause]
preconditions specs = catMaybes (map extractPre specs)
  where 
    extractPre (Requires f e) = Just (SpecClause Precondition f e)
    extractPre _ = Nothing

-- | 'postconditions' @specs@ : all postcondition clauses in @specs@     
postconditions :: [Contract] -> [SpecClause]
postconditions specs = catMaybes (map extractPost specs)
  where 
    extractPost (Ensures f e) = Just (SpecClause Postcondition f e)
    extractPost _ = Nothing
   
-- | 'modifies' @specs@ : all modifies clauses in @specs@   
modifies :: [Contract] -> [Id]
modifies specs = (nub . concat . catMaybes) (map extractMod specs)
  where
    extractMod (Modifies _ ids) = Just ids
    extractMod _ = Nothing
  
-- | Make all preconditions in contracts free  
assumePreconditions :: PSig -> PSig
assumePreconditions sig = sig { psigContracts = map assumePrecondition (psigContracts sig) }
  where
    assumePrecondition (Requires _ e) = Requires True e
    assumePrecondition c = c
    
-- | Make all postconditions in contracts free  
assumePostconditions :: PSig -> PSig
assumePostconditions sig = sig { psigContracts = map assumePostcondition (psigContracts sig) }
  where
    assumePostcondition (Ensures _ e) = Ensures True e
    assumePostcondition c = c    

{- Functions and procedures -}

-- | Function signature
data FSig = FSig {
    fsigName :: Id,         -- ^ Function name
    fsigTypeVars :: [Id],   -- ^ Type variables
    fsigArgTypes :: [Type], -- ^ Argument types
    fsigRetType :: Type     -- ^ Return type
  }
  
-- | Function signature as a map type  
fsigType sig = MapType (fsigTypeVars sig) (fsigArgTypes sig) (fsigRetType sig)

-- | Map type as a function signature 
fsigFromType (MapType tv domainTypes rangeType) = FSig "" tv domainTypes rangeType 

instance Eq FSig where
  s1 == s2 = fsigName s1 == fsigName s2
  
-- | Function definition
data FDef = FDef {
    fdefName  :: Id,            -- ^ Entity to which the definition belongs
    fdefTV    :: [Id],          -- ^ Type variables
    fdefArgs  :: [IdType],      -- ^ Arguments (types may be less general than in the corresponding signature)
    fdefGuard :: Expression,    -- ^ Condition under which the definition applies
    fdefBody  :: Expression     -- ^ Body 
  }

-- | Constraint set: contains a list of definitions and a list of constraints
type ConstraintSet = ([FDef], [FDef])

-- | Abstract store: maps names to their constraints
type AbstractStore = Map Id ConstraintSet

-- | Union of abstract stores (values at the same key are concatenated)
asUnion :: AbstractStore -> AbstractStore -> AbstractStore
asUnion s1 s2 = M.unionWith (\(d1, c1) (d2, c2) -> (d1 ++ d2, c1 ++ c2)) s1 s2
 
-- | Procedure signature 
data PSig = PSig {
    psigName :: Id,               -- ^ Procedure name
    psigTypeVars :: [Id],         -- ^ Type variables
    psigArgs :: [IdTypeWhere],    -- ^ In-parameters
    psigRets :: [IdTypeWhere],    -- ^ Out-parameters
    psigContracts :: [Contract]   -- ^ Contracts
  }
  
instance Eq PSig where
  s1 == s2 = psigName s1 == psigName s2  
  
-- | All parameters of a procedure signature 
psigParams sig = psigArgs sig ++ psigRets sig
-- | Types of in-parameters of a procedure signature
psigArgTypes = (map itwType) . psigArgs
-- | Types of out-parameters of a procedure signature
psigRetTypes = (map itwType) . psigRets
-- | Procedure signature as a map type
psigType sig = MapType (psigTypeVars sig) (psigArgTypes sig) (tupleType $ psigRetTypes sig) 
-- | Modifies clauses of a procedure signature
psigModifies = modifies . psigContracts
-- | Preconditions of a procedure signature
psigRequires = preconditions . psigContracts
-- | Postconditions of a procedure signature
psigEnsures = postconditions . psigContracts
  
-- | Procedure definition;
-- a single procedure might have multiple definitions (one per body)
data PDef = PDef { 
    pdefIns :: [Id],                  -- ^ In-parameter names (in the same order as 'psigArgs' in the corresponding signature)
    pdefOuts :: [Id],                 -- ^ Out-parameter names (in the same order as 'psigRets' in the corresponding signature)
    pdefParamsRenamed :: Bool,        -- ^ Are any parameter names in this definition different for the procedure signature? (used for optimizing parameter renaming, True is a safe default)
    pdefBody :: BasicBody,            -- ^ Body
    pdefConstraints :: AbstractStore, -- ^ Constraints on local names
    pdefPos :: SourcePos              -- ^ Location of the (first line of the) procedure definition in the source
  }
  
-- | All local names of a procedure definition  
pdefLocals def = pdefIns def ++ pdefOuts def ++ map itwId (fst (pdefBody def))

{- Code generation -}

num i = gen $ Numeral i
eneg e = inheritPos (UnaryExpression Neg) e
enot e = inheritPos (UnaryExpression Not) e
e1 |+|    e2 = inheritPos2 (BinaryExpression Plus) e1 e2
e1 |-|    e2 = inheritPos2 (BinaryExpression Minus) e1 e2
e1 |*|    e2 = inheritPos2 (BinaryExpression Times) e1 e2
e1 |/|    e2 = inheritPos2 (BinaryExpression Div) e1 e2
e1 |%|    e2 = inheritPos2 (BinaryExpression Mod) e1 e2
e1 |=|    e2 = inheritPos2 (BinaryExpression Eq) e1 e2
e1 |!=|   e2 = inheritPos2 (BinaryExpression Neq) e1 e2
e1 |<|    e2 = inheritPos2 (BinaryExpression Ls) e1 e2
e1 |<=|   e2 = inheritPos2 (BinaryExpression Leq) e1 e2
e1 |>|    e2 = inheritPos2 (BinaryExpression Gt) e1 e2
e1 |>=|   e2 = inheritPos2 (BinaryExpression Geq) e1 e2
e1 |&|    e2 = inheritPos2 (BinaryExpression And) e1 e2
e1 |||    e2 = inheritPos2 (BinaryExpression Or) e1 e2
e1 |=>|   e2 = inheritPos2 (BinaryExpression Implies) e1 e2
e1 |<=>|  e2 = inheritPos2 (BinaryExpression Equiv) e1 e2
assume e = attachPos (position e) (Predicate (SpecClause Inline True e))

conjunction [] = gen TT
conjunction es = foldl1 (|&|) es
  
{- Misc -}

-- | 'interval' @(lo, hi)@ : Interval from @lo@ to @hi@
interval (lo, hi) = [lo..hi]

-- | Extract the element out of a 'Right' and throw an error if its argument is 'Left'
fromRight :: Either a b -> b
fromRight (Right x) = x

-- | 'deleteAll' @keys m@ : map @m@ with @keys@ removed from its domain
deleteAll :: Ord k => [k] -> Map k a -> Map k a
deleteAll keys m = foldr M.delete m keys

-- | 'restrictDomain' @keys m@ : map @m@ restricted on the set of keys @keys@
restrictDomain :: Ord k => Set k -> Map k a -> Map k a
restrictDomain keys m = M.filterWithKey (\k _ -> k `S.member` keys) m

-- | 'removeDomain' @keys m@ : map @m@ with the set of keys @keys@ removed from its domain
removeDomain :: Ord k => Set k -> Map k a -> Map k a
removeDomain keys m = M.filterWithKey (\k _ -> k `S.notMember` keys) m

mapItwType f (IdTypeWhere i t w) = IdTypeWhere i (f t) w

-- | Monadic version of 'any' (executes boolean-valued computation for all arguments in a list until the first True is found) 
anyM :: Monad m => (a -> m Bool) -> [a] -> m Bool
anyM _ [] = return False
anyM pred (x : xs) = do
  res <- pred x
  if res then return True else anyM pred xs
  
-- | Monadic version of 'all' (executes boolean-valued computation for all arguments in a list until the first False is found) 
allM :: Monad m => (a -> m Bool) -> [a] -> m Bool
allM _ [] = return True
allM pred (x : xs) = do
  res <- pred x
  if not res then return False else allM pred xs  

-- | Execute a computation with state of type @t@ inside a computation with state of type @s@
changeState :: Monad m => (s -> t) -> (t -> s -> s) -> StateT t m a -> StateT s m a
changeState getter modifier e = do
  st <- gets getter
  (res, st') <- lift $ runStateT e st
  modify $ modifier st'
  return res  

-- | 'withLocalState' @localState e@ :
-- Execute @e@ in current state modified by @localState@, and then restore current state
withLocalState :: Monad m => (s -> t) -> StateT t m a -> StateT s m a
withLocalState localState e = changeState localState (flip const) e
      
internalError msg = error $ "Internal interpreter error (consider submitting a bug report):\n" ++ msg