{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS -fno-warn-overlapping-patterns #-}
module Helper where
    

import Language.Haskell.TH
import Control.Monad
import Data.Generics
    
decomposeForallT :: Type -> ([Type],Type)
decomposeForallT (ForallT _ cxt t) = case decomposeForallT t of
                                       (x,y) -> (cxt++x,y)
decomposeForallT t = ([],t)


lemma :: Name -> Name -> Q [Dec]
lemma cls prf = do
        prfInfo <- reify prf

        let
            typ = case prfInfo of
                    VarI _ typ _ _ -> unmangleNames typ
                    _ -> error ("expected ValI, got "++show prfInfo)

        info <- reify cls
        let methodId = case info of
                         ClassI (ClassD _ _ _ _ [SigD x _]) -> x
                         _ -> error ("expected ClassI, got "++show info)
                             
            (cxt,bodyT) = decomposeForallT typ

        -- sig <- sigD prfId (return typ)
        
        inst <- instanceD (return cxt) 
               (conT cls `appT` return bodyT) [valD (varP methodId)
                                                        (normalB (varE prf))
                                                        []
                                              ]
               
        return [inst]


unmangleName :: Name -> Name
unmangleName = mkName . fst . break (=='[') . nameBase

unmangleNames :: (Data a) => a -> a
unmangleNames = everywhere (mkT unmangleName)

lemmata :: Name -> [Name] -> Q [Dec]
lemmata cls xs = fmap concat (mapM (lemma cls) xs)