{-# LANGUAGE DeriveFoldable             #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE PatternSynonyms            #-}
{-# LANGUAGE TypeFamilies               #-}
module Data.Strict.Forced
    (
      Forced (Forced, getForced)
    , Data.Strict.Forced.map
    , Data.Strict.Forced.traverse
    , (<!>)
    ) where

import           Control.Arrow
import           Control.DeepSeq
import           Data.Hashable
import           Data.Semigroup (Semigroup(..))
import           GHC.Exts
import           GHC.Float
import           Prelude         (Applicative (..), Bounded (..), Enum (..),
                                  Eq (..), Foldable (..), Fractional (..), Functor(..),
                                  Integral (..), Monoid (..), Num (..),
                                  Ord (..), Read (..), Real (..), RealFrac (..),
                                  Show (..), (<$>), (.))
import           Type.Strict

-- | A newtype to enforce rigid normal form evaluation.
newtype Forced a = Forced_ a
  deriving ( Eq
             , Ord
             , Show
             , Hashable
             , Foldable
             )

instance StrictType seen (Forced a)

-- | A pattern constructor that forces its contents to 'rnf'
pattern Forced :: NFData a => a -> Forced a
{-# COMPLETE Forced #-}
pattern Forced { getForced } <- Forced_ getForced where Forced a = Forced_ (force a)

map :: (NFData a) => (b -> a) -> Forced b -> Forced a
map f (Forced_ b) = Forced (f b)

traverse :: (NFData a, Applicative f) => (b -> f a) -> Forced b -> f (Forced a)
traverse f (Forced_ a) = Forced <$> f a

(<!>) :: NFData a => Forced (t -> a) -> Forced t -> Forced a
Forced_ f <!> Forced_ x = Forced (f x)

instance NFData (Forced a) where rnf _ = ()

instance (NFData a, Read a) => Read(Forced a) where
  readsPrec p inp = [ (Forced x, rest) | (x, rest) <- readsPrec p inp ]

instance (Semigroup a, NFData a, Monoid a) => Monoid (Forced a) where
  mempty = Forced mempty
  mappend = (<>)

instance (NFData a, Semigroup a) => Semigroup (Forced a) where
  Forced a <> Forced b = Forced (a <> b)

instance (NFData a, Bounded a) => Bounded (Forced a) where
  minBound = Forced minBound
  maxBound = Forced maxBound

instance (NFData a, Enum a) => Enum (Forced a) where
  succ = Forced . succ . getForced
  pred = Forced . pred . getForced
  fromEnum = fromEnum . getForced
  toEnum   = Forced . toEnum
  enumFrom = fmap Forced . enumFrom . getForced
  enumFromThen (Forced f) (Forced t) = Forced <$> enumFromThen f t
  enumFromTo (Forced f) (Forced t) = Forced <$> enumFromTo f t
  enumFromThenTo (Forced f) (Forced th) (Forced t) = Forced <$> enumFromThenTo f th t

instance (NFData a, IsList a) => IsList (Forced a) where
  type Item (Forced a) = Item a
  fromList = Forced . fromList
  toList = toList . getForced

instance (NFData a, Num a) => Num (Forced a) where
  Forced a + Forced b = Forced (a + b)
  Forced a - Forced b = Forced (a - b)
  Forced a * Forced b = Forced (a * b)
  negate = Forced . negate . getForced
  abs = Forced . abs . getForced
  signum = Forced . abs . getForced
  fromInteger = Forced . fromInteger

instance (NFData a, Integral a) => Integral (Forced a) where
  quot (Forced a) (Forced b) = Forced (quot a b)
  rem (Forced a) (Forced b) = Forced (rem a b)
  div (Forced a) (Forced b) = Forced (div a b)
  mod (Forced a) (Forced b) = Forced (mod a b)
  quotRem (Forced a) (Forced b) = (Forced *** Forced) (quotRem a b)
  divMod (Forced a) (Forced b) = (Forced *** Forced) (divMod a b)
  toInteger = toInteger . getForced

instance (NFData a, Fractional a) => Fractional (Forced a) where
  Forced a / Forced b = Forced (a / b)
  recip = Forced . recip . getForced
  fromRational = Forced . fromRational

instance (NFData a, Floating a) => Floating (Forced a) where
  pi = Forced pi
  Forced a ** Forced b = Forced (a ** b)
  logBase (Forced a) (Forced b) = Forced (logBase a b)
  exp       = Forced . exp . getForced
  log       = Forced . log . getForced
  sqrt      = Forced . sqrt . getForced
  sin       = Forced . sin . getForced
  cos       = Forced . cos . getForced
  tan       = Forced . tan . getForced
  asin      = Forced . asin . getForced
  acos      = Forced . acos . getForced
  atan      = Forced . atan . getForced
  sinh      = Forced . sinh . getForced
  cosh      = Forced . cosh . getForced
  tanh      = Forced . tanh . getForced
  asinh     = Forced . asinh . getForced
  acosh     = Forced . acosh . getForced
  atanh     = Forced . atanh . getForced
  log1p     = Forced . log1p . getForced
  expm1     = Forced . expm1 . getForced
  log1pexp  = Forced . log1pexp. getForced
  log1mexp  = Forced . log1mexp. getForced

instance (NFData a, RealFloat a) => RealFloat (Forced a) where
  floatRadix = floatRadix . getForced
  floatDigits = floatDigits . getForced
  floatRange = floatRange . getForced
  decodeFloat = decodeFloat . getForced
  encodeFloat i j = Forced (encodeFloat i j)
  exponent = exponent . getForced
  significand = Forced . significand . getForced
  scaleFloat i = Forced . scaleFloat i . getForced
  isNaN = isNaN . getForced
  isInfinite = isInfinite . getForced
  isDenormalized = isDenormalized . getForced
  isNegativeZero = isNegativeZero . getForced
  isIEEE = isIEEE . getForced
  atan2 (Forced a) (Forced b) = Forced (atan2 a b)

instance (NFData a, RealFrac a) => RealFrac (Forced a) where
  properFraction = second Forced . properFraction . getForced
  truncate = truncate . getForced
  round = round . getForced
  ceiling = ceiling . getForced
  floor = floor . getForced

instance (NFData a, Real a) => Real (Forced a) where
  toRational = toRational . getForced