{-# LANGUAGE ScopedTypeVariables #-}

module Test.Torch.Types.Instances (
  Ok (..)
, Is (..)
, Named (..)
, IsBottom (..)
-- , Skipped (..)

, SimpleFailure (..)
, UnexpectedValue (..)
, Bottom (..)
) where

import Control.Parallel.Strategies (NFData, using, rnf)
import Control.Exception (evaluate, try, SomeException(..))
import Control.Monad (liftM)
import Control.Monad.Trans (MonadIO, liftIO)

import Test.Torch.Types


{- Tests -}

-- Simple Assertion
data Ok = Ok Bool

instance Test Ok where
    run (Ok bool) = do
      bool' <- eval bool
      return $ case bool' of
                 Left e      -> bottom e
                 Right True  -> pass
                 Right False -> simple_failure

-- Assertion for Equality
data Is where
    Is :: (Eq a, Show a) => Bool -> a -> a -> Is

instance Test Is where
    run (Is eq expected got) = do
      isEq <- eval (expected == got)
      return $ case isEq of
                 Left e      -> bottom e
                 Right isEq' -> if eq == isEq'
                                then pass
                                else unexpected_value eq expected got

-- Named Assertion
data Named where
    Named :: (Test t) => String -> t -> Named

instance Test Named where
    run (Named name test) = do
      result <- run test
      return $ case result of
                 Pass         -> pass
                 Fail failure -> named_failure name failure
--                 _            -> result -- result is Skip..

-- Assertion for whether it is Bottom
data IsBottom where
    IsBottom :: (NFData a) => Bool -> a -> IsBottom

instance Test IsBottom where
    run (IsBottom bool a) = do
      isBtm <- either (const True) (const False) `liftM` eval a
      return $ if isBtm == bool then pass else simple_failure

-- Test that always skipped
-- data Skipped where
--     Skipped :: (Test t) => t -> SkipReason -> Skipped

-- instance Test Skipped where
--     run (Skipped _ reason) = return $ Skip reason


{- Failures -}

-- Very Simple Failure
data SimpleFailure = SimpleFailure

instance Failure SimpleFailure where
    describe _ = "failed."

-- Failure that got Unexpected Value
data UnexpectedValue where
    UnexpectedValue :: (Show a) => Bool -> a -> a -> UnexpectedValue
 
instance Failure UnexpectedValue where
    describe (UnexpectedValue eq gotten expected) =
      (if eq then "expected " else "not expected ") ++
        show expected ++ ", but got " ++ show gotten ++ "."

-- Failure of Named Test
data NamedFailure where
    NamedFailure :: (Failure f) => String -> f -> NamedFailure

instance Failure NamedFailure where
    describe (NamedFailure name failure) =
        name ++ ": " ++ describe failure

-- Failure that met bottom with an exception on running test
data Bottom = Bottom SomeException

instance Failure Bottom where
    describe (Bottom (SomeException e)) = "failed with exception: " ++ show e

instance Failure SomeFailure where
    describe (SomeFailure f) = describe f

{- Local Utilities -}

simple_failure, pass :: Result
simple_failure = Fail SimpleFailure
pass = Pass

bottom :: SomeException -> Result
bottom = Fail . Bottom

unexpected_value :: (Show a) => Bool -> a -> a -> Result
unexpected_value eq expected got = Fail $ UnexpectedValue eq expected got

eval :: (MonadIO io, NFData a) => a -> io (Either SomeException a)
eval expr = liftIO $ try $ evaluate $ (expr `using` rnf)

named_failure :: (Failure f) => String -> f -> Result
named_failure name f = Fail $ NamedFailure name f