module Text.CTPL where

import Control.Monad
import Data.Char
import Data.Monoid
import Text.Chatty.Parser
import Text.Chatty.Parser.Carrier
import qualified Text.CTPL0 as Null

data Procedure = Procedure { procName :: String, procInstr :: Instruction, procAddr :: Int }
data CTPLState = CTPLState { definedProcs :: [Procedure] }
data Exec a = Succ a | NoSuchProc String | SyntaxFault deriving Show
data CTPL a = CTPL { runCTPL :: CTPLState -> Exec (a, CTPLState, String -> String) }

instance Monad Exec where
  return a = Succ a
  (Succ a) >>= f = f a
  (NoSuchProc s) >>= f = NoSuchProc s
  SyntaxFault >>= f = SyntaxFault

instance Monad CTPL where
  return a = CTPL $ \s -> return (a, s, id)
  m >>= f = CTPL $ \s -> do
                           (a', s', f') <- runCTPL m s
                           (a'', s'', f'') <- runCTPL (f a') s'
                           return (a'', s'', f' . f'')

getState :: (CTPLState -> a) -> CTPL a
getState f = CTPL $ \s -> return (f s, s, id)

modState :: (CTPLState -> CTPLState) -> CTPL ()
modState f = CTPL $ \s -> return ((), f s, id)

emit :: String -> CTPL ()
emit str = CTPL $ \s -> return ((), s, (str++))

getProc :: String -> (Procedure -> a) -> CTPL a
getProc nm f = do
  ps <- getState definedProcs
  case filter (\p -> procName p == nm) ps of
    [] -> CTPL $ \_ -> NoSuchProc nm
    [p] -> return (f p)

catchEmission :: CTPL () -> CTPL String
catchEmission m = CTPL $ \s -> do
  (_,s',f') <- runCTPL m s
  return (f' [], s', id)

discardEmission :: CTPL a -> CTPL a
discardEmission m = CTPL $ \s -> do
  (a,s',_) <- runCTPL m s
  return (a, s', id)

data NumSource = AX Int | CK0 Int | Buf NumSource Int | Const Int | Len Int | CP Int deriving Show
data Instruction = SetAX NumSource
                 | SetCK0 NumSource
                 | SetBuf NumSource NumSource
                 | Call String
                 | PopAX
                 | PopCK
                 | PushAX
                 | PushCK
                 | AXToTape
                 | CK0ToTape
                 | AXFromTape
                 | CK0FromTape
                 | PushTape
                 | PopTape
                 | Walk NumSource
                 | Remove NumSource
                 | Insert NumSource String
                 | Return
                 | IfThenElse Condition Instruction Instruction
                 | UntilDo Condition Instruction
                 | Compound [Instruction]
                 deriving Show  
data Condition = EqAX NumSource
               | EqCK0 NumSource
               | LtAX NumSource
               | LtCK0 NumSource
               | GtAX NumSource
               | GtCK0 NumSource
               | IsUpper NumSource
               | IsLower NumSource
               | IsDigit NumSource
               | IsEob NumSource
               | Negation Condition
               | EqCh NumSource [Char]
               deriving Show

allocProcs :: Int -> [Procedure] -> CTPL [Procedure]
allocProcs cur [] = return []
allocProcs cur (Procedure n i _:ps) = do
  sp <- predictSpace i
  ps' <- allocProcs (cur+sp) ps
  return (Procedure n i cur : ps')

-- Always include sign! Don't dump anything if 0
dumpIncop :: String -> Int -> CTPL ()
dumpIncop p 0 = return ()
dumpIncop p i
  | i < 0 = emit (p ++ show i)
  | i > 0 = emit (p ++ ['+'] ++ show i)

-- Pad to four digits. That should be sufficient unless you write crazily huge scripts...
dumpAddr :: Int -> CTPL ()
dumpAddr i = emit $ let s0 = show i in replicate (4-length s0) '0' ++ s0

predictSpace :: Instruction -> CTPL Int
predictSpace = liftM length . catchEmission . dumpInstr

splitStr :: String -> [String]
splitStr [] = []
splitStr ('$':ss) = "$" : splitStr ss
splitStr str =
  let t = takeWhile (/='$') str
  in if t == str
     then [t]
     else t : "$" : splitStr (tail $ dropWhile (/='$') str)

arepl i
  | i >= 0 = replicate i '>'
  | otherwise = replicate (-i) '<'

dumpInstr :: Instruction -> CTPL ()
dumpInstr instr = case instr of
  -- +#, -#
  SetAX (AX i) -> dumpIncop [] i
  -- Dd
  SetAX (CK0 i) -> emit "Dd" >> dumpIncop [] i
  -- l
  SetAX (Buf (CP j) i) -> emit (arepl j) >> emit "l" >> dumpIncop [] i
  -- Split
  SetAX (Buf d i) -> dumpInstr (SetAX d) >> emit "ml" >> dumpIncop [] i
  -- Q0+#
  SetAX (Const i) -> emit "Q0" >> dumpIncop [] i
  -- Q0+7Je>0-7J!eQ
  SetAX (Len i) -> emit "Q0+7Je>0-7J!eQ" >> dumpIncop [] i
  -- Q
  SetAX (CP i) -> emit "Q" >> dumpIncop [] i
  -- C+#, C-#
  SetCK0 (CK0 i) -> dumpIncop "C" i
  -- kd
  SetCK0 (AX i) -> emit "kd" >> dumpIncop "C" i
  -- Cl
  SetCK0 (Buf (CP j) i) -> emit (arepl j) >> emit "Cl" >> dumpIncop "C" i
  -- Split
  SetCK0 (Buf d i) -> dumpInstr (SetCK0 d) >> emit "CmCl" >> dumpIncop "C" i
  -- C0C+#
  SetCK0 (Const i) -> emit "C0" >> dumpIncop "C" i
  -- C0C+11CJe>C0C-11CJ!eCQ (yeeees, this iiiis long...)
  SetCK0 (Len i) -> emit "C0C+11CJe>C0C-11CJ!eCQ" >> dumpIncop "C" i
  -- CQ
  SetCK0 (CP i) -> emit "CQ" >> dumpIncop "C" i
  -- Cd [ldck0] Csk
  SetBuf (CP i) f -> emit "Cd" >> dumpInstr (SetCK0 f) >> emit (arepl i) >> emit "Csk"
  -- Cd [ldck0] i/<CsyxkCd [ldck0] CmkPpx
  SetBuf d f -> emit "Cd" >> dumpInstr (SetCK0 f) >> emit "i/<CsyxkCd" >> dumpInstr (SetCK0 d) >> emit "CmkPpx"
  -- d0+#ct
  Call str -> do
    a <- getProc str procAddr
    emit "d0+"
    dumpAddr a
    emit "ct"
  -- D
  PopAX -> emit "D"
  -- k
  PopCK -> emit "k"
  -- d
  PushAX -> emit "d"
  -- Cd
  PushCK -> emit "Cd"
  -- f
  Return -> emit "f"
  -- i/<s
  AXToTape -> emit "i/<s"
  -- lx
  AXFromTape -> emit "lx"
  -- i/<Cs
  CK0ToTape -> emit "i/<Cs"
  -- Clx
  CK0FromTape -> emit "Clx"
  -- x
  Remove (CP i) -> emit (arepl i) >> emit "x"
  -- Cd[ldbuf]Cmkx
  Remove d -> emit "Cd" >> dumpInstr (SetCK0 d) >> emit "Cmkx"
  -- Nothing :)
  Walk (CP i) -> emit (arepl i)
  -- Cd[ldbuf]Cmk
  Walk d -> emit "Cd" >> dumpInstr (SetCK0 d) >> emit "Cmk"
  -- yx
  PushTape -> emit "yx"
  -- Pp<
  PopTape -> emit "Pp<"
  -- I...$
  Insert (CP i) str -> emit (arepl i) >> forM_ (splitStr str) (\s -> if s=="$" then emit "i$" else emit ('I':s++"$"))
  -- Cd[ldbuf]CmkI...$
  Insert d str -> emit "Cd" >> dumpInstr (SetCK0 d) >> emit "Cmk" >> forM_ (splitStr str) (\s -> if s=="$" then emit "i$" else emit ('I':s++"$"))
  -- condInit CdC0C+#CJCk cond condClean elseBranch d0+#Jt condClean ifBranch dD
  IfThenElse c y n -> do
    condClean <- discardEmission (dumpCond c "")
    elseBranch <- catchEmission (emit condClean >> dumpInstr n)
    ifBranch <- catchEmission (emit condClean >> dumpInstr y >> emit "d")
    elseBranch2 <- catchEmission (emit "d0" >> dumpIncop [] (length ifBranch) >> emit "Jt")
    jumper <- catchEmission (emit "CdC0" >> dumpIncop "C" (length elseBranch+length elseBranch2) >> emit "CJCk")
    dumpCond c jumper
    emit elseBranch
    emit elseBranch2
    emit ifBranch
    emit "D"
  -- dD condInit CdC0C+#CJCk cond condClean body d0-#Jt condClean
  UntilDo c b -> do
    condClean <- discardEmission (dumpCond c "")
    bodyPrev <- catchEmission (emit condClean >> dumpInstr b >> emit "d0-XXXXJt")
    jumper <- catchEmission (emit "CdC0" >> dumpIncop "C" (length bodyPrev) >> emit "CJCk")
    cond <- catchEmission (dumpCond c jumper >> return ())
    emit "dD"
    emit cond
    emit condClean
    dumpInstr b
    emit "d0-"
    dumpAddr (length cond + length bodyPrev + 1)
    emit "Jt"
    emit condClean
  -- Dispatch Compound
  Compound is -> mapM_ dumpInstr is

dumpCond :: Condition -> String -> CTPL String
dumpCond cond jumper = case cond of
  -- jt (pointless!? -- oh wait, we'll use this for the true const :))
  EqAX (AX 0) -> emit jumper >> emit "t" >> return ""
  LtAX (AX i) | i > 0 -> emit jumper >> emit "t" >> return ""
  GtAX (AX i) | i < 0 -> emit jumper >> emit "t" >> return ""
  EqCK0 (CK0 0) -> emit jumper >> emit "t" >> return ""
  LtCK0 (CK0 i) | i > 0 -> emit jumper >> emit "t" >> return ""
  GtCK0 (CK0 i) | i < 0 -> emit jumper >> emit "t" >> return ""
  -- j!t (even more pointless!? -- we'll use it for the false const though :))
  EqAX (AX _) -> emit jumper >> emit "!t" >> return ""
  LtAX (AX i) | i <= 0 -> emit jumper >> emit "!t" >> return ""
  GtAX (AX i) | i >= 0 -> emit jumper >> emit "!t" >> return ""
  EqCK0 (CK0 _) -> emit jumper >> emit "!t" >> return ""
  LtCK0 (CK0 i) | i <= 0 -> emit jumper >> emit "!t" >> return ""
  GtCK0 (CK0 i) | i >= 0 -> emit jumper >> emit "!t" >> return ""
  -- -#j=   +#
  EqAX (CK0 i) -> dumpIncop [] (-i) >> emit jumper >> emit "=" >> catchEmission (dumpIncop [] i)
  LtAX (CK0 i) -> dumpIncop [] (-i) >> emit jumper >> emit "<" >> catchEmission (dumpIncop [] i)
  GtAX (CK0 i) -> dumpIncop [] (-i) >> emit jumper >> emit ">" >> catchEmission (dumpIncop [] i)
  EqCK0 (AX i) -> dumpIncop [] i >> emit jumper >> emit "=" >> catchEmission (dumpIncop [] (-i))
  LtCK0 (AX i) -> dumpIncop [] i >> emit jumper >> emit ">" >> catchEmission (dumpIncop [] (-i))
  GtCK0 (AX i) -> dumpIncop [] i >> emit jumper >> emit "<" >> catchEmission (dumpIncop [] (-i))
  -- -#jz   +#
  EqAX (Const i) -> dumpIncop [] (-i) >> emit jumper >> emit "z" >> catchEmission (dumpIncop [] i)
  -- CdC0C+#j<  k
  LtAX (Const i) -> emit "CdC0" >> dumpIncop "C" i >> emit jumper >> emit "<" >> return "k"
  GtAX (Const i) -> emit "CdC0" >> dumpIncop "C" i >> emit jumper >> emit ">" >> return "k"
  -- C-#jCz  C+#
  EqCK0 (Const i) -> dumpIncop "C" (-i) >> emit jumper >> emit "Cz" >> catchEmission (dumpIncop "C" i)
  -- i/<s0+#j>  lx
  LtCK0 (Const i) -> emit "i/<s0" >> dumpIncop [] i >> emit jumper >> emit ">" >> return "lx"
  GtCK0 (Const i) -> emit "i/<s0" >> dumpIncop [] i >> emit jumper >> emit "<" >> return "lx"
  -- Cd[ldbuf]CmClC+#j=  k
  EqAX (Buf d i) -> emit "Cd" >> dumpInstr (SetCK0 d) >> emit "CmCl" >> dumpIncop "C" i >> emit jumper >> emit "=" >> return "k"
  LtAX (Buf d i) -> emit "Cd" >> dumpInstr (SetCK0 d) >> emit "CmCl" >> dumpIncop "C" i >> emit jumper >> emit "<" >> return "k"
  GtAX (Buf d i) -> emit "Cd" >> dumpInstr (SetCK0 d) >> emit "CmCl" >> dumpIncop "C" i >> emit jumper >> emit ">" >> return "k"
  -- Cd[ldbuf]Cmi/<ks>l+#j=  <lx
  EqCK0 (Buf d i) -> emit "Cd" >> dumpInstr (SetCK0 d) >> emit "Cmi/<ks>l" >> dumpIncop [] i >> emit jumper >> emit "=" >> return "<lx"
  LtCK0 (Buf d i) -> emit "Cd" >> dumpInstr (SetCK0 d) >> emit "Cmi/<ks>l" >> dumpIncop [] i >> emit jumper >> emit "<" >> return "<lx"
  GtCK0 (Buf d i) -> emit "Cd" >> dumpInstr (SetCK0 d) >> emit "Cmi/<ks>l" >> dumpIncop [] i >> emit jumper >> emit ">" >> return "<lx"
  -- d0+7Je>0-7J!eD-#jE  +#
  EqAX (Len i) -> emit "d0+7Je>0-7J!eD" >> dumpIncop [] (-i) >> emit jumper >> emit "E" >> catchEmission (dumpIncop [] i)
  LtAX (Len i) -> emit "d0+7Je>0-7J!eD" >> dumpIncop [] (-i) >> emit jumper >> emit "g" >> catchEmission (dumpIncop [] i)
  GtAX (Len i) -> emit "d0+7Je>0-7J!eD" >> dumpIncop [] (-i) >> emit jumper >> emit "l" >> catchEmission (dumpIncop [] i)
  -- CdC0C+11CJe>C0C-11CJ!ekC-#jCE  C+#  (no, this is noooot long :p)
  EqCK0 (Len i) -> emit "CdC0C+11CJe>C0C-11CJ!ek" >> dumpIncop "C" (-i) >> emit jumper >> emit "CE" >> catchEmission (dumpIncop "C" i)
  LtCK0 (Len i) -> emit "CdC0C+11CJe>C0C-11CJ!ek" >> dumpIncop "C" (-i) >> emit jumper >> emit "Cg" >> catchEmission (dumpIncop "C" i)
  GtCK0 (Len i) -> emit "CdC0C+11CJe>C0C-11CJ!ek" >> dumpIncop "C" (-i) >> emit jumper >> emit "Cl" >> catchEmission (dumpIncop "C" i)
  -- Cd[ldbuf]CmkjU, Cd[ldbuf]CmkjL, Cd[ldbuf]CmkjN
  IsUpper (CP i) -> emit (arepl i) >> emit jumper >> emit "U" >> return ""
  IsUpper d -> emit "Cd" >> dumpInstr (SetCK0 d) >> emit "Cmk" >> emit jumper >> emit "U" >> return ""
  IsLower (CP i) -> emit (arepl i) >> emit jumper >> emit "L" >> return ""
  IsLower d -> emit "Cd" >> dumpInstr (SetCK0 d) >> emit "Cmk" >> emit jumper >> emit "L" >> return ""
  IsDigit (CP i) -> emit (arepl i) >> emit jumper >> emit "N" >> return ""
  IsDigit d -> emit "Cd" >> dumpInstr (SetCK0 d) >> emit "Cmk" >> emit jumper >> emit "N" >> return ""
  IsEob (CP i) -> emit (arepl i) >> emit jumper >> emit "e" >> return ""
  IsEob d -> emit "Cd" >> dumpInstr (SetCK0 d) >> emit "Cmk" >> emit jumper >> emit "e" >> return ""
  -- j!
  Negation c -> dumpCond c (jumper++"!")
  -- j|q#!t
  EqCh (CP i) chs -> emit (arepl i) >> emit jumper >> emit (foldr q "!t" chs) >> return "" where q a b = "|q"++[a]++b
  -- Cd[ldbuf]Cmkj|q#!t
  EqCh d chs -> emit "Cd" >> dumpInstr (SetCK0 d) >> emit "Cmk" >> emit jumper >> emit (foldr q "!t" chs) >> return ""  where q a b = "|q"++[a]++b

compile :: [Procedure] -> Instruction -> Exec String
compile ps main = do
  let s0 = CTPLState (ps++[Procedure [] main 0])
  (_, _, out) <- flip runCTPL s0 $ do
    ps <- getState definedProcs
    ps' <- allocProcs 7 ps
    modState $ \s -> s{definedProcs=ps'}
    amain <- getProc [] procAddr
    emit "+"
    dumpAddr amain
    emit "jt"
    forM_ ps' $ \p -> dumpInstr $ procInstr p
  return $ out []

multiParse :: ChParser m => m ([Procedure], [Instruction])
multiParse = do
  let parseStep = liftM (\p -> ([p],[])) parseProc ??? liftM (\i -> ([],[i])) parseInstr
  parts <- many parseStep
  return $ mconcat parts

parseInstr :: ChParser m => m Instruction
parseInstr = parseAssignment
             ??? parseReturn
             ??? parseInsert
             ??? parseCall
             ??? parseIf
             ??? parseLoop
             ??? parseFor
             ??? parseWalk
             ??? parseRemove

parseString :: ChParser m => m String
parseString = do
  many white
  match '"'
  let char = do
        k <- request
        case k of
          '"' -> pabort
          '\\' -> request
          k -> return k
  cs <- many char
  match '"'
  return cs

parseReturn :: ChParser m => m Instruction
parseReturn = do
  many white
  matchs "return"
  many white
  match ';'
  return Return

parseInsert :: ChParser m => m Instruction
parseInsert = do
  many white
  matchs "insert"
  cs <- parseString
  many white
  matchs "at"
  many white
  match '['
  ns <- parseNumSource
  many white
  match ']'
  many white
  match ';'
  return $ Insert ns cs

parseRemove :: ChParser m => m Instruction
parseRemove = do
  many white
  matchs "remove";
  many white
  match '[';
  ns <- parseNumSource
  many white
  match ']'
  many white
  match ';'
  return $ Remove ns

parseWalk :: ChParser m => m Instruction
parseWalk = do
  many white
  match '['
  ns <- parseNumSource
  many white
  match ']'
  many white
  match ';'
  return $ Walk ns

parseCall :: ChParser m => m Instruction
parseCall = do
  many white
  matchs "call";
  many white
  nm <- (:) `liftM` alpha `ap` many anum
  many white
  match ';'
  return $ Call nm

parseAssignment :: ChParser m => m Instruction
parseAssignment = do
  many white
  target <- (matchs "AX" >> return SetAX)
            ??? (matchs "CK0" >> return SetCK0)
            ??? do
              match '['
              s <- parseNumSource
              many white
              match ']'
              return $ SetBuf s
  many white
  match '='
  ns <- parseNumSource
  many white
  match ';'
  return $ target ns

parseNumSource :: ChParser m => m NumSource
parseNumSource = do
  let num = (many white >> match '-' >> liftM negate number)
            ?? number
      summand = (many white >> match '-' >> liftM negate number)
                ?? (many white >> match '+' >> number)
                ?? (return 0)
      ax = many white >> matchs "AX" >> liftM AX summand
      ck0 = many white >> matchs "CK0" >> liftM CK0 summand
      len = many white >> matchs "LEN" >> liftM Len summand
      clt = many white >> match '\\' >> liftM (Const . ord) request
      cst = liftM Const num
      cp = many white >> matchs "CP" >> liftM CP summand
      buf = do
        many white
        match '['
        n <- parseNumSource
        many white
        match ']'
        liftM (Buf n) summand
  ax ??? ck0 ??? cp ??? cst ??? clt ??? buf ??? len

parseProc :: ChParser m => m Procedure
parseProc = do
  many white
  matchs "proc"
  some white
  nm <- liftM (:) alpha `ap` many anum
  many white
  match '{'
  is <- many parseInstr
  many white
  match '}'
  let is' = PopAX : is ++ [Return]
  return $ Procedure nm (Compound is') 0

parseIf :: ChParser m => m Instruction
parseIf = do
  many white
  matchs "if"
  c <- parseCond
  many white
  match '{'
  ifBranch <- many parseInstr
  many white
  match '}'
  let parseElse = do
        many white
        matchs "else"
        many white
        match '{'
        elseBranch <- many parseInstr
        many white
        match '}'
        return elseBranch
  elseBranch <- parseElse ??? return []
  return $ IfThenElse c (Compound ifBranch) (Compound elseBranch)

parseLoop :: ChParser m => m Instruction
parseLoop = do
  many white
  lt <- (matchs "until" >> return UntilDo) ??? (matchs "while" >> return (UntilDo . Negation))
  c <- parseCond
  many white
  match '{'
  body <- many parseInstr
  many white
  match '}'
  return $ lt c $ Compound body

parseFor :: ChParser m => m Instruction
parseFor = do
  many white
  matchs "for"
  start <- parseNumSource
  many white
  (step, cond) <- (matchs "to" >> return (1, GtAX)) ??? (matchs "downto" >> return (-1, LtAX))
  end <- parseNumSource
  many white
  match '{'
  body <- many parseInstr
  many white
  match '}'
  return $ Compound [PushCK, SetCK0 end, AXToTape, PushTape, CK0ToTape, PushTape, PopCK, PushCK, SetAX start, PopTape, CK0FromTape, PopTape,
                     UntilDo (cond $ CK0 0) $ Compound ([PushAX, AXFromTape]++body++[SetCK0 $ CK0 step, AXToTape, PopAX]), AXFromTape, PopCK]

parseCond :: ChParser m => m Condition
parseCond = parseNot ??? parseIsLower ??? parseEq ??? parseEqCh ??? parseLt ??? parseGt
  where
    parseIsLower = do
      many white
      ctr <- (matchs "lowercase?" >> return IsLower)
             ??? (matchs "uppercase?" >> return IsUpper)
             ??? (matchs "digit?" >> return IsDigit)
             ??? (matchs "end?" >> return IsEob)
      many white
      match '['
      ns <- parseNumSource
      many white
      match ']'
      return $ ctr ns
    parseNot = do
      many white
      matchs "not"
      Negation `liftM` parseCond
    parseEq = do
      many white
      reg <- (matchs "AX" >> return EqAX) ??? (matchs "CK0" >> return EqCK0)
      many white
      match '='
      ns <- parseNumSource
      return $ reg ns
    parseLt = do
      many white
      reg <- (matchs "AX" >> return LtAX) ??? (matchs "CK0" >> return LtCK0)
      many white
      match '<'
      ns <- parseNumSource
      return $ reg ns
    parseGt = do
      many white
      reg <- (matchs "AX" >> return GtAX) ??? (matchs "CK0" >> return GtCK0)
      many white
      match '>'
      ns <- parseNumSource
      return $ reg ns
    parseEqCh = do
      many white
      match '['
      ns <- parseNumSource
      many white
      match ']'
      many white
      matchs "in"
      chs <- parseString
      return $ EqCh ns chs

parse :: String -> [] ([Procedure], [Instruction])
parse s = runCarrierT s multiParse

compileCTPL :: String -> Exec String
compileCTPL s = case parse s of
  [] -> SyntaxFault
  (ps, is):_ -> compile ps (Compound is)

evalCTPL :: String -> String -> Int -> Null.Exec String
evalCTPL program buffer limit =
  case compileCTPL program of
    Succ bc -> Null.evalCTPL0 bc buffer limit
    _ -> Null.SynViol