{- -----------------------------------------------------------------------------
Copyright 2019-2021 Kevin P. Barry

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
----------------------------------------------------------------------------- -}

-- Author: Kevin P. Barry [ta0kira@gmail.com]

{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE Safe #-}

module Base.CompilerError (
  CollectErrorsM(..),
  ErrorContextM(..),
  ErrorContextT(..),
  (<??),
  (??>),
  (<!!),
  (!!>),
  collectAllM_,
  collectFirstM_,
  emptyErrorM,
  errorFromIO,
  isCompilerError,
  isCompilerErrorM,
  isCompilerSuccess,
  isCompilerSuccessM,
  mapCompilerM,
  mapCompilerM_,
  mapErrorsM,
  mergeObjectsM,
  silenceErrorsM,
  tryCompilerM,
) where

import Control.Monad.IO.Class
import Control.Monad.Trans
import Control.Monad.Trans.State (StateT,mapStateT)
import Data.Functor.Identity
import System.IO.Error (catchIOError)

#if MIN_VERSION_base(4,13,0)
import Control.Monad.Fail ()
#elif MIN_VERSION_base(4,9,0)
import Control.Monad.Fail
#endif


-- For some GHC versions, pattern-matching failures require MonadFail.
#if MIN_VERSION_base(4,9,0)
class (Monad m, MonadFail m) => ErrorContextM m where
#else
class Monad m => ErrorContextM m where
#endif
  compilerErrorM :: String -> m a
  withContextM :: m a -> String -> m a
  withContextM m a
c String
_ = m a
c
  summarizeErrorsM :: m a -> String -> m a
  summarizeErrorsM m a
e String
_ = m a
e
  compilerWarningM :: String -> m ()
  compilerWarningM String
_ = forall (m :: * -> *) a. Monad m => a -> m a
return ()
  compilerBackgroundM :: String -> m ()
  compilerBackgroundM String
_ = forall (m :: * -> *) a. Monad m => a -> m a
return ()
  resetBackgroundM :: m a -> m a
  resetBackgroundM = forall a. a -> a
id

class ErrorContextM m => CollectErrorsM m where
  collectAllM :: Foldable f => f (m a) -> m [a]
  collectAnyM :: Foldable f => f (m a) -> m [a]
  collectFirstM :: Foldable f => f (m a) -> m a

class MonadTrans t => ErrorContextT t where
  isCompilerErrorT :: (Monad m, ErrorContextM (t m)) => t m a -> m Bool
  isCompilerSuccessT :: (Monad m, ErrorContextM (t m)) => t m a -> m Bool
  isCompilerSuccessT = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(ErrorContextT t, Monad m, ErrorContextM (t m)) =>
t m a -> m Bool
isCompilerErrorT
  ifElseSuccessT :: (Monad m, ErrorContextM (t m)) => t m a -> m () -> m () -> t m a

(<??) :: ErrorContextM m => m a -> String -> m a
<?? :: forall (m :: * -> *) a. ErrorContextM m => m a -> String -> m a
(<??) = forall (m :: * -> *) a. ErrorContextM m => m a -> String -> m a
withContextM
infixl 1 <??

(??>) :: ErrorContextM m => String -> m a -> m a
??> :: forall (m :: * -> *) a. ErrorContextM m => String -> m a -> m a
(??>) = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) a. ErrorContextM m => m a -> String -> m a
withContextM
infixr 1 ??>

(<!!) :: ErrorContextM m => m a -> String -> m a
<!! :: forall (m :: * -> *) a. ErrorContextM m => m a -> String -> m a
(<!!) = forall (m :: * -> *) a. ErrorContextM m => m a -> String -> m a
summarizeErrorsM
infixl 1 <!!

(!!>) :: ErrorContextM m => String -> m a -> m a
!!> :: forall (m :: * -> *) a. ErrorContextM m => String -> m a -> m a
(!!>) = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) a. ErrorContextM m => m a -> String -> m a
summarizeErrorsM
infixr 1 !!>

collectAllM_ :: (Foldable f, CollectErrorsM m) => f (m a) -> m ()
collectAllM_ :: forall (f :: * -> *) (m :: * -> *) a.
(Foldable f, CollectErrorsM m) =>
f (m a) -> m ()
collectAllM_ = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b. a -> b -> a
const ()) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) (f :: * -> *) a.
(CollectErrorsM m, Foldable f) =>
f (m a) -> m [a]
collectAllM

collectFirstM_ :: (Foldable f, CollectErrorsM m) => f (m a) -> m ()
collectFirstM_ :: forall (f :: * -> *) (m :: * -> *) a.
(Foldable f, CollectErrorsM m) =>
f (m a) -> m ()
collectFirstM_ = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b. a -> b -> a
const ()) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) (f :: * -> *) a.
(CollectErrorsM m, Foldable f) =>
f (m a) -> m a
collectFirstM

mapCompilerM :: CollectErrorsM m => (a -> m b) -> [a] -> m [b]
mapCompilerM :: forall (m :: * -> *) a b.
CollectErrorsM m =>
(a -> m b) -> [a] -> m [b]
mapCompilerM a -> m b
f = forall (m :: * -> *) (f :: * -> *) a.
(CollectErrorsM m, Foldable f) =>
f (m a) -> m [a]
collectAllM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map a -> m b
f

mapCompilerM_ :: CollectErrorsM m => (a -> m b) -> [a] -> m ()
mapCompilerM_ :: forall (m :: * -> *) a b.
CollectErrorsM m =>
(a -> m b) -> [a] -> m ()
mapCompilerM_ a -> m b
f = forall (f :: * -> *) (m :: * -> *) a.
(Foldable f, CollectErrorsM m) =>
f (m a) -> m ()
collectAllM_ forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map a -> m b
f

tryCompilerM :: CollectErrorsM m => m a -> m (Maybe a)
tryCompilerM :: forall (m :: * -> *) a. CollectErrorsM m => m a -> m (Maybe a)
tryCompilerM m a
x = forall (m :: * -> *) (f :: * -> *) a.
(CollectErrorsM m, Foldable f) =>
f (m a) -> m a
collectFirstM [m a
x forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just,forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing]

isCompilerError :: (ErrorContextT t, ErrorContextM (t Identity)) => t Identity a -> Bool
isCompilerError :: forall (t :: (* -> *) -> * -> *) a.
(ErrorContextT t, ErrorContextM (t Identity)) =>
t Identity a -> Bool
isCompilerError = forall a. Identity a -> a
runIdentity forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(ErrorContextT t, Monad m, ErrorContextM (t m)) =>
t m a -> m Bool
isCompilerErrorT

isCompilerSuccess :: (ErrorContextT t, ErrorContextM (t Identity)) => t Identity a -> Bool
isCompilerSuccess :: forall (t :: (* -> *) -> * -> *) a.
(ErrorContextT t, ErrorContextM (t Identity)) =>
t Identity a -> Bool
isCompilerSuccess = forall a. Identity a -> a
runIdentity forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(ErrorContextT t, Monad m, ErrorContextM (t m)) =>
t m a -> m Bool
isCompilerSuccessT

isCompilerErrorM :: CollectErrorsM m => m a -> m Bool
isCompilerErrorM :: forall (m :: * -> *) a. CollectErrorsM m => m a -> m Bool
isCompilerErrorM m a
x = forall (m :: * -> *) (f :: * -> *) a.
(CollectErrorsM m, Foldable f) =>
f (m a) -> m a
collectFirstM [m a
x forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False,forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True]

isCompilerSuccessM :: CollectErrorsM m => m a -> m Bool
isCompilerSuccessM :: forall (m :: * -> *) a. CollectErrorsM m => m a -> m Bool
isCompilerSuccessM m a
x = forall (m :: * -> *) (f :: * -> *) a.
(CollectErrorsM m, Foldable f) =>
f (m a) -> m a
collectFirstM [m a
x forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True,forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False]

mapErrorsM :: CollectErrorsM m => [String] -> m a
mapErrorsM :: forall (m :: * -> *) a. CollectErrorsM m => [String] -> m a
mapErrorsM [String]
es = forall (m :: * -> *) a b.
CollectErrorsM m =>
(a -> m b) -> [a] -> m ()
mapCompilerM_ forall (m :: * -> *) a. ErrorContextM m => String -> m a
compilerErrorM [String]
es forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. CollectErrorsM m => m a
emptyErrorM

emptyErrorM :: CollectErrorsM m => m a
emptyErrorM :: forall (m :: * -> *) a. CollectErrorsM m => m a
emptyErrorM = forall (m :: * -> *) a. ErrorContextM m => String -> m a
compilerErrorM String
""

silenceErrorsM :: CollectErrorsM m => m a -> m a
silenceErrorsM :: forall (m :: * -> *) a. CollectErrorsM m => m a -> m a
silenceErrorsM m a
x = do
  [a]
x' <- forall (m :: * -> *) (f :: * -> *) a.
(CollectErrorsM m, Foldable f) =>
f (m a) -> m [a]
collectAnyM [m a
x]
  case [a]
x' of
       [a
y] -> forall (m :: * -> *) a. Monad m => a -> m a
return a
y
       [a]
_   -> forall (m :: * -> *) a. CollectErrorsM m => m a
emptyErrorM

errorFromIO :: (MonadIO m, ErrorContextM m) => IO a -> m a
errorFromIO :: forall (m :: * -> *) a. (MonadIO m, ErrorContextM m) => IO a -> m a
errorFromIO IO a
x = do
  Either String a
x' <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. b -> Either a b
Right IO a
x forall a. IO a -> (IOError -> IO a) -> IO a
`catchIOError` (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> Either a b
Left forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> String
show)
  case Either String a
x' of
       (Right a
x2) -> forall (m :: * -> *) a. Monad m => a -> m a
return a
x2
       (Left String
e)   -> forall (m :: * -> *) a. ErrorContextM m => String -> m a
compilerErrorM String
e

-- For fixed x, if f y x succeeds for some y then x is removed.
mergeObjectsM :: CollectErrorsM m => (a -> a -> m b) -> [a] -> m [a]
mergeObjectsM :: forall (m :: * -> *) a b.
CollectErrorsM m =>
(a -> a -> m b) -> [a] -> m [a]
mergeObjectsM a -> a -> m b
f = [a] -> [a] -> m [a]
merge [] where
  merge :: [a] -> [a] -> m [a]
merge [a]
cs [] = forall (m :: * -> *) a. Monad m => a -> m a
return [a]
cs
  merge [a]
cs (a
x:[a]
xs) = do
    [a]
ys <- forall (m :: * -> *) (f :: * -> *) a.
(CollectErrorsM m, Foldable f) =>
f (m a) -> m a
collectFirstM forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {a}. a -> m [a]
check ([a]
cs forall a. [a] -> [a] -> [a]
++ [a]
xs) forall a. [a] -> [a] -> [a]
++ [forall (m :: * -> *) a. Monad m => a -> m a
return [a
x]]
    [a] -> [a] -> m [a]
merge ([a]
cs forall a. [a] -> [a] -> [a]
++ [a]
ys) [a]
xs where
      check :: a -> m [a]
check a
x2 = a
x2 a -> a -> m b
`f` a
x forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return []

instance ErrorContextM m => ErrorContextM (StateT a m) where
  compilerErrorM :: forall a. String -> StateT a m a
compilerErrorM      = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. ErrorContextM m => String -> m a
compilerErrorM
  withContextM :: forall a. StateT a m a -> String -> StateT a m a
withContextM        = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. ErrorContextM m => String -> m a -> m a
(??>)
  summarizeErrorsM :: forall a. StateT a m a -> String -> StateT a m a
summarizeErrorsM    = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. ErrorContextM m => String -> m a -> m a
(!!>)
  compilerWarningM :: String -> StateT a m ()
compilerWarningM    = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). ErrorContextM m => String -> m ()
compilerWarningM
  compilerBackgroundM :: String -> StateT a m ()
compilerBackgroundM = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). ErrorContextM m => String -> m ()
compilerBackgroundM
  resetBackgroundM :: forall a. StateT a m a -> StateT a m a
resetBackgroundM    = forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT forall (m :: * -> *) a. ErrorContextM m => m a -> m a
resetBackgroundM