{-# OPTIONS -fno-warn-orphans #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE PatternGuards #-}

module Language.Fay.Compiler.Optimizer where

import Control.Applicative
import Control.Arrow (first)
import Control.Monad.Error
import Control.Monad.Writer
import Control.Monad.State
import Data.List
import Data.Maybe
import Language.Fay.Print
import Language.Fay.Types
import Language.Haskell.Exts (QName(..),ModuleName(..),Name(..))
import Language.Haskell.Exts (SrcLoc(..))
import Prelude hiding (exp)

-- | The arity of a function. Arity here is defined to be the number
-- of arguments that can be directly uncurried from a curried lambda
-- abstraction. So \x y z -> if x then (\a -> a) else (\a -> a) has an
-- arity of 3, not 4.
type FuncArity = (QName,Int)

-- | Optimize monad.
type Optimize = State OptState

-- | State.
data OptState = OptState
  { optStmts   :: [JsStmt]
  , optUncurry :: [QName]
  }

-- | Run an optimizer, which may output additional statements.
runOptimizer :: ([JsStmt] -> Optimize [JsStmt]) -> [JsStmt] -> [JsStmt]
runOptimizer optimizer stmts =
  let (newstmts,OptState _ uncurried) = flip runState st $ optimizer stmts
  in (newstmts ++ (tco (catMaybes (map (uncurryBinding newstmts) (nub uncurried)))))
  where st = OptState stmts []

-- | Perform any top-level cross-module optimizations and GO DEEP to
-- optimize further.
optimizeToplevel :: [JsStmt] -> Optimize [JsStmt]
optimizeToplevel = stripAndUncurry

-- | Perform tail-call optimization.
tco :: [JsStmt] -> [JsStmt]
tco = map inStmt where
  inStmt stmt = case stmt of
    JsMappedVar srcloc name exp -> JsMappedVar srcloc name (inject name exp)
    JsVar name exp -> JsVar name (inject name exp)
    e -> e
  inject name exp = case exp of
    JsFun params [] (Just (JsNew JsThunk [JsFun [] stmts ret])) ->
      JsFun params
            []
            (Just
              (JsNew JsThunk
                     [JsFun []
                            (optimize params name (stmts ++ [ JsEarlyReturn e | Just e <- [ret] ]))
                            Nothing]))
    _ -> exp
  optimize params name stmts = result where
    result = let (newstmts,w) = runWriter makeWhile
             in if null w
                   then stmts
                   else newstmts
    makeWhile = do
      newstmts <- fmap concat (mapM swap stmts)
      return [JsWhile (JsLit (JsBool True)) newstmts]
    swap stmt = case stmt of
      JsEarlyReturn e
        | tailCall e -> do tell [()]
                           return (rebind e ++ [JsContinue])
        | otherwise  -> return [stmt]
      JsIf p ithen ielse -> do
        newithen <- fmap concat (mapM swap ithen)
        newielse <- fmap concat (mapM swap ielse)
        return [JsIf p newithen newielse]
      e -> return [e]
    tailCall (JsApp (JsName cname) _) = cname == name
    tailCall _ = False
    rebind (JsApp _ args) = zipWith go args params where
      go arg param = JsUpdate param arg
    rebind e = error . show $ e

-- | Strip redundant forcing from the whole generated code.
stripAndUncurry :: [JsStmt] -> Optimize [JsStmt]
stripAndUncurry = applyToExpsInStmts stripFuncForces where
  stripFuncForces arities exp = case exp of
    JsApp (JsName JsForce) [JsName (JsNameVar f)]
      | Just _ <- lookup f arities -> return (JsName (JsNameVar f))
    JsFun ps stmts body            -> do substmts <- mapM stripInStmt stmts
                                         sbody <- maybe (return Nothing) (fmap Just . go) body
                                         return (JsFun ps substmts sbody)
    JsApp a b                      -> do
      result <- walkAndStripForces arities exp
      case result of
        Just strippedExp             -> go strippedExp
        Nothing                      -> JsApp <$> go a <*> mapM go b
    JsNegApp e                     -> JsNegApp <$> go e
    JsTernaryIf a b c              -> JsTernaryIf <$> go a <*> go b <*> go c
    JsParen e                      -> JsParen <$> go e
    JsUpdateProp e n a             -> JsUpdateProp <$> go e <*> pure n <*> go a
    JsList xs                      -> JsList <$> mapM go xs
    JsEq a b                       -> JsEq <$> go a <*> go b
    JsInfix op a b                 -> JsInfix op <$> go a <*> go b
    JsObj xs                       -> JsObj <$> mapM (\(x,y) -> (x,) <$> go y) xs
    JsNew name xs                  -> JsNew name <$> mapM go xs
    e                              -> return e

    where
      go = stripFuncForces arities
      stripInStmt = applyToExpsInStmt arities stripFuncForces

-- | Strip redundant forcing from an application if possible.
walkAndStripForces :: [FuncArity] -> JsExp -> Optimize (Maybe JsExp)
walkAndStripForces arities = go True [] where
  go frst args app = case app of
    JsApp (JsName JsForce) [e] -> if frst
                                     then do result <- go False args e
                                             case result of
                                               Nothing -> return Nothing
                                               Just ex -> return (Just (JsApp (JsName JsForce) [ex]))
                                     else go False args e
    JsApp op [arg] -> go False (arg:args) op
    JsName (JsNameVar f)
      | Just arity <- lookup f arities, length args == arity -> do
        modify $ \s -> s { optUncurry = f : optUncurry s }
        return (Just (JsApp (JsName (JsNameVar (renameUncurried f))) args))
    _ -> return Nothing

-- | Apply the given function to the top-level expressions in the
-- given statements.
applyToExpsInStmts :: ([FuncArity] -> JsExp -> Optimize JsExp) -> [JsStmt] -> Optimize [JsStmt]
applyToExpsInStmts f stmts = mapM (applyToExpsInStmt (collectFuncs stmts) f) stmts

-- | Apply the given function to the top-level expressions in the
-- given statement.
applyToExpsInStmt :: [FuncArity] -> ([FuncArity] -> JsExp -> Optimize JsExp) -> JsStmt -> Optimize JsStmt
applyToExpsInStmt funcs f stmts = uncurryInStmt stmts where
  transform = f funcs
  uncurryInStmt stmt = case stmt of
    JsMappedVar srcloc name exp -> JsMappedVar srcloc name <$> transform exp
    JsVar name exp              -> JsVar name <$> transform exp
    JsEarlyReturn exp           -> JsEarlyReturn <$> transform exp
    JsIf op ithen ielse         -> JsIf <$> transform op
                                        <*> mapM uncurryInStmt ithen
                                        <*> mapM uncurryInStmt ielse
    s -> pure s

-- | Collect functions and their arity from the whole codeset.
collectFuncs :: [JsStmt] -> [FuncArity]
collectFuncs = (++ prim) . concat . map collectFunc where
  collectFunc (JsMappedVar _ name exp) = collectFunc (JsVar name exp)
  collectFunc (JsVar (JsNameVar name) exp) | arity > 0 = [(name,arity)]
    where arity = expArity exp
  collectFunc _ = []
  prim = map (first (Qual (ModuleName "Fay$"))) (unary ++ binary)
  unary = map (,1) [Ident "return"]
  binary = map ((,2) . Ident)
               ["then","bind","mult","mult","add","sub","div"
               ,"eq","neq","gt","lt","gte","lte","and","or"]

-- | Get the arity of an expression.
expArity :: JsExp -> Int
expArity (JsFun _ _ mexp) = 1 + maybe 0 expArity mexp
expArity _ = 0

test :: IO ()
test = do
  let (newstmts,OptState _ uncurried) = flip runState st $ optimizeToplevel stmts
  putStrLn $ printJSPretty newstmts
  putStrLn $ printJSPretty (catMaybes (map (uncurryBinding newstmts) uncurried))

    where
      st = OptState stmts []
      stmts = [JsMappedVar (SrcLoc {srcFilename = "", srcLine = 1, srcColumn = 1}) (JsNameVar (Qual (ModuleName "Main") (Ident "sum$uncurried"))) (JsFun [JsParam 1,JsParam 2] [] (Just (JsNew JsThunk [JsFun [] [JsVar (JsNameVar (UnQual (Ident "acc"))) (JsName (JsParam 2)),JsIf (JsEq (JsApp (JsName JsForce) [JsName (JsParam 1)]) (JsLit (JsInt 0))) [JsEarlyReturn (JsName (JsNameVar (UnQual (Ident "acc"))))] [],JsVar (JsNameVar (UnQual (Ident "acc"))) (JsName (JsParam 2)),JsVar (JsNameVar (UnQual (Ident "n"))) (JsName (JsParam 1)),JsEarlyReturn (JsApp (JsName (JsNameVar (Qual (ModuleName "Main") (Ident "sum$uncurried")))) [JsApp (JsName (JsNameVar (Qual (ModuleName "Fay$") (Ident "sub$uncurried")))) [JsApp (JsName JsForce) [JsName (JsNameVar (UnQual (Ident "n")))],JsLit (JsInt 1)],JsApp (JsName (JsNameVar (Qual (ModuleName "Fay$") (Ident "add$uncurried")))) [JsApp (JsName JsForce) [JsName (JsNameVar (UnQual (Ident "acc")))],JsApp (JsName JsForce) [JsName (JsNameVar (UnQual (Ident "n")))]]])] Nothing])))]

uncurryBinding :: [JsStmt] -> QName -> Maybe JsStmt
uncurryBinding stmts qname = listToMaybe (mapMaybe funBinding stmts)
  where
    funBinding stmt = case stmt of
      JsMappedVar srcloc (JsNameVar name) body
        | name == qname -> JsMappedVar srcloc (JsNameVar (renameUncurried name)) <$> uncurryIt body
      JsVar (JsNameVar name) body
        | name == qname -> JsVar (JsNameVar (renameUncurried name)) <$> uncurryIt body
      _ -> Nothing

    uncurryIt = Just . go [] where
      go args exp = case exp of
        JsFun [arg] [] (Just body) -> go (arg : args) body
        inner -> JsFun (reverse args) [] (Just inner)

-- | Rename an uncurried copy of a curried function.
renameUncurried :: QName -> QName
renameUncurried q = case q of
  Qual m n -> Qual m (renameUnQual n)
  UnQual n -> UnQual (renameUnQual n)
  s -> s
  where
    renameUnQual n = case n of
      Ident nom -> Ident (nom ++ postfix)
      Symbol nom -> Symbol (nom ++ postfix)
    postfix = "$uncurried"