module Ehs.Parser
( Ehs(Plain, Embed, Bind, Let, For, If)
, parseEhses
) where
import Control.Applicative   ((<$), (<$>), (<*))
import Control.Monad
import Data.Char             (isSpace)
import Data.Maybe            (isNothing)
import Data.List             (find)
import GHC.Exts
import Language.Haskell.Meta
import Language.Haskell.TH
import Text.Parsec
import Text.Parsec.String    (Parser)

data Ehs s
  = Plain s
  | Embed Exp
  | Bind Pat Exp
  | Let [Dec]
  | For Pat Exp [Ehs s]
  | If [(Exp, [Ehs s])]
  | Zip (Ehs s, Ehs s, Ehs s)
  deriving (Eq, Show, Functor)

isStmt :: Ehs s -> Bool
isStmt (Plain _) = False
isStmt (Embed _) = False
isStmt _         = True

-- isStmt nopStmt = True
nopStmt :: Ehs s
nopStmt = Let []

-- isStmt nop = False
nop :: IsString s => Ehs s
nop = Plain ""

removeNops :: (Eq s, IsString s) => [Ehs s] -> [Ehs s]
removeNops = nestedFilter (\x -> nop /= x && nopStmt /= x)

nestedMap :: (Ehs a -> Ehs b) -> [Ehs a] -> [Ehs b]
nestedMap f = map go
  where
    go (For pat exp es) = For pat exp $ map go es
    go (If clauses) = If [(exp, map go es) | (exp, es) <- clauses]
    go x = f x

nestedFilter :: (Ehs a -> Bool) -> [Ehs a] -> [Ehs a]
nestedFilter p = go
  where
    go (For pat exp es:xs) = For pat exp (go es) : go xs
    go (If clauses:xs) = If [(exp, go es) | (exp, es) <- clauses] : go xs
    go (x:xs) = if p x then x : go xs else go xs
    go [] = []

parseEhses :: IsString s => Parser [Ehs s]
parseEhses = map (fmap fromString) . removeNops . nestedMap trim . rove . (nop :) <$> many parseEhs <* eof
  where
    trim :: Ehs String -> Ehs String
    trim (Zip (x, p@(Plain _), y))
      | isStmt x && isStmt y = fmap (lchomp . trimAfterLastEOL) p
      | isStmt x = fmap lchomp p
      | isStmt y = fmap trimAfterLastEOL p
    trim (Zip (_, x, _)) = head $ nestedMap trim [x]
    trim x = x
    rove :: [Ehs String] -> [Ehs String]
    rove = \case
      (x:y:z:r) -> Zip (x, deep y, z) : rove (y : z : r)
      [x,y] -> [Zip (x, deep y, nop)]
      [x] -> [Zip (nop, deep x, nop)]
      [] -> [Zip (nop, nop, nop)]
      where
        deep (For pat exp es) = For pat exp $ rove $ nopStmt : es ++ [nopStmt]
        deep (If clauses) = If [(exp, rove $ nopStmt : es ++ [nopStmt]) | (exp, es) <- clauses]
        deep x = x
    lchomp ('\r':'\n': r) = r
    lchomp ('\n': r) = r
    lchomp xs = xs
    trimAfterLastEOL s = reverse rfront ++ reverse (if isNothing (find (=='\n') rrear) then rrear else dropWhile (/='\n') rrear)
      where
        (rrear , rfront) = span isSpace (reverse s)

parseEhs :: Parser (Ehs String)
parseEhs
    = parsePlain
  <|> parseEmbed
  <|> parseFor
  <|> try parseBind
  <|> try parseLet
  <|> parseIf
  <?> "parse error"

parsePlain :: Parser (Ehs String)
parsePlain = do
  plain <- manyTill ("<%" <$ try (string "<%%") <|> (:[]) <$> anyChar)
    $ lookAhead (try tagHead <|> eof)
  guard (not (null plain))
  return $ Plain $ concat plain
  where
    tagHead = string "<%" >> notFollowedBy (char '%')

parseEmbed :: Parser (Ehs String)
parseEmbed = do
  try $ string "<%="
  e <- innerText $ string "%>"
  exp <- case parseExp e of
    Right exp -> return exp
    Left  err -> fail $ "<%= %>: " ++ err
  return $ Embed exp

parseBind :: Parser (Ehs String)
parseBind = do
  string "<%"
  p <- innerText $ string "<-"
  e <- innerText $ string "%>"
  pat <- case parsePat p of
    Right pat -> return pat
    Left  err -> fail $ "<% %>: " ++ err
  exp <- case parseExp e of
    Right exp -> return exp
    Left  err -> fail $ "<% %>: " ++ err
  return $ Bind pat exp

parseLet :: Parser (Ehs String)
parseLet = do
  try $ do
    string "<%"
    skipMany space
    string "let"
    skipMany1 space
  ds <- innerText $ string "%>"
  decs <- case parseDecs ds of
    Right decs -> return decs
    Left  err -> fail $ "<%let %>: " ++ err
  return $ Let decs

parseIf :: Parser (Ehs String)
parseIf = do
  try $ do
    string "<%"
    skipMany space
    string "if"
    skipMany1 space
  e <- innerText $ string "%>"
  thenExp <- case parseExp e of
    Right exp -> return exp
    Left  err -> fail $ "<%if %>: " ++ err
  thenStmts <- many parseEhs
  clauses <- manyTill parseElsif (try endTag)
  return $ If $ (thenExp, thenStmts) : clauses

parseElsif :: Parser (Exp, [Ehs String])
parseElsif = do
  try $ string "<%|"
  e <- innerText $ string "%>"
  exp <- case parseExp e of
    Right exp -> return exp
    Left  err -> fail $ "<%| %>: " ++ err
  stmts <- many parseEhs
  return (exp, stmts)

parseFor :: Parser (Ehs String)
parseFor = do
  try $ do
    string "<%"
    skipMany space
    string "for"
    skipMany1 space
  p <- innerText $ string "<-"
  e <- innerText $ string "%>"
  pat <- case parsePat p of
    Right pat -> return pat
    Left  err -> fail $ "<%for %>: " ++ err
  exp <- case parseExp e of
    Right exp -> return exp
    Left  err -> fail $ "<%for %>: " ++ err
  stmts <- manyTill parseEhs (try endTag)
  return $ For pat exp stmts

innerText :: Parser a -> Parser String
innerText = followedBy ("%>" <$ try (string "%%>") <|> (:[]) <$> anyChar)
  where
    followedBy p e = concat <$> manyTill p (try e)

endTag :: Parser ()
endTag = do
  string "<%"
  skipMany space
  string "end"
  skipMany space
  string "%>"
  return ()