{-# LANGUAGE TemplateHaskell, QuasiQuotes, PatternGuards, DoAndIfThenElse #-}

module Data.TrieMap.Representation.TH (genRepr, genOrdRepr) where

import Data.TrieMap.Modifiers
import Data.TrieMap.Rep
import Data.TrieMap.Rep.Instances ()
import Language.Haskell.TH
import Language.Haskell.TH.ExpandSyns

data ToRepCase = ToRepCase [Pat] Exp
data FromRepCase = FromRepCase Pat [Exp]
type ToRep = [ToRepCase]
type FromRep = [FromRepCase]

type Representation = (Type, ToRep, FromRep)

-- | Given a type with an associated 'Ord' instance, generates a representation that will cause its 'TMap'
-- implementation to be essentially equivalent to "Data.Map".
genOrdRepr :: Name -> Q [Dec]
genOrdRepr tycon = do
	TyConI dec <- reify tycon
	let theTyp = foldl AppT (ConT tycon) . map tyVarBndrType
	case dec of
		DataD cxt _ tyvars _ _ -> do
			repr <- ordRepr (theTyp tyvars)
			return (decsForRepr cxt (theTyp tyvars) repr)
		NewtypeD cxt _ tyvars _ _ -> do
			repr <- ordRepr (theTyp tyvars)
			return (decsForRepr cxt (theTyp tyvars) repr)
		_	-> fail ("Cannot generate Repr instance for " ++ pprint dec)

ordRepr :: Type -> Q Representation
ordRepr t0 = do
	x <- newName "x"
	return (ConT ''Ordered `AppT` t0, 
			[ToRepCase [VarP x] (ConE 'Ord `AppE` VarE x)],
			[FromRepCase (ConP 'Ord [VarP x])
				[VarE x]])
	

-- | Given the name of a type constructor, automatically generates an efficient 'Repr' instance.
genRepr :: Name -> Q [Dec]
genRepr tycon = do
	TyConI dec <- reify tycon
	let theTyp = foldl AppT (ConT tycon) . map tyVarBndrType
	case dec of
		DataD cxt _ tyvars cons _ -> do
			conReprs <- mapM conRepr cons
			return (decsForRepr cxt (theTyp tyvars) (foldr1 union conReprs))
		NewtypeD cxt _ tyvars con _ -> do
			theConRepr <- conRepr con
			return (decsForRepr cxt (theTyp tyvars) theConRepr)
		_	-> fail ("Cannot generate Repr instance for " ++ pprint dec)

tyVarBndrType :: TyVarBndr -> Type
tyVarBndrType (PlainTV tyvar) = VarT tyvar
tyVarBndrType (KindedTV tyvar _) = VarT tyvar

decsForRepr :: Cxt -> Type -> Representation -> [Dec]
decsForRepr cxt t (tRep, toR, fromR) = [
		InstanceD cxt (ConT ''Repr `AppT` t)
			[TySynInstD ''Rep [t] tRep,
			 FunD 'toRep
				[Clause pats (NormalB e) [] | ToRepCase pats e <- toR],
			 FunD 'fromRep
				[Clause [pat] (NormalB e) [] | FromRepCase pat [e] <- fromR]]]

decompose :: Type -> (Type, [Type])
decompose (tyfun `AppT` ty) = case decompose tyfun of
	(tyfun, tys)	-> (tyfun, tys ++ [ty])
decompose ty = (ty, [])

type ReprM = Q

conRepr :: Con -> ReprM Representation
conRepr (RecC con args) = conRepr (NormalC con [(strict, typ) | (_, strict, typ) <- args])
conRepr (InfixC t1 con t2) = conRepr (NormalC con [t1, t2])
conRepr (NormalC con []) = return $ conify con unit
conRepr (NormalC con args) = do
	argCons <- mapM (typeRepr . snd) args
	return (conify con (foldr1 prod argCons))
conRepr con = fail ("Cannot generate representation for existential constructor " ++ pprint con)

typeRepr :: Type -> ReprM Representation
typeRepr t00 = expandSyns t00 >>= \ t0 -> case decompose t0 of
	(ListT, [t])	-> do
		(tRep, toR, fromR) <- typeRepr t
		xs <- newName "elems"
		x <- newName "el"
		xsRep <- newName "elemReps"
		xRep <- newName "elemRep"
		return (ListT `AppT` tRep,
			[ToRepCase [VarP xs] 
				(CompE [BindS (VarP x) (VarE xs),
					NoBindS (CaseE (VarE x) [Match pat (NormalB e) [] | ToRepCase [pat] e <- toR])])],
			[FromRepCase (VarP xsRep)
				[CompE [BindS (VarP xRep) (VarE xsRep),
					NoBindS (CaseE (VarE xRep) [Match pat (NormalB e) [] | FromRepCase pat [e] <- fromR])]]])
	(TupleT 0, _)	-> return unit
	(TupleT _, ts)	-> do
		reps <- mapM typeRepr ts
		let (tRep, toR, fromR) = foldr1 prod reps
		return (tRep, [ToRepCase [TupP pats] e | ToRepCase pats e <- toR], [FromRepCase pat [TupE es] | FromRepCase pat es <- fromR])
	(ConT con, ts)
		| con == ''()	-> return unit
		| con == ''Either, [tL, tR] <- ts
			-> do	(tRepL, lToR, lFromR) <- typeRepr tL
				(tRepR, rToR, rFromR) <- typeRepr tR
				return (ConT ''Either `AppT` tRepL `AppT` tRepR,
					[ToRepCase [ConP 'Left pats] (ConE 'Left `AppE` e) | ToRepCase pats e <- lToR] ++
						[ToRepCase [ConP 'Right pats] (ConE 'Right `AppE` e) | ToRepCase pats e <- rToR],
					[FromRepCase (ConP 'Left [pat]) [ConE 'Left `AppE` e] | FromRepCase pat [e] <- lFromR] ++
						[FromRepCase (ConP 'Right [pat]) [ConE 'Right `AppE` e] | FromRepCase pat [e] <- rFromR])
		| otherwise -> do	ClassI _ instances <- reify ''Repr
					let knowns = [tycon | ClassInstance{ci_tys = [ConT tycon]} <- instances]
					-- TODO: recognize preexisting higher-arity instances
					if con `elem` knowns && null ts then do
						arg <- newName "arg"
						argRep <- newName "argRep"
						return (ConT ''Rep `AppT` ConT con,
							[ToRepCase [VarP arg] (VarE 'toRep `AppE` VarE arg)],
							[FromRepCase (VarP argRep) [VarE 'fromRep `AppE` VarE argRep]])
					else recursiveRepr t0
	_	-> recursiveRepr t0

recursiveRepr :: Type -> ReprM Representation
recursiveRepr t0 = do	-- TODO: handle type synonyms here
		x <- newName "arg"
		return (ConT ''Key `AppT` t0, 
			[ToRepCase [VarP x] (ConE 'Key `AppE` VarE x)],
			[FromRepCase (ConP 'Key [VarP x]) [VarE x]])

unit :: Representation
unit = (TupleT 0, [ToRepCase [] (TupE [])], [FromRepCase WildP []])

prod :: Representation -> Representation -> Representation
prod (t1, toRep1, fromRep1)
	(t2, toRep2, fromRep2) =
	(TupleT 2 `AppT` t1 `AppT` t2,
		do	ToRepCase pats1 out1 <- toRep1
			ToRepCase pats2 out2 <- toRep2
			return (ToRepCase (pats1 ++ pats2) (TupE [out1, out2])),
		do	FromRepCase pat1 out1 <- fromRep1
			FromRepCase pat2 out2 <- fromRep2
			return (FromRepCase (TupP [pat1, pat2]) (out1 ++ out2)))

conify :: Name -> Representation -> Representation
conify conName (t, toR, fromR) =
	(t, [ToRepCase [ConP conName args] e | ToRepCase args e <- toR], 
		[FromRepCase p [foldl AppE (ConE conName) outs] | FromRepCase p outs <- fromR])

union :: Representation -> Representation -> Representation
union (t1, toRep1, fromRep1)
	(t2, toRep2, fromRep2) =
	(ConT ''Either `AppT` t1 `AppT` t2,
		[ToRepCase pats (ConE 'Left `AppE` e) | ToRepCase pats e <- toRep1] ++
		[ToRepCase pats (ConE 'Right `AppE` e) | ToRepCase pats e <- toRep2],
		[FromRepCase (ConP 'Left [pat]) es | FromRepCase pat es <- fromRep1] ++
		[FromRepCase (ConP 'Right [pat]) es | FromRepCase pat es <- fromRep2])