{-# LANGUAGE LambdaCase    #-}
{-# LANGUAGE ViewPatterns  #-}

-- | Utility functions used by the normalisation transformations
module CLaSH.Normalize.Util where

import           Control.Lens            ((%=), (.=))
import qualified Control.Lens            as Lens
import qualified Data.Either             as Either
import qualified Data.Graph              as Graph
import           Data.HashMap.Lazy       (HashMap)
import qualified Data.HashMap.Lazy       as HashMap
import qualified Data.List               as List
import qualified Data.Maybe              as Maybe
import qualified Data.Set                as Set
import           Unbound.LocallyNameless (Fresh, unembed)

import           CLaSH.Core.FreeVars     (termFreeIds)
import           CLaSH.Core.Term         (Term (..), TmName)
import           CLaSH.Core.Type         (Type (..), splitFunForallTy)
import           CLaSH.Core.Util         (collectArgs, termType)
import           CLaSH.Core.Var          (Id, Var (..))
import           CLaSH.Netlist.Util      (splitNormalized)
import           CLaSH.Normalize.Types
import           CLaSH.Rewrite.Types
import           CLaSH.Rewrite.Util

-- | Determine if a function is already inlined in the context of the 'NetlistMonad'
alreadyInlined :: TmName
               -> NormalizeMonad Bool
alreadyInlined f = do
  cf <- Lens.use curFun
  inlinedHM <- Lens.use inlined
  case HashMap.lookup cf inlinedHM of
    Nothing       -> return False
    Just inlined' -> return (f `elem` inlined')

-- | Move the names of inlined functions collected during a traversal into the
-- permanent inlined function cache
commitNewInlined :: NormRewrite
commitNewInlined _ e = R $ liftR $ do
  cf <- Lens.use curFun
  nI <- Lens.use newInlined
  inlinedHM <- Lens.use inlined
  case HashMap.lookup cf inlinedHM of
    Nothing -> inlined %= HashMap.insert cf nI
    Just _  -> inlined %= HashMap.adjust (`List.union` nI) cf
  newInlined .= []
  return e

-- | Determine if a term is closed
isClosed :: (Functor m, Fresh m)
         => Term
         -> m Bool
isClosed = fmap (not . isPolyFunTy) . termType
  where
    -- Is a type a (polymorphic) function type?
    isPolyFunTy = not . null . Either.lefts . fst . splitFunForallTy

-- | Determine if a term represents a constant
isConstant :: Term -> Bool
isConstant e = case collectArgs e of
  (Data _, args)   -> all (either isConstant (const True)) args
  (Prim _ _, args) -> all (either isConstant (const True)) args
  (Literal _,_)    -> True
  _                -> False

-- | Get the \"Wrapped\" function out of a normalized Term. Returns 'Nothing' if
-- the normalized term is not actually a wrapper.
getWrappedF :: (Fresh m,Functor m) => Term -> m (Maybe Term)
getWrappedF body = do
    normalizedM <- splitNormalized body
    case normalizedM of
      Right (funArgs,[(_,bExpr)],_) -> return $! uncurry (reduceArgs True funArgs) (collectArgs $ unembed bExpr)
      _                             -> return Nothing
  where
    reduceArgs :: Bool -> [Id] -> Term -> [Either Term Type] -> Maybe Term
    reduceArgs _    []    appE []                         = Just appE
    reduceArgs _    (_:_) _ []                            = Nothing
    reduceArgs b    ids       appE (Right ty:args)        = reduceArgs b ids (TyApp appE ty) args
    reduceArgs _    (id1:ids) appE (Left (Var _ nm):args) | varName id1 == nm = reduceArgs False ids appE args
    reduceArgs True ids@(_:_) appE (Left arg:args)        = reduceArgs True ids (App appE arg) args
    reduceArgs _ _ _ _                                    = Nothing

-- | Create a call graph for a set of global binders, given a root
callGraph :: [TmName] -- ^ List of functions that should not be inspected
          -> HashMap TmName Term -- ^ Global binders
          -> TmName -- ^ Root of the call graph
          -> [(TmName,[TmName])]
callGraph visited bindingMap root = node:other
  where
    rootTm = Maybe.fromMaybe (error $ show root ++ " is not a global binder") $ HashMap.lookup root bindingMap
    used   = Set.toList $ termFreeIds rootTm
    node   = (root,used)
    other  = concatMap (callGraph (root:visited) bindingMap) (filter (`notElem` visited) used)

-- | Determine the sets of recursive components given the edges of a callgraph
recursiveComponents :: [(TmName,[TmName])] -- ^ [(calling function,[called function])]
                    -> [[TmName]]
recursiveComponents = Maybe.catMaybes
                    . map (\case {Graph.CyclicSCC vs -> Just vs; _ -> Nothing})
                    . Graph.stronglyConnComp
                    . map (\(n,es) -> (n,n,es))