{-# LANGUAGE TemplateHaskell #-}
{- | Description: template haskell to help emulate partial type signatures

Example usage (GHC-7.8):

> sigs
>     [| ["f1" :: a -> b -> (a, Int),
>         "f1" :: b -> a -> (Char, a) ] |]
> 
> f1 x y | False = $(unionSigs [| f1 x y |])
> f1 x y = undefined -- (x,y)


A GHC-7.6 compatible version must be slightly longer
to work around the extra typechecking done of [| |]
brackets:

> sigs [| do
>   f2 <- Nothing
>   Just [ f2 :: a -> b -> (a, Int),
>          f2 :: b -> a -> (Char, a) ]
>   |]
>   
> f2 x y | False = $(unionSigs [| f2  x y |])
> f2 x y = undefined -- (x,y)

If the expression splice generated by 'unionSigs' is left out,

> sigs
>     [| ["g" :: a -> b -> (a, Int),
>         "g" :: b -> a -> (Char, a) ] |]
> 
> g x y = undefined -- (x,y)

then @g@'s type takes the most general type @(g :: t)@, and the two
functions defined by 'sigs' can be used to restrict the type of g:

> partialTypeSig_g1 :: (t -> t1 -> (t, Int)) -> t -> t1 -> (t, Int)
> partialTypeSig_g1 = id
> 
> partialTypeSig_g2 :: (t -> t1 -> (Char, t1)) -> t -> t1 -> (Char, t1)
> partialTypeSig_g2 = id

-}
module PartialTypeSigs
  ( sigs,
    unionSigs,
   ) where


import qualified Data.Map as M
import Data.IORef
import System.IO.Unsafe
import Data.Maybe
import Language.Haskell.TH
import Data.Generics
import Language.Haskell.TH.Syntax
import Data.Monoid
import Control.Monad
import Data.Either

{-# NOINLINE m #-}
m :: IORef (M.Map String Int)
m = unsafePerformIO (newIORef M.empty)

{- | any subexpression of the passed-in expression which looks like:

> "functionName" :: t
> functionName :: t

generates the following function:

> partialTypeSig_functionName1 x = x `asTypeOf` (functionName `asTypeOf` (undefined :: t))

Note that the above function is not the same as

> badId x = x `asTypeOf` (functionName :: t)

which requires that @t@ be more specific than @functionName@

-}
sigs :: ExpQ -> DecsQ
sigs es = do
    runIO $ writeIORef m M.empty
    es <- es

    let nts :: [(Exp, Type)]
        nts = everything (<>)
              ( \x -> [ (e,t) | SigE e t <- maybeToList (cast x) ])
            es

        (ntsGood, ntsBad) = partitionEithers 
          $ map (\(e,t) -> case e of
                  VarE (Name (OccName n) _nameFlavour) -> Left (n,t)
                  LitE (StringL n) -> Left (n,t)
                  _ -> Right e)
            nts

    unless (null ntsBad) $ reportWarning
        $ "Don't know how to interpret the left-hand-side of '::':" ++ show ntsBad

    fmap concat $ mapM (\(n,t) -> unifiesWith1 n (return t)) ntsGood
    
unifiesWithPrefix = "partialTypeSig_"

{- | > unifiesWith1 "f" [t| forall a b c. a -> b -> c |]

generates a function

> partialTypeSig_f1 x = x `asTypeOf` (undefined :: a -> b -> c)

-}
unifiesWith1 :: String -> TypeQ -> DecsQ
unifiesWith1 e t = do
    k <- runIO $ atomicModifyIORef m $ \k ->
          let k' = M.insertWith (+) e 1 k
          in (k', fromMaybe 1 $ M.lookup e k')
    unifiesWith2 (unifiesWithPrefix++e++show k) (dyn e) t

unifiesWith2 :: String -> ExpQ -> TypeQ -> DecsQ
unifiesWith2 s e t = do
    x <- newName "x"
    fmap (:[]) $ funD (mkName s)
      [clause [varP x]
      (normalB [| ($(varE x) `asTypeOf` $e) `asTypeOf` (undefined :: $t) |])
      []]

unionSigs :: ExpQ -> ExpQ
unionSigs  call = do
    call <- call
    VarE (Name (OccName k) _) : args <- return $ reverse (unappsErev call)
    m <- runIO (readIORef m)
    maybe noCxt (toExp k args) $ M.lookup k m
  where
    noCxt = do
      reportWarning $ "PartialTypeSigs.unionSigs: missing a call to PartialTypeSigs.sigs directly above"
      [| error "PartialTypeSigs.unionSigs no context given" |]
    toExp :: String -> [Exp] -> Int -> ExpQ
    toExp k args n = 
      foldr (\x y -> [| $x `asTypeOf` $y |])
        [| error "PartialTypeSigs.unionSigs should be applied to an unreachable function clause" |]
        [ foldl appE
              [| $(dyn (unifiesWithPrefix++k++show i)) undefined |]
              (map return args)
            | i <- [1 .. n] ]

{- | reverse of what you'd expect:

> unappsErev <$> [| x (f y) z |]

is equivalent to

> sequence [ [| z |], [| f y |] , [| x |] ]@
-}
unappsErev :: Exp -> [Exp]
unappsErev (AppE x y) = y : unappsErev x
unappsErev (ConE x) | x == '() = []
unappsErev x = [x]