{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}

module Data.BidiSpec
    (Spec, SpecGen, SpecParser(..)
    ,mkSpec, parseBySpec, genBySpec, runSpecParser, rsGen, rsParse
    ,spGet, spGets, spCheck, spFromMaybe, spFromEither
    ,rsPair, rsTriple, rsQuadruple, rsMaybe
    ,rsWrap, rsWrapMaybe, rsWrapEither, rsWrapEither',rsCondSeq
    ,rsChoice, rsAlt, rsTagSwitch, rsSwitch, rsCase, rsCaseConst
    ,rsGetSet,rsCheckSet, rsLift, rsUnit, rsZero, rsWith, rsDefault
    ,rsMaybeDefault
    )
where

----------------------------------------
-- STDLIB
----------------------------------------
import Control.Monad (MonadPlus(..), guard, liftM)
import Control.Monad.Error (MonadError(..), Error(..))
import Control.Monad.Trans (MonadTrans, lift)
import Control.Monad.Reader (MonadReader, ReaderT, runReaderT, ask, local, asks)

import Data.List (find)

-- ----------------------------------------------------------------------------
--  SpecParser
-- ----------------------------------------------------------------------------
newtype SpecParser s e a
    = SpecParser { unSpecParser :: ReaderT s (Either e) a }
    deriving (Monad, MonadPlus)

instance Error e => MonadError e (SpecParser s e) where
    throwError e = SpecParser $ lift (throwError e)
    catchError (SpecParser ra) f =
        do s <- spGet
           case runReaderT ra s of
             Left e -> f e
             Right x -> return x

instance Error e => MonadReader s (SpecParser s e) where
    ask = SpecParser ask
    local mapSt (SpecParser cont) = SpecParser (local mapSt cont)

runSpecParser :: SpecParser s e a -> s -> Either e a
runSpecParser = runReaderT . unSpecParser

spGet :: Error e => SpecParser s e s
spGet = ask

spGets :: Error e => (s -> a) -> SpecParser s e a
spGets = asks

spCheck :: Error e => (a -> Bool) -> (a -> e) -> a -> SpecParser s e ()
spCheck check mkErr a = guard (check a) `mplus` (throwError (mkErr a))

-- Not exported because the user should supply custom error messages!
spFromMaybe :: Error e => e -> Maybe a -> SpecParser s e a
spFromMaybe e Nothing = throwError e
spFromMaybe _e (Just a) = return a

spFromEither :: Error e => Either e a -> SpecParser s e a
spFromEither (Left e) = throwError e
spFromEither (Right a) = return a

-- ----------------------------------------------------------------------------
--  Spec generator
-- ----------------------------------------------------------------------------

type SpecGen tgt a = tgt -> a -> tgt

-- ----------------------------------------------------------------------------
--  Spec data type
-- ----------------------------------------------------------------------------

data Spec err src tgt a
    = Spec
    { rsGen ::  SpecGen tgt a
    , rsParse :: SpecParser src err a }

mkSpec :: SpecParser i e a -> SpecGen o a -> Spec e i o a
mkSpec = flip Spec

parseBySpec :: MonadError e m => Spec e i o a -> i -> m a
parseBySpec sp i =
    case runReaderT (unSpecParser (rsParse sp)) i of
      Left err -> throwError err
      Right res -> return res

genBySpec :: Monad m => Spec e i o a -> o -> a -> m o
genBySpec sp o = return . rsGen sp o


-- ----------------------------------------------------------------------------
--  Spec combinators
-- ----------------------------------------------------------------------------

rsPair :: Error e => Spec e i o a -> Spec e i o b -> Spec e i o (a,b)
rsPair rsA rsB = mkSpec rsParseDef rsGenDef
    where rsGenDef rout (a,b) = rsGen rsA (rsGen rsB rout b) a
          rsParseDef = do a <- rsParse rsA
                          b <- rsParse rsB
                          return (a, b)

rsTriple :: Error e =>
            Spec e i o a
         -> Spec e i o b
         -> Spec e i o c
         -> Spec e i o (a,b,c)
rsTriple rsA rsB rsC = mkSpec rsParseDef rsGenDef
    where rsGenDef rout (a,b,c) = rsGen rsC (rsGen rsB (rsGen rsA rout a) b) c
          rsParseDef = do a <- rsParse rsA
                          b <- rsParse rsB
                          c <- rsParse rsC
                          return (a, b, c)

rsQuadruple :: Error e =>
               Spec e i o a
            -> Spec e i o b
            -> Spec e i o c
            -> Spec e i o d
            -> Spec e i o (a,b,c,d)
rsQuadruple rsA rsB rsC rsD = mkSpec rsParseDef rsGenDef
    where rsGenDef rout (a,b,c,d)
              = rsGen rsD (rsGen rsC (rsGen rsB (rsGen rsA rout a) b) c) d
          rsParseDef = do a <- rsParse rsA
                          b <- rsParse rsB
                          c <- rsParse rsC
                          d <- rsParse rsD
                          return (a, b, c, d)

rsWrap :: Error e => (a -> b, b -> a) -> Spec e i o a -> Spec e i o b
rsWrap (toB,toA) = rsWrapMaybe (error "BidiSpec: rsWrap") (return . toB, toA)

rsMaybe :: Error e => Spec e i o a -> Spec e i o (Maybe a)
rsMaybe rsA = mkSpec rsParseDef rsGenDef
    where rsParseDef = liftM Just (rsParse rsA) `mplus` return Nothing
          rsGenDef rout Nothing = rout
          rsGenDef rout (Just a) = rsGen rsA rout a

rsMaybeDefault :: Error e => a -> Spec e i o a -> Spec e i o a
rsMaybeDefault defaultA rsA = mkSpec rsParseDef rsGenDef
    where rsParseDef = rsParse rsA `mplus` return defaultA
          rsGenDef = rsGen rsA

rsWrapMaybe :: Error e =>
               String                        -- error message for Maybe case
            -> (a -> Maybe b, b -> a)        -- wrappers
            -> Spec e i o a
            -> Spec e i o b
rsWrapMaybe msg (aToB, bToA) rsA = mkSpec rsParseDef rsGenDef
    where rsGenDef rout b = rsGen rsA rout (bToA b)
          rsParseDef = rsParse rsA >>= parseA . aToB
          parseA Nothing = fail (strMsg $ "rsWrapMaybe: " ++ msg)
          parseA (Just b) = return b

rsWrapEither :: Error e =>
                (a -> Either e b, b -> a)     -- wrappers
             -> Spec e i o a
             -> Spec e i o b
rsWrapEither (aToB, bToA) rsA = mkSpec rsParseDef rsGenDef
    where rsGenDef rout b = rsGen rsA rout (bToA b)
          rsParseDef = rsParse rsA >>= parseA . aToB
          parseA (Left err) = throwError err
          parseA (Right b) = return b

rsWrapEither' :: (Show l, Error e) =>
                 (a -> Either l b, b -> a)     -- wrappers
              -> Spec e i o a
              -> Spec e i o b
rsWrapEither' (toB,toA) = rsWrapEither (mapLeft (strMsg . show) . toB, toA)
    where mapLeft f (Left a) = Left (f a)
          mapLeft _f (Right c) = Right c

rsCondSeq :: Error e =>
             Spec e i o b
          -> (b -> a)
          -> Spec e i o a
          -> (a -> Spec e i o b)
          -> Spec e i o b
rsCondSeq pd f pa k = mkSpec rsParseDef rsGenDef
    where rsGenDef rout b = let a = f b
                                pb = k a
                            in rsGen pa (rsGen pb rout b) a
          rsParseDef = do a <- rsParse pa
                          rsParse (k a)
                       `mplus` rsParse pd

rsChoice :: Error e =>
            Spec e i o b
         -> Spec e i o a
         -> (a -> Spec e i o b)
         -> Spec e i o b
rsChoice pb = rsCondSeq pb (error "rsChoice: undefined")

rsAlt :: Error e => (a -> Int) -> [Spec e i o a] -> Spec e i o a
rsAlt getIdx alts = mkSpec rsParseDef rsGenDef
    where rsGenDef rout a = rsGen (alts !! getIdx a) rout a
          rsParseDef =
              case alts of
                [] -> rsParse rsZero
                (x:xs) -> rsParse (rsChoice (rsAlt getIdx xs) x rsLift)

data SpecCase e i o a = SpecCase { case_value :: a
                                 , case_spec :: Spec e i o a }

rsCase :: Error e => (a -> b, b -> a) -> Spec e i o a -> SpecCase e i o b
rsCase wrapfuns@(aToB,_bToA) specA = SpecCase value spec
    where value = aToB (error "rsCase: tagging function requires evaluation")
          spec = rsWrap wrapfuns specA

rsCaseConst :: Error e =>
               a                               -- constant to match/generate
            -> (Spec e i o b -> Spec e i o b)  -- continuation
            -> SpecCase e i o a
rsCaseConst a mkRs = rsCase (const a, const undef) (mkRs (rsLift undef))
    where undef = error "rsCaseConst: this value should have been ignored"

rsSwitch :: (Error e, Show a) => [SpecCase e i o a] -> Spec e i o a
rsSwitch = rsTagSwitch (takeWhile isNotDelim . dropWhile (=='(') . show)
    where isNotDelim x = x /= ' ' && x /= ','

rsTagSwitch :: (Error e, Eq t) => (a -> t) -> [SpecCase e i o a] -> Spec e i o a
rsTagSwitch tag cases = mkSpec rsParseDef rsGenDef
    where
      rsParseDef = foldl mplus (fail noMatch) (map (rsParse . case_spec) cases)
      rsGenDef rout a =
          case fmap case_spec $ find ((tag a ==) . tag . case_value) cases of
            Just spec -> rsGen spec rout a
            Nothing -> error noMatch
      noMatch = "rsSwitch: No case matched."

rsGetSet :: Error e => (i -> a) -> (o -> a -> o) -> Spec e i o a
rsGetSet get set = mkSpec rsParseDef rsGenDef
    where rsGenDef = set
          rsParseDef = spGet >>= return . get

rsCheckSet :: Error e =>
              SpecParser i e ()
           -> (o -> o)
           -> Spec e i o a
           -> Spec e i o a
rsCheckSet parser setfun rs = mkSpec rsParseDef rsGenDef
    where rsGenDef rout a = rsGen rs (setfun rout) a
          rsParseDef = parser >> rsParse rs

rsDefault :: Error e => (e -> a) -> Spec e u i a -> Spec e u i a
rsDefault onError (Spec gen parse) = Spec gen rsParseDef
    where rsParseDef = catchError parse (return . onError)

rsLift :: Error e => a -> Spec e i o a
rsLift x = mkSpec (return x) const

rsZero :: Error e => Spec e i o a
rsZero = mkSpec (fail "rsZero") const

rsUnit :: Error e => Spec e i o ()
rsUnit = rsLift ()

rsWith :: (Spec e i o a -> Spec e i o b) -> Spec e i o a -> Spec e i o b
rsWith = ($)