module Binops (binops) where

import Ast
import Control.Monad (liftM,guard)
import Control.Monad.Error
import Data.List (foldl',splitAt,elemIndices,group,groupBy,sortBy,find)
import qualified Data.Map as Map
import Data.Maybe (mapMaybe)

import Text.Parsec
import ParseLib

data Assoc = L | N | R deriving (Eq,Show)

table = [ (9, R, ".")
        , (7, L, "*"), (7, L, "/"), (7, L, "mod")
        , (6, L, "+"), (6, L, "-")
        , (5, R, ":" ), (5, R, "++")
        , (4, N, "<="), (4, N, ">="), (4, N, "<")
        , (4, N, "=="), (4, N, "/="), (4, N, ">")
        , (3, R, "&&")
        , (2, R, "||")
        , (0, R, "$")
        ]

sortOps = sortBy (\(i,_,_) (j,_,_) -> compare i j)

binops term anyOp = do
  e <- term
  (ops,es) <- liftM unzip $
                    many (try $ do { whitespace
                                   ; op <- anyOp
                                   ; whitespace
                                   ; e <- term; return (op,e) })
  case binopOf Map.empty (sortOps table) ops (e:es) of
    Right e -> return e
    Left msg -> fail msg

binopSplit seen opTable i ops es =
    case (splitAt i ops, splitAt (i+1) es) of
      ((befores, op:afters), (pres, posts)) ->
          do e1 <- binopOf seen opTable befores pres
             e2 <- binopOf seen opTable afters posts
             return $ Binop op e1 e2

binopOf _ _ _ [e] = return e
binopOf seen [] ops es = binopOf seen [(9,L,head ops)] ops es

binopOf seen (tbl@((lvl, L, op):rest)) ops es =
    case elemIndices op ops of
      [] -> binopOf seen rest ops es
      is -> case Map.lookup lvl seen of
              Nothing -> binopSplit (Map.insert lvl (L,op) seen) tbl (last is) ops es
              Just (L,_) -> binopSplit seen tbl (last is) ops es
              Just (assoc,op') -> Left $ errorMessage lvl op L op' assoc
binopOf seen (tbl@((lvl, assoc, op):rest)) ops es =
    case elemIndices op ops of
      [] -> binopOf seen rest ops es
      i:_ -> case Map.lookup lvl seen of
               Nothing -> binopSplit (Map.insert lvl (assoc,op) seen) tbl i ops es
               Just (assoc',op') ->
                   if assoc == assoc' && assoc /= N then
                       binopSplit seen tbl i ops es
                   else Left $ errorMessage lvl op assoc op' assoc'

errorMessage lvl op1 assoc1 op2 assoc2 =
    "Cannot have (" ++ op1 ++ ") with [" ++ show assoc1 ++ " " ++
    show lvl ++ "] and (" ++ op2 ++ ") with [" ++ show assoc2 ++ " " ++
    show lvl ++ "]"