module Camfort.Specification.Units.InferenceFrontend
  ( initInference, runCriticalVariables, runInferVariables, runInconsistentConstraints, getConstraint )
where
import Data.Data (Data)
import Data.List (nub)
import qualified Data.Map as M
import qualified Data.IntMap as IM
import qualified Data.Set as S
import Data.Maybe (isJust, fromMaybe, catMaybes)
import Data.Generics.Uniplate.Operations
import Control.Monad
import Control.Monad.State.Strict
import Control.Monad.Writer.Strict
import Control.Monad.Trans.Except
import Control.Monad.RWS.Strict
import qualified Language.Fortran.AST as F
import qualified Language.Fortran.Analysis as FA
import Camfort.Analysis.CommentAnnotator (annotateComments)
import Camfort.Analysis.Annotations
import Camfort.Specification.Units.Environment
import Camfort.Specification.Units.Monad
import Camfort.Specification.Units.InferenceBackend
import qualified Camfort.Specification.Units.Parser as P
import qualified Debug.Trace as D
import qualified Numeric.LinearAlgebra as H 
initInference :: UnitSolver ()
initInference = do
  pf <- gets usProgramFile
  
  
  let (linkedPF, parserReport) = runWriter $ annotateComments P.unitParser pf
  modifyProgramFile $ const linkedPF
  
  mapM_ tell parserReport
  
  
  
  
  insertGivenUnits
  
  
  
  insertParametricUnits
  
  
  
  insertUndeterminedUnits
  
  
  annotateAllVariables
  
  
  annotateLiterals
  
  
  
  
  propagateUnits
  
  
  
  abstractCons <- extractConstraints
  
  
  
  cons <- applyTemplates abstractCons
  
  
  modifyProgramFile cleanLinks
  modify $ \ s -> s { usConstraints = cons }
  debugLogging
cleanLinks :: F.ProgramFile UA -> F.ProgramFile UA
cleanLinks = transformBi (\ a -> a { unitBlock = Nothing, unitSpec = Nothing } :: UnitAnnotation A)
runCriticalVariables :: UnitSolver [UnitInfo]
runCriticalVariables = do
  cons <- usConstraints `fmap` get
  return $ criticalVariables cons
runInferVariables :: UnitSolver [(String, UnitInfo)]
runInferVariables = do
  cons <- usConstraints `fmap` get
  return $ inferVariables cons
runInconsistentConstraints :: UnitSolver (Maybe Constraints)
runInconsistentConstraints = do
  cons <- usConstraints `fmap` get
  return $ inconsistentConstraints cons
insertParametricUnits :: UnitSolver ()
insertParametricUnits = gets usProgramFile >>= (mapM_ paramPU . universeBi)
  where
    paramPU pu = do
      forM_ (indexedParams pu) $ \ (i, param) -> do
        
        modifyVarUnitMap $ M.insertWith (curry snd) param (UnitParamPosAbs (fname, i))
      where
        fname = puName pu
indexedParams :: F.ProgramUnit UA -> [(Int, String)]
indexedParams pu
  | F.PUFunction _ _ _ _ _ (Just paList) (Just r) _ _ <- pu = zip [0..] $ varName r : map varName (F.aStrip paList)
  | F.PUFunction _ _ _ _ _ (Just paList) _ _ _        <- pu = zip [0..] $ fname     : map varName (F.aStrip paList)
  | F.PUSubroutine _ _ _ _ (Just paList) _ _          <- pu = zip [1..] $ map varName (F.aStrip paList)
  | otherwise                                               = []
  where
    fname = puName pu
insertUndeterminedUnits :: UnitSolver ()
insertUndeterminedUnits = do
  pf <- gets usProgramFile
  forM_ (universeBi pf) $ \ pu -> case pu of
    F.PUFunction {}   -> modifyPUBlocksM (transformBiM (toParamVar (puName pu))) pu
    F.PUSubroutine {} -> modifyPUBlocksM (transformBiM (toParamVar (puName pu))) pu
    _                 -> modifyPUBlocksM (transformBiM toUnitVar) pu
  where
    toParamVar :: String -> F.Expression UA -> UnitSolver (F.Expression UA)
    toParamVar fname v@(F.ExpValue _ _ (F.ValVariable _)) = do
      let vname = varName v
      modifyVarUnitMap $ M.insertWith (curry snd) vname (UnitParamVarAbs (fname, vname))
      return v
    toParamVar _ e = return e
    toUnitVar :: F.Expression UA -> UnitSolver (F.Expression UA)
    toUnitVar v@(F.ExpValue _ _ (F.ValVariable _)) = do
      let vname = varName v
      modifyVarUnitMap $ M.insertWith (curry snd) vname (UnitVar vname)
      return v
    toUnitVar e = return e
insertGivenUnits :: UnitSolver ()
insertGivenUnits = do
  pf <- gets usProgramFile
  mapM_ checkComment [ b | b@(F.BlComment {}) <- universeBi pf ]
  where
    
    checkComment :: F.Block UA -> UnitSolver ()
    checkComment (F.BlComment a _ _)
      
      | Just (P.UnitAssignment (Just vars) unitsAST) <- mSpec
      , Just b                                       <- mBlock = insertUnitAssignments (toUnitInfo unitsAST) b vars
      
      | Just (P.UnitAlias name unitsAST)             <- mSpec  = modifyUnitAliasMap (M.insert name (toUnitInfo unitsAST))
      | otherwise                                              = return ()
      where
        mSpec  = unitSpec (FA.prevAnnotation a)
        mBlock = unitBlock (FA.prevAnnotation a)
    
    
    insertUnitAssignments info (F.BlStatement _ _ _ (F.StDeclaration _ _ _ _ decls)) varRealNames = do
      
      
      
      nameMap <- uoNameMap `fmap` ask
      let m = M.fromList [ (varUniqueName, info) | e@(F.ExpValue _ _ (F.ValVariable _)) <- universeBi decls
                                                 , varRealName <- varRealNames
                                                 , let varUniqueName = varName e
                                                 , maybe False (== varRealName) (varUniqueName `M.lookup` nameMap) ]
      modifyVarUnitMap $ M.unionWith const m
      modifyGivenVarSet . S.union . S.fromList . M.keys $ m
annotateAllVariables :: UnitSolver ()
annotateAllVariables = modifyProgramFileM $ \ pf -> do
  varUnitMap <- usVarUnitMap `fmap` get
  let annotateExp e@(F.ExpValue _ _ (F.ValVariable _))
        | Just info <- M.lookup (varName e) varUnitMap = setUnitInfo info e
      annotateExp e = e
  return $ transformBi annotateExp pf
annotateLiterals :: UnitSolver ()
annotateLiterals = modifyProgramFileM (transformBiM annotateLiteralsPU)
annotateLiteralsPU :: F.ProgramUnit UA -> UnitSolver (F.ProgramUnit UA)
annotateLiteralsPU pu = do
  mode <- asks uoLiterals
  case mode of
    LitUnitless -> modifyPUBlocksM (transformBiM expUnitless) pu
    LitPoly     -> modifyPUBlocksM (transformBiM (withLiterals genParamLit)) pu
    LitMixed    -> modifyPUBlocksM (transformBiM expMixed) pu
  where
    
    expMixed e = case e of
      F.ExpValue _ _ (F.ValInteger i) | read i == (0 :: Int) -> withLiterals genParamLit e
                                      | otherwise            -> withLiterals genUnitLiteral e
      F.ExpValue _ _ (F.ValReal i) | read i == (0 :: Double) -> withLiterals genParamLit e
                                   | otherwise               -> withLiterals genUnitLiteral e
      _                                                      -> return e
    
    expUnitless e
      | isLiteral e = return $ setUnitInfo UnitlessLit e
      | otherwise   = return e
    
    withLiterals m e
      | isLiteral e = flip setUnitInfo e `fmap` m
      | otherwise   = return e
applyTemplates :: Constraints -> UnitSolver Constraints
applyTemplates cons = do
  
  let instances = nub [ (name, i) | UnitParamPosUse (name, _, i) <- universeBi cons ]
  
  
  concreteCons <- foldM (substInstance []) [] instances
  
  
  
  aliasMap <- usUnitAliasMap `fmap` get
  let aliases = [ ConEq (UnitAlias name) def | (name, def) <- M.toList aliasMap ]
  let transAlias (UnitName a) | a `M.member` aliasMap = UnitAlias a
      transAlias u                                    = u
  return . transformBi transAlias . filter (not . isParametric) $ cons ++ concreteCons ++ aliases
substInstance :: [F.Name] -> Constraints -> (F.Name, Int) -> UnitSolver Constraints
substInstance callStack output (name, callId)
  
  
  
  | name `elem` callStack = return output
  | otherwise             = do
  tmap <- gets usTemplateMap
  
  
  
  
  
  
  
  template <- transformBiM callIdRemap $ [] `fromMaybe` M.lookup name tmap
  
  
  modify $ \ s -> s { usCallIdRemap = IM.empty }
  
  let instances = nub [ (name, i) | UnitParamPosUse (name, _, i) <- universeBi template ]
  template' <- foldM (substInstance (name:callStack)) [] instances
  
  return . instantiate (name, callId) $ output ++ template ++ template'
callIdRemap :: UnitInfo -> UnitSolver UnitInfo
callIdRemap info = modifyCallIdRemapM $ \ idMap -> case info of
    UnitParamPosUse (n, p, i)
      | Just i' <- IM.lookup i idMap -> return (UnitParamPosUse (n, p, i'), idMap)
      | otherwise                    -> genCallId >>= \ i' ->
                                          return (UnitParamPosUse (n, p, i'), IM.insert i i' idMap)
    UnitParamVarUse (n, v, i)
      | Just i' <- IM.lookup i idMap -> return (UnitParamVarUse (n, v, i'), idMap)
      | otherwise                    -> genCallId >>= \ i' ->
                                          return (UnitParamVarUse (n, v, i'), IM.insert i i' idMap)
    UnitParamLitUse (l, i)
      | Just i' <- IM.lookup i idMap -> return (UnitParamLitUse (l, i'), idMap)
      | otherwise                    -> genCallId >>= \ i' ->
                                          return (UnitParamLitUse (l, i'), IM.insert i i' idMap)
    _                         -> return (info, idMap)
instantiate (name, callId) = transformBi $ \ info -> case info of
  UnitParamPosAbs (name, position) -> UnitParamPosUse (name, position, callId)
  UnitParamLitAbs litId            -> UnitParamLitUse (litId, callId)
  UnitParamVarAbs (fname, vname)   -> UnitParamVarUse (fname, vname, callId)
  _                                -> info
extractConstraints :: UnitSolver Constraints
extractConstraints = do
  pf         <- gets usProgramFile
  varUnitMap <- gets usVarUnitMap
  return $ [ con | b <- mainBlocks pf, con@(ConEq {}) <- universeBi b ] ++
           [ ConEq (UnitVar v) u | (v, u) <- M.toList varUnitMap ]
mainBlocks :: F.ProgramFile UA -> [F.Block UA]
mainBlocks = concatMap getBlocks . universeBi
  where
    getBlocks (F.PUMain _ _ _ bs _)   = bs
    getBlocks (F.PUModule _ _ _ bs _) = bs
    getBlocks _                       = []
isParametric :: Constraint -> Bool
isParametric info = not . null $ [ () | UnitParamPosAbs _ <- universeBi info ] ++
                                 [ () | UnitParamVarAbs _ <- universeBi info ] ++
                                 [ () | UnitParamLitAbs _ <- universeBi info ]
propagateUnits :: UnitSolver ()
propagateUnits = modifyProgramFileM $ transformBiM propagatePU        <=<
                                      transformBiM propagateStatement <=<
                                      transformBiM propagateExp
propagateExp :: F.Expression UA -> UnitSolver (F.Expression UA)
propagateExp e = fmap uoLiterals ask >>= \ lm -> case e of
  F.ExpValue _ _ (F.ValVariable _)       -> return e 
  F.ExpValue _ _ (F.ValInteger _)        -> return e 
  F.ExpValue _ _ (F.ValReal _)           -> return e 
  F.ExpBinary _ _ F.Multiplication e1 e2 -> setF2 UnitMul (getUnitInfoMul lm e1) (getUnitInfoMul lm e2)
  F.ExpBinary _ _ F.Division e1 e2       -> setF2 UnitMul (getUnitInfoMul lm e1) (flip UnitPow (1) `fmap` (getUnitInfoMul lm e2))
  F.ExpBinary _ _ F.Exponentiation e1 e2 -> setF2 UnitPow (getUnitInfo e1) (constantExpression e2)
  F.ExpBinary _ _ o e1 e2 | isOp AddOp o -> setF2C ConEq  (getUnitInfo e1) (getUnitInfo e2)
                          | isOp RelOp o -> setF2C ConEq  (getUnitInfo e1) (getUnitInfo e2)
  F.ExpFunctionCall {}                   -> propagateFunctionCall e
  _                                      -> whenDebug (tell ("propagateExp: unhandled: " ++ show e)) >> return e
  where
    
    setF2 f u1 u2  = return $ maybeSetUnitInfoF2 f u1 u2 e
    
    setF2C f u1 u2 = return . maybeSetUnitInfo u1 $ maybeSetUnitConstraintF2 f u1 u2 e
propagateFunctionCall :: F.Expression UA -> UnitSolver (F.Expression UA)
propagateFunctionCall e@(F.ExpFunctionCall a s f Nothing)                     = do
  (info, _)     <- callHelper f []
  return . setUnitInfo info $ F.ExpFunctionCall a s f Nothing
propagateFunctionCall e@(F.ExpFunctionCall a s f (Just (F.AList a' s' args))) = do
  (info, args') <- callHelper f args
  return . setUnitInfo info $ F.ExpFunctionCall a s f (Just (F.AList a' s' args'))
propagateStatement :: F.Statement UA -> UnitSolver (F.Statement UA)
propagateStatement stmt = case stmt of
  F.StExpressionAssign _ _ e1 e2               -> do
    return $ maybeSetUnitConstraintF2 ConEq (getUnitInfo e1) (getUnitInfo e2) stmt
  F.StCall a s sub (Just (F.AList a' s' args)) -> do
    (_, args') <- callHelper sub args
    return $ F.StCall a s sub (Just (F.AList a' s' args'))
  F.StDeclaration {}                           -> transformBiM propagateDeclarator stmt
  _                                            -> return stmt
propagateDeclarator :: F.Declarator UA -> UnitSolver (F.Declarator UA)
propagateDeclarator decl = case decl of
  F.DeclVariable _ _ e1 _ (Just e2) -> do
    return $ maybeSetUnitConstraintF2 ConEq (getUnitInfo e1) (getUnitInfo e2) decl
  F.DeclArray _ _ e1 _ _ (Just e2)  -> do
    return $ maybeSetUnitConstraintF2 ConEq (getUnitInfo e1) (getUnitInfo e2) decl
  _                                 -> return decl
propagatePU :: F.ProgramUnit UA -> UnitSolver (F.ProgramUnit UA)
propagatePU pu = do
  let name = puName pu
  let bodyCons = [ con | con@(ConEq {}) <- universeBi pu ] 
  varMap <- gets usVarUnitMap
  
  
  
  
  
  givenCons <- fmap catMaybes . forM (indexedParams pu) $ \ (i, param) -> do
    case M.lookup param varMap of
      Just (UnitParamPosAbs {}) -> return Nothing
      Just u                    -> return . Just . ConEq u $ UnitParamPosAbs (name, i)
      _                         -> return Nothing
  let cons = givenCons ++ bodyCons
  modifyTemplateMap (M.insert name cons)
  return (setConstraint (ConConj cons) pu)
containsParametric :: Data from => String -> from -> Bool
containsParametric name x = not . null $ [ () | UnitParamPosAbs (name', _) <- universeBi x, name == name' ] ++
                                         [ () | UnitParamVarAbs (name', _) <- universeBi x, name == name' ]
callHelper :: F.Expression UA -> [F.Argument UA] -> UnitSolver (UnitInfo, [F.Argument UA])
callHelper nexp args = do
  let name = varName nexp
  callId <- genCallId 
  let eachArg i arg@(F.Argument _ _ _ e)
        
        | Just u <- getUnitInfo e = setConstraint (ConEq u (UnitParamPosUse (name, i, callId))) arg
        | otherwise               = arg
  let args' = zipWith eachArg [1..] args
  
  let info = UnitParamPosUse (name, 0, callId)
  return (info, args')
genCallId :: UnitSolver Int
genCallId = do
  st <- get
  let callId = usCallIds st
  put $ st { usCallIds = callId + 1 }
  return callId
genUnitLiteral :: UnitSolver UnitInfo
genUnitLiteral = do
  s <- get
  let i = usLitNums s
  put $ s { usLitNums = i + 1 }
  return $ UnitLiteral i
genParamLit :: UnitSolver UnitInfo
genParamLit = do
  s <- get
  let i = usLitNums s
  put $ s { usLitNums = i + 1 }
  return $ UnitParamLitAbs i
getUnitInfo :: F.Annotated f => f UA -> Maybe UnitInfo
getUnitInfo = unitInfo . FA.prevAnnotation . F.getAnnotation
getConstraint :: F.Annotated f => f UA -> Maybe Constraint
getConstraint = unitConstraint . FA.prevAnnotation . F.getAnnotation
getUnitInfoMul :: LiteralsOpt -> F.Expression UA -> Maybe UnitInfo
getUnitInfoMul LitPoly e          = getUnitInfo e
getUnitInfoMul _ e
  | isJust (constantExpression e) = Just UnitlessLit
  | otherwise                     = getUnitInfo e
setUnitInfo :: F.Annotated f => UnitInfo -> f UA -> f UA
setUnitInfo info = modifyAnnotation (onPrev (\ ua -> ua { unitInfo = Just info }))
setConstraint :: F.Annotated f => Constraint -> f UA -> f UA
setConstraint c = modifyAnnotation (onPrev (\ ua -> ua { unitConstraint = Just c }))
maybeSetUnitInfo :: F.Annotated f => Maybe UnitInfo -> f UA -> f UA
maybeSetUnitInfo Nothing e  = e
maybeSetUnitInfo (Just u) e = setUnitInfo u e
maybeSetUnitInfoF2 :: F.Annotated f => (a -> b -> UnitInfo) -> Maybe a -> Maybe b -> f UA -> f UA
maybeSetUnitInfoF2 f (Just u1) (Just u2) e = setUnitInfo (f u1 u2) e
maybeSetUnitInfoF2 _ _ _ e                 = e
maybeSetUnitConstraintF2 :: F.Annotated f => (a -> b -> Constraint) -> Maybe a -> Maybe b -> f UA -> f UA
maybeSetUnitConstraintF2 f (Just u1) (Just u2) e = setConstraint (f u1 u2) e
maybeSetUnitConstraintF2 _ _ _ e                 = e
fmapUnitInfo :: F.Annotated f => (UnitInfo -> UnitInfo) -> f UA -> f UA
fmapUnitInfo f x
  | Just u <- getUnitInfo x = setUnitInfo (f u) x
  | otherwise               = x
modifyPUBlocksM :: Monad m => ([F.Block a] -> m [F.Block a]) -> F.ProgramUnit a -> m (F.ProgramUnit a)
modifyPUBlocksM f pu = case pu of
  F.PUMain a s n b pus                    -> flip fmap (f b) $ \ b' -> F.PUMain a s n b' pus
  F.PUModule a s n b pus                  -> flip fmap (f b) $ \ b' -> F.PUModule a s n b' pus
  F.PUSubroutine a s r n p b subs         -> flip fmap (f b) $ \ b' -> F.PUSubroutine a s r n p b' subs
  F.PUFunction   a s r rec n p res b subs -> flip fmap (f b) $ \ b' -> F.PUFunction a s r rec n p res b' subs
  F.PUBlockData  a s n b                  -> flip fmap (f b) $ \ b' -> F.PUBlockData  a s n b'
isLiteral (F.ExpValue _ _ (F.ValReal _)) = True
isLiteral (F.ExpValue _ _ (F.ValInteger _)) = True
isLiteral _ = False
constantExpression :: F.Expression a -> Maybe Double
constantExpression (F.ExpValue _ _ (F.ValInteger i)) = Just $ read i
constantExpression (F.ExpValue _ _ (F.ValReal r))    = Just $ read r
constantExpression _                                 = Nothing
isOp :: BinOpKind -> F.BinaryOp -> Bool
isOp cat = (== cat) . binOpKind
data BinOpKind = AddOp | MulOp | DivOp | PowerOp | LogicOp | RelOp deriving Eq
binOpKind :: F.BinaryOp -> BinOpKind
binOpKind F.Addition         = AddOp
binOpKind F.Subtraction      = AddOp
binOpKind F.Multiplication   = MulOp
binOpKind F.Division         = DivOp
binOpKind F.Exponentiation   = PowerOp
binOpKind F.Concatenation    = AddOp
binOpKind F.GT               = RelOp
binOpKind F.GTE              = RelOp
binOpKind F.LT               = RelOp
binOpKind F.LTE              = RelOp
binOpKind F.EQ               = RelOp
binOpKind F.NE               = RelOp
binOpKind F.Or               = LogicOp
binOpKind F.And              = LogicOp
binOpKind F.Equivalent       = RelOp
binOpKind F.NotEquivalent    = RelOp
binOpKind (F.BinCustom _)    = RelOp
debugLogging :: UnitSolver ()
debugLogging = whenDebug $ do
    (tell . unlines . map (\ (ConEq u1 u2) -> "  ***AbsConstraint: " ++ show (flattenUnits u1) ++ " === " ++ show (flattenUnits u2) ++ "\n")) =<< extractConstraints
    pf <- gets usProgramFile
    cons <- usConstraints `fmap` get
    vum <- usVarUnitMap `fmap` get
    tell . unlines $ [ "  " ++ show info ++ " :: " ++ n | (n, info) <- M.toList vum ]
    tell "\n\n"
    uam <- usUnitAliasMap `fmap` get
    tell . unlines $ [ "  " ++ n ++ " = " ++ show info | (n, info) <- M.toList uam ]
    tell . unlines $ map (\ (ConEq u1 u2) -> "  ***Constraint: " ++ show (flattenUnits u1) ++ " === " ++ show (flattenUnits u2) ++ "\n") cons
    tell $ show cons ++ "\n\n"
    forM_ (universeBi pf) $ \ pu -> case pu of
      F.PUFunction {}
        | Just (ConConj cons) <- getConstraint pu ->
          whenDebug . tell . unlines $ (puName pu ++ ":"):map (\ (ConEq u1 u2) -> "    constraint: " ++ show (flattenUnits u1) ++ " === " ++ show (flattenUnits u2)) cons
      F.PUSubroutine {}
        | Just (ConConj cons) <- getConstraint pu ->
          whenDebug . tell . unlines $ (puName pu ++ ":"):map (\ (ConEq u1 u2) -> "    constraint: " ++ show (flattenUnits u1) ++ " === " ++ show (flattenUnits u2)) cons
      _ -> return ()
    let (unsolvedM, inconsists, colA) = constraintsToMatrix cons
    let solvedM = rref unsolvedM
    tell "\n--------------------------------------------------\n"
    tell $ show colA
    tell "\n--------------------------------------------------\n"
    tell $ show unsolvedM
    tell "\n--------------------------------------------------\n"
    tell . show $ (H.takeRows (H.rank solvedM) solvedM)
    tell "\n--------------------------------------------------\n"
    tell $ "Rank: " ++ show (H.rank solvedM) ++ "\n"
    tell $ "Is inconsistent RREF? " ++ show (isInconsistentRREF solvedM) ++ "\n"
    tell $ "Inconsistent rows: " ++ show (inconsistentConstraints cons) ++ "\n"
    tell "--------------------------------------------------\n"
    tell $ "Critical Variables: " ++ show (criticalVariables cons) ++ "\n"
    tell $ "Infer Variables: " ++ show (inferVariables cons) ++ "\n"
puName :: F.ProgramUnit UA -> F.Name
puName pu
  | F.Named n <- FA.puName pu = n
  | otherwise               = "_nameless"
varName :: F.Expression UA -> F.Name
varName = FA.varName