{-# LANGUAGE TemplateHaskell #-}

{- |
  Module      :  OrPatterns.Internal
  Copyright   :  (c) Adam Vogt 2011 - 2014
  License     :  BSD3
  Maintainer  :  Adam Vogt <vogt.adam@gmail.com>
  Stability   :  experimental
  Portability :  GHC>=7 -XTemplateHaskell, -XViewPatterns, syb

-}

module OrPatterns.Internal (
    pats,
    tryParseSplits,
    combineSplits,
  ) where

import Control.Monad.Error
import Data.Generics
import Data.List.Split
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
import Language.Haskell.Meta.Syntax.Translate (ToPat(toPat))
import qualified Language.Haskell.Exts as E
import qualified Data.Map as M
import Data.List

-- | parser used in 'OrPatterns.o'
pats :: String -> Either String PatQ
pats str = combineSplits =<< tryParseSplits sep parsePatAllExts (splitOn sep str)
  where sep = " | "

tryParseSplits :: [s] -> ([s] -> Either e r) -> [[s]] -> Either e [r]
tryParseSplits filler parsePat pieces = 
    let go accum (a:b:bs) = case parsePat a of
              Left {} -> go accum ((a ++ filler ++ b) : bs)
              Right x -> go (x:accum) (b:bs)
        go accum [a] = case parsePat a of
              Left msg -> Left msg
              Right x -> Right (reverse (x : accum))
        go accum [] = Right (reverse accum)
    in go [] pieces

combineSplits :: [Pat] -> Either String PatQ
combineSplits opts = do
    let counts = map
            (everything
                (M.unionWith (+))
                (mkQ M.empty (\x -> case x of
                    VarP n ->  M.singleton n 1
                    _ -> M.empty))
            )
            opts
    -- assume that patterns may not repeat names of variables
    unless (all (== length counts) $ M.elems $ M.unionsWith (+) counts)
           (fail "Equations do not bind equal variables")

    let vars = M.keys (head counts)
        dest = [| Just $(tupE (map varE vars)) |]
        destP = conP 'Just [ tupP (map varP vars) ]

    return $ viewP
        [| \x -> $(caseE [| x |] $
            [ match (return p) (normalB dest) [] | p <- opts]
            ++ [match wildP (normalB [| Nothing |]) []]
            )
         |]
       destP


parsePatAllExts :: String -> Either String Pat
parsePatAllExts str = toEither $ E.parsePatWithMode allExtensionsMode str

toEither :: (ToPat a, Show a) => E.ParseResult a -> Either String Pat
toEither (E.ParseOk x) = Right (toPat x)
toEither err = Left (show err)

allExtensionsMode :: E.ParseMode
allExtensionsMode =
    E.defaultParseMode{
        E.fixities = Nothing, -- these get filled in later
        E.extensions = map E.EnableExtension [
          E.ImplicitParams,
          E.BangPatterns,
          E.NamedFieldPuns,
          E.PatternGuards,
          E.TypeFamilies,
          E.UnicodeSyntax,
          E.TypeOperators,
          E.RecordWildCards,
          E.LambdaCase,
          E.ViewPatterns,
          E.TupleSections,
          E.NPlusKPatterns,
          E.DataKinds,
          E.PolyKinds,
          E.MultiWayIf ]}