module Text.HPaco.AST.AST where

import Text.HPaco.AST.Statement
import Text.HPaco.AST.Identifier (Identifier)
import Data.Maybe (fromMaybe)

type Def = (Identifier, Statement)

data AST =
    AST { astRootStatement :: Statement
        , astDefs :: [Def]
        , astDeps :: [String]
        }
        deriving (Show)

defAST :: AST
defAST =
    AST { astRootStatement = NullStatement
        , astDefs = []
        , astDeps = []
        }

walkStatement :: (Statement -> Maybe a) -> Statement -> [a]
walkStatement f stmt =
    case f stmt of
        Just x -> [x]
        Nothing -> recurse
    where recurse =
            case stmt of
                StatementSequence stmts -> concatMap (walkStatement f) stmts
                IfStatement _ tb fb -> walkStatement f tb ++ walkStatement f fb
                LetStatement _ _ s -> walkStatement f s
                ForStatement _ _ _ s -> walkStatement f s
                SwitchStatement _ branches -> concat [ walkStatement f s | (e, s) <- branches ]
                otherwise -> []

mapStatement :: (Statement -> Maybe Statement) -> Statement -> Statement
mapStatement f stmt =
    fromMaybe recurse (f stmt)
    where
        g = mapStatement f
        recurse =
            case stmt of
                StatementSequence stmts -> StatementSequence $ map g stmts
                IfStatement e tb fb -> IfStatement e (g tb) (g fb)
                LetStatement a b s -> LetStatement a b (g s)
                ForStatement a b c s -> ForStatement a b c (g s)
                SwitchStatement e branches -> SwitchStatement e [ (c, g s) | (c, s) <- branches ]
                otherwise -> stmt