{- -----------------------------------------------------------------------------
Copyright 2019-2021 Kevin P. Barry

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
----------------------------------------------------------------------------- -}

-- Author: Kevin P. Barry [ta0kira@gmail.com]

{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE Safe #-}

module Base.CompilerError (
  CollectErrorsM(..),
  ErrorContextM(..),
  ErrorContextT(..),
  (<??),
  (??>),
  (<!!),
  (!!>),
  collectAllM_,
  collectFirstM_,
  emptyErrorM,
  errorFromIO,
  isCompilerError,
  isCompilerErrorM,
  isCompilerSuccess,
  isCompilerSuccessM,
  mapCompilerM,
  mapCompilerM_,
  mapErrorsM,
  mergeObjectsM,
  silenceErrorsM,
  tryCompilerM,
) where

import Control.Monad.IO.Class
import Control.Monad.Trans
import Control.Monad.Trans.State (StateT,mapStateT)
import Data.Functor.Identity
import System.IO.Error (catchIOError)

#if MIN_VERSION_base(4,13,0)
import Control.Monad.Fail ()
#elif MIN_VERSION_base(4,9,0)
import Control.Monad.Fail
#endif


-- For some GHC versions, pattern-matching failures require MonadFail.
#if MIN_VERSION_base(4,9,0)
class (Monad m, MonadFail m) => ErrorContextM m where
#else
class Monad m => ErrorContextM m where
#endif
  compilerErrorM :: String -> m a
  withContextM :: m a -> String -> m a
  withContextM m a
c String
_ = m a
c
  summarizeErrorsM :: m a -> String -> m a
  summarizeErrorsM m a
e String
_ = m a
e
  compilerWarningM :: String -> m ()
  compilerWarningM String
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  compilerBackgroundM :: String -> m ()
  compilerBackgroundM String
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  resetBackgroundM :: m a -> m a
  resetBackgroundM = m a -> m a
forall a. a -> a
id

class ErrorContextM m => CollectErrorsM m where
  collectAllM :: Foldable f => f (m a) -> m [a]
  collectAnyM :: Foldable f => f (m a) -> m [a]
  collectFirstM :: Foldable f => f (m a) -> m a

class MonadTrans t => ErrorContextT t where
  isCompilerErrorT :: (Monad m, ErrorContextM (t m)) => t m a -> m Bool
  isCompilerSuccessT :: (Monad m, ErrorContextM (t m)) => t m a -> m Bool
  isCompilerSuccessT = (Bool -> Bool) -> m Bool -> m Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Bool -> Bool
not (m Bool -> m Bool) -> (t m a -> m Bool) -> t m a -> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t m a -> m Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(ErrorContextT t, Monad m, ErrorContextM (t m)) =>
t m a -> m Bool
isCompilerErrorT
  ifElseSuccessT :: (Monad m, ErrorContextM (t m)) => t m a -> m () -> m () -> t m a

(<??) :: ErrorContextM m => m a -> String -> m a
<?? :: m a -> String -> m a
(<??) = m a -> String -> m a
forall (m :: * -> *) a. ErrorContextM m => m a -> String -> m a
withContextM
infixl 1 <??

(??>) :: ErrorContextM m => String -> m a -> m a
??> :: String -> m a -> m a
(??>) = (m a -> String -> m a) -> String -> m a -> m a
forall a b c. (a -> b -> c) -> b -> a -> c
flip m a -> String -> m a
forall (m :: * -> *) a. ErrorContextM m => m a -> String -> m a
withContextM
infixr 1 ??>

(<!!) :: ErrorContextM m => m a -> String -> m a
<!! :: m a -> String -> m a
(<!!) = m a -> String -> m a
forall (m :: * -> *) a. ErrorContextM m => m a -> String -> m a
summarizeErrorsM
infixl 1 <!!

(!!>) :: ErrorContextM m => String -> m a -> m a
!!> :: String -> m a -> m a
(!!>) = (m a -> String -> m a) -> String -> m a -> m a
forall a b c. (a -> b -> c) -> b -> a -> c
flip m a -> String -> m a
forall (m :: * -> *) a. ErrorContextM m => m a -> String -> m a
summarizeErrorsM
infixr 1 !!>

collectAllM_ :: (Foldable f, CollectErrorsM m) => f (m a) -> m ()
collectAllM_ :: f (m a) -> m ()
collectAllM_ = ([a] -> ()) -> m [a] -> m ()
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (() -> [a] -> ()
forall a b. a -> b -> a
const ()) (m [a] -> m ()) -> (f (m a) -> m [a]) -> f (m a) -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (m a) -> m [a]
forall (m :: * -> *) (f :: * -> *) a.
(CollectErrorsM m, Foldable f) =>
f (m a) -> m [a]
collectAllM

collectFirstM_ :: (Foldable f, CollectErrorsM m) => f (m a) -> m ()
collectFirstM_ :: f (m a) -> m ()
collectFirstM_ = (a -> ()) -> m a -> m ()
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (() -> a -> ()
forall a b. a -> b -> a
const ()) (m a -> m ()) -> (f (m a) -> m a) -> f (m a) -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (m a) -> m a
forall (m :: * -> *) (f :: * -> *) a.
(CollectErrorsM m, Foldable f) =>
f (m a) -> m a
collectFirstM

mapCompilerM :: CollectErrorsM m => (a -> m b) -> [a] -> m [b]
mapCompilerM :: (a -> m b) -> [a] -> m [b]
mapCompilerM a -> m b
f = [m b] -> m [b]
forall (m :: * -> *) (f :: * -> *) a.
(CollectErrorsM m, Foldable f) =>
f (m a) -> m [a]
collectAllM ([m b] -> m [b]) -> ([a] -> [m b]) -> [a] -> m [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> m b) -> [a] -> [m b]
forall a b. (a -> b) -> [a] -> [b]
map a -> m b
f

mapCompilerM_ :: CollectErrorsM m => (a -> m b) -> [a] -> m ()
mapCompilerM_ :: (a -> m b) -> [a] -> m ()
mapCompilerM_ a -> m b
f = [m b] -> m ()
forall (f :: * -> *) (m :: * -> *) a.
(Foldable f, CollectErrorsM m) =>
f (m a) -> m ()
collectAllM_ ([m b] -> m ()) -> ([a] -> [m b]) -> [a] -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> m b) -> [a] -> [m b]
forall a b. (a -> b) -> [a] -> [b]
map a -> m b
f

tryCompilerM :: CollectErrorsM m => m a -> m (Maybe a)
tryCompilerM :: m a -> m (Maybe a)
tryCompilerM m a
x = [m (Maybe a)] -> m (Maybe a)
forall (m :: * -> *) (f :: * -> *) a.
(CollectErrorsM m, Foldable f) =>
f (m a) -> m a
collectFirstM [m a
x m a -> (a -> m (Maybe a)) -> m (Maybe a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Maybe a -> m (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe a -> m (Maybe a)) -> (a -> Maybe a) -> a -> m (Maybe a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Maybe a
forall a. a -> Maybe a
Just,Maybe a -> m (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing]

isCompilerError :: (ErrorContextT t, ErrorContextM (t Identity)) => t Identity a -> Bool
isCompilerError :: t Identity a -> Bool
isCompilerError = Identity Bool -> Bool
forall a. Identity a -> a
runIdentity (Identity Bool -> Bool)
-> (t Identity a -> Identity Bool) -> t Identity a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t Identity a -> Identity Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(ErrorContextT t, Monad m, ErrorContextM (t m)) =>
t m a -> m Bool
isCompilerErrorT

isCompilerSuccess :: (ErrorContextT t, ErrorContextM (t Identity)) => t Identity a -> Bool
isCompilerSuccess :: t Identity a -> Bool
isCompilerSuccess = Identity Bool -> Bool
forall a. Identity a -> a
runIdentity (Identity Bool -> Bool)
-> (t Identity a -> Identity Bool) -> t Identity a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t Identity a -> Identity Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(ErrorContextT t, Monad m, ErrorContextM (t m)) =>
t m a -> m Bool
isCompilerSuccessT

isCompilerErrorM :: CollectErrorsM m => m a -> m Bool
isCompilerErrorM :: m a -> m Bool
isCompilerErrorM m a
x = [m Bool] -> m Bool
forall (m :: * -> *) (f :: * -> *) a.
(CollectErrorsM m, Foldable f) =>
f (m a) -> m a
collectFirstM [m a
x m a -> m Bool -> m Bool
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False,Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True]

isCompilerSuccessM :: CollectErrorsM m => m a -> m Bool
isCompilerSuccessM :: m a -> m Bool
isCompilerSuccessM m a
x = [m Bool] -> m Bool
forall (m :: * -> *) (f :: * -> *) a.
(CollectErrorsM m, Foldable f) =>
f (m a) -> m a
collectFirstM [m a
x m a -> m Bool -> m Bool
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True,Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False]

mapErrorsM :: CollectErrorsM m => [String] -> m a
mapErrorsM :: [String] -> m a
mapErrorsM [String]
es = (String -> m Any) -> [String] -> m ()
forall (m :: * -> *) a b.
CollectErrorsM m =>
(a -> m b) -> [a] -> m ()
mapCompilerM_ String -> m Any
forall (m :: * -> *) a. ErrorContextM m => String -> m a
compilerErrorM [String]
es m () -> m a -> m a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> m a
forall (m :: * -> *) a. CollectErrorsM m => m a
emptyErrorM

emptyErrorM :: CollectErrorsM m => m a
emptyErrorM :: m a
emptyErrorM = String -> m a
forall (m :: * -> *) a. ErrorContextM m => String -> m a
compilerErrorM String
""

silenceErrorsM :: CollectErrorsM m => m a -> m a
silenceErrorsM :: m a -> m a
silenceErrorsM m a
x = do
  [a]
x' <- [m a] -> m [a]
forall (m :: * -> *) (f :: * -> *) a.
(CollectErrorsM m, Foldable f) =>
f (m a) -> m [a]
collectAnyM [m a
x]
  case [a]
x' of
       [a
y] -> a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
y
       [a]
_   -> m a
forall (m :: * -> *) a. CollectErrorsM m => m a
emptyErrorM

errorFromIO :: (MonadIO m, ErrorContextM m) => IO a -> m a
errorFromIO :: IO a -> m a
errorFromIO IO a
x = do
  Either String a
x' <- IO (Either String a) -> m (Either String a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either String a) -> m (Either String a))
-> IO (Either String a) -> m (Either String a)
forall a b. (a -> b) -> a -> b
$ (a -> Either String a) -> IO a -> IO (Either String a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> Either String a
forall a b. b -> Either a b
Right IO a
x IO (Either String a)
-> (IOError -> IO (Either String a)) -> IO (Either String a)
forall a. IO a -> (IOError -> IO a) -> IO a
`catchIOError` (Either String a -> IO (Either String a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String a -> IO (Either String a))
-> (IOError -> Either String a) -> IOError -> IO (Either String a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Either String a
forall a b. a -> Either a b
Left (String -> Either String a)
-> (IOError -> String) -> IOError -> Either String a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOError -> String
forall a. Show a => a -> String
show)
  case Either String a
x' of
       (Right a
x2) -> a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x2
       (Left String
e)   -> String -> m a
forall (m :: * -> *) a. ErrorContextM m => String -> m a
compilerErrorM String
e

-- For fixed x, if f y x succeeds for some y then x is removed.
mergeObjectsM :: CollectErrorsM m => (a -> a -> m b) -> [a] -> m [a]
mergeObjectsM :: (a -> a -> m b) -> [a] -> m [a]
mergeObjectsM a -> a -> m b
f = [a] -> [a] -> m [a]
merge [] where
  merge :: [a] -> [a] -> m [a]
merge [a]
cs [] = [a] -> m [a]
forall (m :: * -> *) a. Monad m => a -> m a
return [a]
cs
  merge [a]
cs (a
x:[a]
xs) = do
    [a]
ys <- [m [a]] -> m [a]
forall (m :: * -> *) (f :: * -> *) a.
(CollectErrorsM m, Foldable f) =>
f (m a) -> m a
collectFirstM ([m [a]] -> m [a]) -> [m [a]] -> m [a]
forall a b. (a -> b) -> a -> b
$ (a -> m [a]) -> [a] -> [m [a]]
forall a b. (a -> b) -> [a] -> [b]
map a -> m [a]
forall a. a -> m [a]
check ([a]
cs [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
xs) [m [a]] -> [m [a]] -> [m [a]]
forall a. [a] -> [a] -> [a]
++ [[a] -> m [a]
forall (m :: * -> *) a. Monad m => a -> m a
return [a
x]]
    [a] -> [a] -> m [a]
merge ([a]
cs [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
ys) [a]
xs where
      check :: a -> m [a]
check a
x2 = a
x2 a -> a -> m b
`f` a
x m b -> m [a] -> m [a]
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [a] -> m [a]
forall (m :: * -> *) a. Monad m => a -> m a
return []

instance ErrorContextM m => ErrorContextM (StateT a m) where
  compilerErrorM :: String -> StateT a m a
compilerErrorM      = m a -> StateT a m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> StateT a m a) -> (String -> m a) -> String -> StateT a m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> m a
forall (m :: * -> *) a. ErrorContextM m => String -> m a
compilerErrorM
  withContextM :: StateT a m a -> String -> StateT a m a
withContextM        = (String -> StateT a m a -> StateT a m a)
-> StateT a m a -> String -> StateT a m a
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((String -> StateT a m a -> StateT a m a)
 -> StateT a m a -> String -> StateT a m a)
-> (String -> StateT a m a -> StateT a m a)
-> StateT a m a
-> String
-> StateT a m a
forall a b. (a -> b) -> a -> b
$ (m (a, a) -> m (a, a)) -> StateT a m a -> StateT a m a
forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT ((m (a, a) -> m (a, a)) -> StateT a m a -> StateT a m a)
-> (String -> m (a, a) -> m (a, a))
-> String
-> StateT a m a
-> StateT a m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> m (a, a) -> m (a, a)
forall (m :: * -> *) a. ErrorContextM m => String -> m a -> m a
(??>)
  summarizeErrorsM :: StateT a m a -> String -> StateT a m a
summarizeErrorsM    = (String -> StateT a m a -> StateT a m a)
-> StateT a m a -> String -> StateT a m a
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((String -> StateT a m a -> StateT a m a)
 -> StateT a m a -> String -> StateT a m a)
-> (String -> StateT a m a -> StateT a m a)
-> StateT a m a
-> String
-> StateT a m a
forall a b. (a -> b) -> a -> b
$ (m (a, a) -> m (a, a)) -> StateT a m a -> StateT a m a
forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT ((m (a, a) -> m (a, a)) -> StateT a m a -> StateT a m a)
-> (String -> m (a, a) -> m (a, a))
-> String
-> StateT a m a
-> StateT a m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> m (a, a) -> m (a, a)
forall (m :: * -> *) a. ErrorContextM m => String -> m a -> m a
(!!>)
  compilerWarningM :: String -> StateT a m ()
compilerWarningM    = m () -> StateT a m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> StateT a m ())
-> (String -> m ()) -> String -> StateT a m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> m ()
forall (m :: * -> *). ErrorContextM m => String -> m ()
compilerWarningM
  compilerBackgroundM :: String -> StateT a m ()
compilerBackgroundM = m () -> StateT a m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> StateT a m ())
-> (String -> m ()) -> String -> StateT a m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> m ()
forall (m :: * -> *). ErrorContextM m => String -> m ()
compilerBackgroundM
  resetBackgroundM :: StateT a m a -> StateT a m a
resetBackgroundM    = (m (a, a) -> m (a, a)) -> StateT a m a -> StateT a m a
forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT m (a, a) -> m (a, a)
forall (m :: * -> *) a. ErrorContextM m => m a -> m a
resetBackgroundM