{-# LANGUAGE FlexibleContexts
           , GADTs
           , ScopedTypeVariables
           , DataKinds
           , TypeOperators
           , OverloadedStrings
           , LambdaCase
           #-}

{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
---------------------------------------------------------------
module Language.Hakaru.Syntax.AST.Transforms where

import qualified Data.Sequence as S

import Language.Hakaru.Syntax.ANF      (normalize)
import Language.Hakaru.Syntax.CSE      (cse)
import Language.Hakaru.Syntax.Prune    (prune)
import Language.Hakaru.Syntax.Uniquify (uniquify)
import Language.Hakaru.Syntax.Hoist    (hoist)
import Language.Hakaru.Syntax.ABT
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.TypeOf
import Language.Hakaru.Syntax.IClasses
import Language.Hakaru.Syntax.Prelude (lamWithVar, app)
import Language.Hakaru.Types.DataKind

import Language.Hakaru.Expect       (expectInCtx, determineExpect)
import Language.Hakaru.Disintegrate (determine, observeInCtx, disintegrateInCtx)
import Language.Hakaru.Inference    (mcmc', mh')
import Language.Hakaru.Maple        (sendToMaple, MapleOptions(..)
                                    ,defaultMapleOptions, MapleCommand(..)
                                    ,MapleException)

import Data.Ratio (numerator, denominator)
import Language.Hakaru.Types.Sing (sing, Sing(..), sUnFun)
import Language.Hakaru.Types.HClasses (HFractional(..))
import Language.Hakaru.Types.Coercion (findCoercion, Coerce(..))
import qualified Data.Sequence as Seq
import Control.Monad.Fix (MonadFix)
import Control.Monad (liftM)
import Control.Monad.Trans (MonadTrans(..))
import Control.Monad.State  (StateT(..), evalStateT, put, get, withStateT)
import Control.Applicative (Applicative(..), Alternative(..), (<$>), (<$))
import Data.Functor.Identity (Identity(..))

import Control.Exception (try)
import System.IO (stderr)
import Data.Text.Utf8 (hPutStrLn)
import Data.Text (pack)
import Data.Monoid (Monoid(..), (<>))

import Debug.Trace


optimizations
  :: (ABT Term abt)
  => abt '[] a
  -> abt '[] a
optimizations = uniquify
              . prune
              . cse
              . hoist
              -- The hoist pass needs globally uniqiue identifiers
              . uniquify
              . normalize

underLam
    :: (ABT Term abt, Monad m)
    => (abt '[] b -> m (abt '[] b))
    -> abt '[] (a ':-> b)
    -> m (abt '[] (a ':-> b))
underLam f e = caseVarSyn e (return . var) $ \t ->
                   case t of
                   Lam_ :$ e1 :* End ->
                       caseBind e1 $ \x e1' -> do
                           e1'' <- f e1'
                           return . syn $
                                  Lam_  :$ (bind x e1'' :* End)

                   Let_ :$ e1 :* e2 :* End ->
                        case jmEq1 (typeOf e1) (typeOf e) of
                          Just Refl -> do
                               e1' <- underLam f e1
                               return . syn $
                                      Let_ :$ e1' :* e2 :* End
                          Nothing   -> caseBind e2 $ \x e2' -> do
                                         e2'' <- underLam f e2'
                                         return . syn $
                                                Let_ :$ e1 :* (bind x e2'') :* End

                   _ -> error "TODO: underLam"

underLam'
    :: forall abt m a b b'
     . (ABT Term abt, MonadFix m)
    => (abt '[] b -> m (abt '[] b'))
    -> abt '[] (a ':-> b)
    -> m (abt '[] (a ':-> b'))
underLam' f e = do
  f' <- trace "underLam': build function" $
        liftM (\f' b -> app (syn $ Lam_ :$ f' :* End) b) $
        binderM "" (snd $ sUnFun $ typeOf e) f
  return $ underLam'p f' e

underLam'p
    :: forall abt a b b'
     . (ABT Term abt)
    => (abt '[] b -> abt '[] b')
    -> abt '[] (a ':-> b)
    -> abt '[] (a ':-> b')
underLam'p f e =
  let var_ :: Variable (a ':-> b) -> abt '[] (a ':-> b')
      var_ v_ab = trace "underLam': entered var" $
        lamWithVar "" (fst $ sUnFun $ varType v_ab) $ \a ->
        trace "underLam': applied function" $ f $ app (var v_ab) a

      syn_ t = trace "underLam': entered syn" $
        case t of
        Lam_ :$ e1 :* End -> trace "underLam': entered syn/Lam_" $
          caseBind e1 $ \x e1' ->
            trace "underLam': rebuilt Lam_" $
            syn $ Lam_  :$
                (trace "underLam': applied bind{Lam_}" $
                      bind x (trace "underLam': applied function{Lam_}"
                                $ f e1')) :* End

        Let_ :$ e1 :* e2 :* End -> trace "underLam': entered syn/Lam_" $
          caseBind e2 $ \x e2' ->
            trace "underLam': rebuilt Let_" $
            syn $ Let_ :$ e1 :*
                  (trace "underLam': applied bind{Lam_}" $
                         bind x (trace "underLam': recursive case{Let_}" $
                                       go e2')) :* End

        _ -> error "TODO: underLam'"

      go e' = trace "underLam': entered main body" $
              caseVarSyn e' var_ syn_
  in go e

--------------------------------------------------------------------------------

expandTransformations
    :: forall abt a
    . (ABT Term abt)
    => abt '[] a -> abt '[] a
expandTransformations =
  expandTransformationsWith' haskellTransformations

expandAllTransformations
    :: forall abt a
    . (ABT Term abt)
    => abt '[] a -> IO (abt '[] a)
expandAllTransformations =
  expandTransformationsWith allTransformations

expandTransformationsWith'
    :: forall abt a
    . (ABT Term abt)
    => TransformTable abt Identity
    -> abt '[] a -> abt '[] a
expandTransformationsWith' tbl =
  runIdentity . expandTransformationsWith tbl

type TransformM = StateT TransformCtx

expandTransformationsWith
    :: forall abt a m
    . (ABT Term abt, Applicative m, Monad m)
    => TransformTable abt m
    -> abt '[] a -> m (abt '[] a)
expandTransformationsWith tbl t0 =
  flip evalStateT (mempty {nextFreeVar = nextFreeOrBind t0}) .
  cataABTM (pure . var) bind_ (>>= syn_) $ t0
    where
    bind_ :: forall x xs b
           . Variable x
          -> TransformM m (abt xs b)
          -> TransformM m (abt (x ': xs) b)
    bind_ v mt = bind v <$> withStateT (ctxOf v <>) mt

    syn_ :: forall b. Term abt b -> TransformM m (abt '[] b)
    syn_ t =
      case t of
        Transform_ tr :$ as ->
          get >>= \ctx ->
          maybe (pure $ syn t)
                (\r -> r <$ put (ctxOf r <> ctx))
                =<< lift (lookupTransform' tbl tr ctx as)
        _ -> pure $ syn t

mapleTransformationsWithOpts
  :: forall abt
   . ABT Term abt
  => MapleOptions ()
  -> TransformTable abt IO
mapleTransformationsWithOpts opts = TransformTable $ \tr ->
  let cmd c ctx x =
        try (sendToMaple opts{command=MapleCommand c
                             ,context=ctx} x) >>=
          \case
            Left (err :: MapleException) ->
              hPutStrLn stderr (pack $ show err) >> pure Nothing
            Right r ->
              pure $ Just r
      cmd :: Transform '[LC i] o -> TransformCtx
          -> abt '[] i  -> IO (Maybe (abt '[] o)) in
  case tr of
    Simplify       ->
      Just $ \ctx -> \case { e1 :* End -> cmd tr ctx e1 }
    Summarize      ->
      Just $ \ctx -> \case { e1 :* End -> cmd tr ctx e1 }
    Reparam        ->
      Just $ \ctx -> \case { e1 :* End -> cmd tr ctx e1 }
    Disint InMaple ->
      Just $ \ctx -> \case { e1 :* End -> cmd tr ctx e1 }
    _              -> Nothing

mapleTransformations
  :: ABT Term abt
  => TransformTable abt IO
mapleTransformations = mapleTransformationsWithOpts defaultMapleOptions

haskellTransformations :: (Applicative m, ABT Term abt) => TransformTable abt m
haskellTransformations = simpleTable $ \tr ->
  case tr of
    Expect ->
      Just $ \ctx -> \case
        e1 :* e2 :* End -> determineExpect $ expectInCtx ctx e1 e2

    Observe ->
      Just $ \ctx -> \case
        e1 :* e2 :* End -> determine $ observeInCtx ctx e1 e2

    MCMC ->
      Just $ \ctx -> \case
        e1 :* e2 :* End -> mcmc' ctx e1 e2

    MH ->
      Just $ \ctx -> \case
        e1 :* e2 :* End -> mh' ctx e1 e2

    Disint InHaskell ->
      Just $ \ctx -> \case
        e1 :* End -> determine $ disintegrateInCtx ctx e1

    _ -> Nothing

allTransformationsWithMOpts
   :: ABT Term abt
   => MapleOptions ()
   -> TransformTable abt IO
allTransformationsWithMOpts opts = unionTable
  (mapleTransformationsWithOpts opts)
  haskellTransformations

allTransformations :: ABT Term abt => TransformTable abt IO
allTransformations = allTransformationsWithMOpts defaultMapleOptions

--------------------------------------------------------------------------------

coalesce
  :: forall abt a
  .  (ABT Term abt)
  => abt '[] a
  -> abt '[] a
coalesce abt = caseVarSyn abt var onNaryOps
  where onNaryOps (NaryOp_ t es) = syn $ NaryOp_ t (coalesceNaryOp t es)
        onNaryOps term           = syn term

coalesceNaryOp
  :: ABT Term abt
  => NaryOp a
  -> S.Seq (abt '[] a)
  -> S.Seq (abt '[] a)
coalesceNaryOp typ args =
  do abt <- args
     case viewABT abt of
       Syn (NaryOp_ typ' args') ->
         if typ == typ'
         then coalesceNaryOp typ args'
         else return (coalesce abt)
       _ -> return abt