module Agda.Compiler.Treeless.Builtin (translateBuiltins) where
import qualified Agda.Syntax.Internal as I
import Agda.Syntax.Abstract.Name (QName)
import Agda.Syntax.Position
import Agda.Syntax.Treeless
import Agda.Syntax.Literal
import Agda.TypeChecking.Substitute
import Agda.TypeChecking.Monad
import Agda.TypeChecking.Monad.Builtin
import Agda.Compiler.Treeless.Subst () 
import Agda.Utils.Impossible
data BuiltinKit = BuiltinKit
  { isZero   :: QName -> Bool
  , isSuc    :: QName -> Bool
  , isPos    :: QName -> Bool
  , isNegSuc :: QName -> Bool
  , isPlus   :: QName -> Bool
  , isTimes  :: QName -> Bool
  , isLess   :: QName -> Bool
  , isEqual  :: QName -> Bool
  , isForce  :: QName -> Bool
  , isWord64FromNat :: QName -> Bool
  , isWord64ToNat   :: QName -> Bool
  }
builtinKit :: TCM BuiltinKit
builtinKit =
  BuiltinKit <$> isB con builtinZero
             <*> isB con builtinSuc
             <*> isB con builtinIntegerPos
             <*> isB con builtinIntegerNegSuc
             <*> isB def builtinNatPlus
             <*> isB def builtinNatTimes
             <*> isB def builtinNatLess
             <*> isB def builtinNatEquals
             <*> isP pf  "primForce"
             <*> isP pf  "primWord64FromNat"
             <*> isP pf  "primWord64ToNat"
  where
    con (I.Con c _ _) = pure $ I.conName c
    con _           = Nothing
    def (I.Def d _) = pure d
    def _           = Nothing
    pf = Just . primFunName
    is  a b = maybe (const False) (==) . (a =<<) <$> b
    isB a b = is a (getBuiltin' b)
    isP a p = is a (getPrimitive' p)
translateBuiltins :: TTerm -> TCM TTerm
translateBuiltins t = do
  kit <- builtinKit
  return $ transform kit t
transform :: BuiltinKit -> TTerm -> TTerm
transform BuiltinKit{..} = tr
  where
    tr t = case t of
      TCon c | isZero c   -> tInt 0
             | isSuc c    -> TLam (tPlusK 1 (TVar 0))
             | isPos c    -> TLam (TVar 0)
             | isNegSuc c -> TLam $ tNegPlusK 1 (TVar 0)
      TDef f | isPlus f   -> TPrim PAdd
             | isTimes f  -> TPrim PMul
             | isLess f   -> TPrim PLt
             | isEqual f  -> TPrim PEqI
             | isWord64ToNat f   -> TPrim P64ToI
             | isWord64FromNat f -> TPrim PITo64
        
        
        
      TApp (TDef q) (_ : _ : _ : _ : e : f : es)
        | isForce q -> tr $ TLet e $ mkTApp (tOp PSeq (TVar 0) $ mkTApp (raise 1 f) [TVar 0]) es
      TApp (TCon s) [e] | isSuc s ->
        case tr e of
          TLit (LitNat r n) -> tInt (n + 1)
          e | Just (i, e) <- plusKView e -> tPlusK (i + 1) e
          e                 -> tPlusK 1 e
      TApp (TCon c) [e]
        | isPos c    -> tr e
        | isNegSuc c ->
        case tr e of
          TLit (LitNat _ n) -> tInt (-n - 1)
          e | Just (i, e) <- plusKView e -> tNegPlusK (i + 1) e
          e -> tNegPlusK 1 e
      TCase e t d bs -> TCase e (inferCaseType t bs) (tr d) $ concatMap trAlt bs
        where
          trAlt b = case b of
            TACon c 0 b | isZero c -> [TALit (LitNat noRange 0) (tr b)]
            TACon c 1 b | isSuc c  ->
              case tr b of
                
                TCase 0 _ d bs' -> map sucBranch bs' ++ [nPlusKAlt 1 d]
                b -> [nPlusKAlt 1 b]
              where
                sucBranch (TALit (LitNat r i) b) = TALit (LitNat r (i + 1)) $ TLet (tInt i) b
                sucBranch alt | Just (k, b) <- nPlusKView alt =
                  nPlusKAlt (k + 1) $ TLet (tOp PAdd (TVar 0) (tInt 1)) $
                    applySubst ([TVar 1, TVar 0] ++# wkS 2 idS) b
                sucBranch _ = __IMPOSSIBLE__
                nPlusKAlt k b = TAGuard (tOp PGeq (TVar e) (tInt k)) $
                                TLet (tOp PSub (TVar e) (tInt k)) b
                str err = compactS err [Nothing]
            TACon c 1 b | isPos c ->
              case tr b of
                
                TCase 0 _ d bs -> map sub bs ++ [posAlt d]
                b -> [posAlt  b]
              where
                
                sub :: Subst TTerm a => a -> a
                sub = applySubst (TVar e :# IdS)
                posAlt b = TAGuard (tOp PGeq (TVar e) (tInt 0)) $ sub b
            TACon c 1 b | isNegSuc c ->
              case tr b of
                
                TCase 0 _ d bs -> map negsucBranch bs ++ [negAlt d]
                b -> [negAlt b]
              where
                body b   = TLet (tNegPlusK 1 (TVar e)) b
                negAlt b = TAGuard (tOp PLt (TVar e) (tInt 0)) $ body b
                negsucBranch (TALit (LitNat r i) b) = TALit (LitNat r (-i - 1)) $ body b
                negsucBranch alt | Just (k, b) <- nPlusKView alt =
                  TAGuard (tOp PLt (TVar e) (tInt (-k))) $
                  body $ TLet (tNegPlusK (k + 1) (TVar $ e + 1)) b
                negsucBranch _ = __IMPOSSIBLE__
            TACon c a b -> [TACon c a (tr b)]
            TALit l b   -> [TALit l (tr b)]
            TAGuard g b -> [TAGuard (tr g) (tr b)]
      TVar{}    -> t
      TDef{}    -> t
      TCon{}    -> t
      TPrim{}   -> t
      TLit{}    -> t
      TUnit{}   -> t
      TSort{}   -> t
      TErased{} -> t
      TError{}  -> t
      TCoerce a -> TCoerce (tr a)
      TLam b                  -> TLam (tr b)
      TApp a bs               -> TApp (tr a) (map tr bs)
      TLet e b                -> TLet (tr e) (tr b)
    inferCaseType t (TACon c _ _ : _)
      | isZero c   = t { caseType = CTNat }
      | isSuc c    = t { caseType = CTNat }
      | isPos c    = t { caseType = CTInt }
      | isNegSuc c = t { caseType = CTInt }
    inferCaseType t _ = t
    nPlusKView (TAGuard (TApp (TPrim PGeq) [TVar 0, (TLit (LitNat _ k))])
                        (TLet (TApp (TPrim PSub) [TVar 0, (TLit (LitNat _ j))]) b))
      | k == j = Just (k, b)
    nPlusKView _ = Nothing