{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

-- |
-- Module      : Jikka.CPlusPlus.Convert.InlineSetAt
-- Description : does inline expansion of @set_at@ function. / @set_at@ 関数を inline 展開します。
-- Copyright   : (c) Kimiyuki Onaka, 2020
-- License     : Apache License 2.0
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
module Jikka.CPlusPlus.Convert.InlineSetAt
  ( run,
  )
where

import Control.Monad.Writer.Strict
import Jikka.CPlusPlus.Language.Expr
import Jikka.CPlusPlus.Language.Util
import Jikka.Common.Alpha
import Jikka.Common.Error

runExpr :: (MonadAlpha m, MonadWriter [Statement] m) => Expr -> m Expr
runExpr :: Expr -> m Expr
runExpr = \case
  Call (SetAt Type
t) [Expr
xs, Expr
i, Expr
x] -> do
    VarName
y <- case Expr
xs of
      Var (VarName String
xs) -> NameKind -> String -> m VarName
forall (m :: * -> *).
MonadAlpha m =>
NameKind -> String -> m VarName
renameVarName NameKind
LocalNameKind String
xs
      Expr
_ -> NameKind -> m VarName
forall (m :: * -> *). MonadAlpha m => NameKind -> m VarName
newFreshName NameKind
LocalNameKind
    [Statement] -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
      [ Type -> VarName -> DeclareRight -> Statement
Declare (Type -> Type
TyVector Type
t) VarName
y (Expr -> DeclareRight
DeclareCopy Expr
xs),
        AssignExpr -> Statement
Assign (AssignOp -> LeftExpr -> Expr -> AssignExpr
AssignExpr AssignOp
SimpleAssign (LeftExpr -> Expr -> LeftExpr
LeftAt (VarName -> LeftExpr
LeftVar VarName
y) Expr
i) Expr
x)
      ]
    Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (VarName -> Expr
Var VarName
y)
  Expr
e -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Expr
e

runStatement :: MonadAlpha m => Statement -> m [Statement]
runStatement :: Statement -> m [Statement]
runStatement Statement
stmt = do
  (Statement
stmt, [Statement]
decls) <- WriterT [Statement] m Statement -> m (Statement, [Statement])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT ((Expr -> WriterT [Statement] m Expr)
-> Statement -> WriterT [Statement] m Statement
forall (m :: * -> *).
Monad m =>
(Expr -> m Expr) -> Statement -> m Statement
mapDirectExprStatementM ((Expr -> WriterT [Statement] m Expr)
-> Expr -> WriterT [Statement] m Expr
forall (m :: * -> *). Monad m => (Expr -> m Expr) -> Expr -> m Expr
mapSubExprM Expr -> WriterT [Statement] m Expr
forall (m :: * -> *).
(MonadAlpha m, MonadWriter [Statement] m) =>
Expr -> m Expr
runExpr) Statement
stmt)
  [Statement] -> m [Statement]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Statement] -> m [Statement]) -> [Statement] -> m [Statement]
forall a b. (a -> b) -> a -> b
$ [Statement]
decls [Statement] -> [Statement] -> [Statement]
forall a. [a] -> [a] -> [a]
++ [Statement
stmt]

runProgram :: MonadAlpha m => Program -> m Program
runProgram :: Program -> m Program
runProgram = (Expr -> m Expr)
-> (Statement -> m [Statement]) -> Program -> m Program
forall (m :: * -> *).
Monad m =>
(Expr -> m Expr)
-> (Statement -> m [Statement]) -> Program -> m Program
mapExprStatementProgramM Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Statement -> m [Statement]
forall (m :: * -> *). MonadAlpha m => Statement -> m [Statement]
runStatement

-- | `run` does inline expansion of @jikka::set_at<T>(...)@ function.
--
-- == Examples
--
-- Before:
--
-- > func(jikka::set_at<T>(xs, i, x));
--
-- After:
--
-- > vector<int> ys = xs;
-- > ys[i] = x;
-- > func(ys);
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: Program -> m Program
run Program
prog = String -> m Program -> m Program
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.CPlusPlus.Convert.InlineSetAt" (m Program -> m Program) -> m Program -> m Program
forall a b. (a -> b) -> a -> b
$ do
  Program -> m Program
forall (m :: * -> *). MonadAlpha m => Program -> m Program
runProgram Program
prog