module Language.Pointwise.Matching where

import Language.Pointwise.Syntax
import Data.Generics hiding (Unit,(:*:),Inl,Inr)
import Control.Monad
import Control.Monad.State
import Generics.Pointless.Combinators
import Data.List

-- Patterm matching elimination

-- Does not detect repeated variables in patterns
-- It can be improved in the pair case: does not detect equalities up to alpha
-- conversion, neither renanes variables appearing both at the pattern and the
-- term subject to matching

getVar :: String -> StateT Int Maybe String
getVar x = do seed <- get
	      modify (+1)
	      return (x ++ (show seed))
 
nomatch :: Term -> StateT Int Maybe Term
nomatch (Match e []) = fail "Should not hapen given the premises"
nomatch (Match e [(Var x, rhs)]) =
    do rhs' <- nomatch rhs
       return (subst [(x,e)] rhs')
nomatch (Match e [(Unit, rhs)]) = 
    nomatch rhs
nomatch (Match e l) 
    | isPair (fst (head l)) =
	do guard (all isPair (map fst l))
	   guard (all null (map (\(x :&: y, _) -> free x `intersect` free y) l))
	   guard (null (free e `intersect` concat (map free (map fst l))))
	   let aux = mygroup (\x y -> pwFst (fst x) == pwFst (fst y)) l
	       left = map (pwFst . fst . head) aux
	       rightmatch = map (nomatch . (Match (Snd e)) . map (pwSnd . fst /\ snd)) aux
	   right <- sequence rightmatch
           nomatch (Match (Fst e) (zip left right))
nomatch (Match e l)
    | isInlr (fst (head l)) =
	do guard (all isInlr (map fst l))
	   let left' = map (\ (Inl t, p) -> (t,p)) (filter (isInl . fst) l)
	       right' = map (\ (Inr t, p) -> (t,p)) (filter (isInr . fst) l)
	   var <- getVar "_v"
	   left <- nomatch (Lam var (Match (Var var) left'))
	   right <- nomatch (Lam var (Match (Var var) right'))
	   return (Case e left right)
nomatch (Match e l)
    | isIn (fst (head l)) =
	do guard (all isIn (map fst l))
	   let pats = map (\(In t,p) -> (t,p)) l
	   nomatch (Match (Out e) pats)
nomatch (Match t alts) = 
    fail "Not possible to eliminate pattern matching!"
nomatch e = gmapM (mkM nomatch) e


   
-- Auxiliary definitions

mygroup :: Eq a => (a -> a -> Bool) -> [a] -> [[a]]
mygroup f [] = []
mygroup f (h:t) = (h : filter (\x -> f x h) t):(mygroup f (filter (\x -> not (f x h)) t))