{-# LANGUAGE DeriveFunctor              #-}
{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE TupleSections              #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE RankNTypes #-}
module Waargonaut.Decode.Internal
  ( CursorHistory' (..)
  , ppCursorHistory
  , DecodeResultT (..)
  , Decoder' (..)
  , withCursor'
  , runDecoderResultT
  , try
  , recordZipperMove
    
  , null'
  , int'
  , text'
  , string'
  , unboundedChar'
  , boundedChar'
  , bool'
  , array'
  , integral'
  , scientific'
  , objTuples'
  , foldCursor'
  , prismDOrFail'
    
  , mapKeepingF
  , mapKeepingFirst
  , mapKeepingLast
    
  , module Waargonaut.Decode.Error
  , module Waargonaut.Decode.ZipperMove
  ) where
import           Control.Applicative           (liftA2, (<|>))
import           Control.Lens                  (Rewrapped, Wrapped (..), (%=),
                                                _1, _Wrapped)
import qualified Control.Lens                  as L
import           Control.Monad                 ((>=>))
import           Control.Monad.Except          (ExceptT (..), MonadError (..),
                                                liftEither, runExceptT)
import           Control.Monad.State           (MonadState (..), StateT (..))
import           Control.Monad.Trans.Class     (MonadTrans (lift))
import           Control.Monad.Error.Hoist                 ((<!?>))
import           Control.Monad.Morph           (MFunctor (..), MMonad (..))
import           Data.Bifunctor                (first)
import           Data.Functor                  (($>))
import           Data.Sequence                 (Seq)
import           Data.Text                     (Text)
import           Data.Map                      (Map)
import qualified Data.Map                      as Map
import qualified Data.Vector                   as V
import qualified Data.Witherable               as Wither
import           Data.Scientific               (Scientific)
import qualified Data.Scientific               as Sci
import           Waargonaut.Types              (AsJType (..), JString,
                                                jNumberToScientific,
                                                jsonAssocKey, jsonAssocVal,
                                                _JChar, _JString)
import           Waargonaut.Types.CommaSep     (toList)
import           Waargonaut.Types.JChar        (jCharToUtf8Char)
import           Text.PrettyPrint.Annotated.WL (Doc, (<+>))
import           Waargonaut.Decode.Error       (AsDecodeError (..),
                                                DecodeError (..))
import           Waargonaut.Decode.ZipperMove  (ZipperMove (..), ppZipperMove)
newtype CursorHistory' i = CursorHistory'
  { unCursorHistory' :: Seq (ZipperMove, i)
  }
  deriving (Show, Eq)
ppCursorHistory
  :: CursorHistory' i
  -> Doc a
ppCursorHistory =
  foldr (<+>) mempty
  . fmap (ppZipperMove . fst)
  . unCursorHistory'
instance CursorHistory' i ~ t => Rewrapped (CursorHistory' i) t
instance Wrapped (CursorHistory' i) where
  type Unwrapped (CursorHistory' i) = Seq (ZipperMove, i)
  _Wrapped' = L.iso (\(CursorHistory' x) -> x) CursorHistory'
  {-# INLINE _Wrapped' #-}
newtype DecodeResultT i e f a = DecodeResultT
  { runDecodeResult :: ExceptT e (StateT (CursorHistory' i) f) a
  }
  deriving ( Functor
           , Applicative
           , Monad
           , MonadState (CursorHistory' i)
           , MonadError e
           )
instance MonadTrans (DecodeResultT i e) where
  lift = DecodeResultT . lift . lift
instance MFunctor (DecodeResultT i e) where
  hoist nat (DecodeResultT dr) = DecodeResultT (hoist (hoist nat) dr)
instance MMonad (DecodeResultT i e) where
  embed t dr = DecodeResultT $ do
    (e, hist) <- runDecodeResult (t (runner dr))
    put hist
    liftEither e
      where
        runner = flip runStateT (CursorHistory' mempty)
          . runExceptT . runDecodeResult
newtype Decoder' c i e f a = Decoder'
  { runDecoder' :: c -> DecodeResultT i e f a
  }
  deriving Functor
instance Monad f => Applicative (Decoder' c i e f) where
  pure       = pure
  aToB <*> a = Decoder' $ \c -> runDecoder' aToB c <*> runDecoder' a c
instance Monad f => Monad (Decoder' c i e f) where
  return      = pure
  a >>= aToFb = Decoder' $ \c -> runDecoder' a c >>= ($ c) . runDecoder' . aToFb
instance MonadTrans (Decoder' c i e) where
  lift = Decoder' . const . lift
instance MFunctor (Decoder' c i e) where
  hoist nat (Decoder' f) = Decoder' (hoist nat . f)
withCursor'
  :: (c -> DecodeResultT i e f a)
  -> Decoder' c i e f a
withCursor' =
  Decoder'
runDecoderResultT
  :: Monad f
  => DecodeResultT i DecodeError f a
  -> f (Either (DecodeError, CursorHistory' i) a)
runDecoderResultT =
  fmap (\(e, hist) -> first (,hist) e)
  . flip runStateT (CursorHistory' mempty)
  . runExceptT
  . runDecodeResult
recordZipperMove :: MonadState (CursorHistory' i) m => ZipperMove -> i -> m ()
recordZipperMove dir i = L._Wrapped %= (`L.snoc` (dir, i))
try :: MonadError e m => m a -> m (Maybe a)
try d = catchError (pure <$> d) (const (pure Nothing))
prismDOrFail'
  :: ( AsDecodeError e
     , MonadError e f
     )
  => e
  -> L.Prism' a b
  -> Decoder' c i e f a
  -> c
  -> DecodeResultT i e f b
prismDOrFail' e p d c =
  runDecoder' (L.preview p <$> d) c <!?> e
text' :: AsJType a ws a => a -> Maybe Text
text' = L.preview (_JStr . _1 . L.re _JString)
string' :: AsJType a ws a => a -> Maybe String
string' = L.preview (_JStr . _1 . _Wrapped . L.to (V.toList . V.map (_JChar L.#)))
boundedChar' :: AsJType a ws a => a -> Maybe Char
boundedChar' = L.preview (_JStr . _1 . _Wrapped . L._head) >=> jCharToUtf8Char
unboundedChar' :: AsJType a ws a => a -> Maybe Char
unboundedChar' = L.preview (_JStr . _1 . _Wrapped . L._head . L.re _JChar)
scientific' :: AsJType a ws a => a -> Maybe Scientific
scientific' = L.preview (_JNum . _1) >=> jNumberToScientific
integral' :: (Bounded i , Integral i , AsJType a ws a) => a -> Maybe i
integral' = scientific' >=> Sci.toBoundedInteger
int' :: AsJType a ws a => a -> Maybe Int
int' = integral'
bool' :: AsJType a ws a => a -> Maybe Bool
bool' = L.preview (_JBool . _1)
null' :: AsJType a ws a => a -> Maybe ()
null' a = L.preview _JNull a $> ()
array' :: AsJType a ws a => (a -> Maybe b) -> a -> [b]
array' f a = Wither.mapMaybe f (a L.^.. _JArr . _1 . L.folded)
objTuples'
  :: ( Applicative f
     , AsJType a ws a
     )
  => (JString -> f k)
  -> (a -> f b)
  -> a
  -> f [(k, b)]
objTuples' kF vF a =
  traverse g (a L.^.. _JObj . _1 . _Wrapped . L.to toList . L.folded)
  where
    g ja = liftA2 (,)
      (ja L.^. jsonAssocKey . L.to kF)
      (ja L.^. jsonAssocVal . L.to vF)
foldCursor'
  :: Monad f
  => b
  -> (b -> a -> b)
  -> (c -> DecodeResultT i e f c)
  -> Decoder' c i e f a
  -> c
  -> DecodeResultT i e f b
foldCursor' empty scons mvCurs elemD =
  go empty
  where
    go acc cur = do
      me <- fmap (scons acc) <$> try (runDecoder' elemD cur)
      maybe (pure acc)
        (\r -> try (mvCurs cur) >>= maybe (pure r) (go r))
        me
mapKeepingF
  :: ( Ord k
     , Applicative f
     , AsJType a ws a
     )
  => (t -> Maybe v -> Maybe v)
  -> (JString -> f k)
  -> (a -> f t)
  -> a
  -> f (Map k v)
mapKeepingF f kF vF a =
  foldr (\(k,v) -> Map.alter (f v) k) Map.empty <$> objTuples' kF vF a
mapKeepingFirst
  :: ( Ord k
     , Applicative f
     , AsJType a ws a
     )
  => (JString -> f k)
  -> (a -> f b)
  -> a
  -> f (Map k b)
mapKeepingFirst =
  mapKeepingF (\v -> (<|> Just v))
mapKeepingLast
  :: ( Ord k
     , Applicative f
     , AsJType a ws a
     )
  => (JString -> f k)
  -> (a -> f b)
  -> a
  -> f (Map k b)
mapKeepingLast =
  mapKeepingF (\v -> (Just v <|>))