module Generate.Cases (caseToMatch, Match (..), Clause (..), matchSubst) where

import Control.Arrow (first)
import Control.Monad (liftM,foldM)
import Data.List (groupBy,sortBy,lookup)
import Data.Maybe (fromMaybe)

import Unique
import SourceSyntax.Location
import SourceSyntax.Literal
import SourceSyntax.Pattern
import SourceSyntax.Expression
import Transform.Substitute

caseToMatch patterns = do
  v <- newVar
  (,) v `liftM` match [v] (map (first (:[])) patterns) Fail

newVar = do n <- guid
            return $ "case" ++ show n

data Match t v
    = Match String [Clause t v] (Match t v)
    | Break
    | Fail
    | Other (LExpr t v)
    | Seq [Match t v]
      deriving Show

data Clause t v =
    Clause (Either String Literal) [String] (Match t v)
    deriving Show

matchSubst :: [(String,String)] -> Match t v -> Match t v
matchSubst _ Break = Break
matchSubst _ Fail = Fail
matchSubst pairs (Seq ms) = Seq (map (matchSubst pairs) ms)
matchSubst pairs (Other (L s e)) =
    Other . L s $ foldr ($) e $ map (\(x,y) -> subst x (Var y)) pairs
matchSubst pairs (Match n cs m) =
    Match (varSubst n) (map clauseSubst cs) (matchSubst pairs m)
        where varSubst v = fromMaybe v (lookup v pairs)
              clauseSubst (Clause c vs m) =
                  Clause c (map varSubst vs) (matchSubst pairs m)

isCon (p:ps, e) =
  case p of
    PData _ _  -> True
    PLiteral _ -> True
    _          -> False

isVar p = not (isCon p)

match :: [String] -> [([Pattern],LExpr t v)] -> Match t v -> Unique (Match t v)
match [] [] def = return def
match [] [([],e)] Fail  = return $ Other e
match [] [([],e)] Break = return $ Other e
match [] cs def = return $ Seq (map (Other . snd) cs ++ [def])
match vs@(v:_) cs def
    | all isVar cs' = matchVar vs cs' def
    | all isCon cs' = matchCon vs cs' def
    | otherwise     = matchMix vs cs' def
  where
    cs' = map (dealias v) cs

dealias v c@(p:ps, L s e) =
    case p of
      PAlias x pattern -> (pattern:ps, L s $ subst x (Var v) e)
      _ -> c

matchVar :: [String] -> [([Pattern],LExpr t v)] -> Match t v
         -> Unique (Match t v)
matchVar (v:vs) cs def = match vs (map subVar cs) def
  where
    subVar (p:ps, ce@(L s e)) = (ps, L s $ subOnePattern p e)
        where
          subOnePattern pattern e =
            case pattern of
              PVar x     -> subst x (Var v) e
              PAnything  -> e
              PRecord fs ->
                 foldr (\x -> subst x (Access (L s (Var v)) x)) e fs

matchCon :: [String] -> [([Pattern],LExpr t v)] -> Match t v
         -> Unique (Match t v)
matchCon (v:vs) cs def = (flip (Match v) def) `liftM` mapM toClause css
    where
      css = groupBy eq (sortBy cmp cs)

      cmp (p1:_,_) (p2:_,_) =
        case (p1,p2) of
          (PData n1 _, PData n2 _) -> compare n1 n2
          _ -> compare p1 p2

      eq (p1:_,_) (p2:_,_) =
        case (p1,p2) of
          (PData n1 _, PData n2 _) -> n1 == n2
          _ -> p1 == p2

      toClause cs =
        case head cs of
          (PData name _ : _, _) -> matchClause (Left name) (v:vs) cs Break
          (PLiteral lit : _, _) -> matchClause (Right lit) (v:vs) cs Break

matchClause :: Either String Literal
            -> [String]
            -> [([Pattern],LExpr t v)]
            -> Match t v
            -> Unique (Clause t v)
matchClause c (v:vs) cs def =
    do vs' <- getVars
       Clause c vs' `liftM` match (vs' ++ vs) (map flatten cs) def
    where

      flatten (p:ps, e) =
          case p of
            PData _ ps' -> (ps' ++ ps, e)
            PLiteral _  -> (ps, e)

      getVars =
          case head cs of
            (PData _ ps : _, _) -> mapM (\_ -> newVar) ps
            (PLiteral _ : _, _) -> return []

matchMix :: [String] -> [([Pattern],LExpr t v)] -> Match t v
         -> Unique (Match t v)
matchMix vs cs def = foldM (flip $ match vs) def (reverse css)
    where css = groupBy (\p1 p2 -> isCon p1 == isCon p2) cs