{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE RebindableSyntax  #-}

module Precursor.Algebra.Semiring where

import           Precursor.Data.Bool
import           Data.Int  (Int, Int16, Int32, Int64, Int8)
import           GHC.Float (Double, Float)
import qualified GHC.Num   as P
import           Precursor.Numeric.Num

-- | A <https://en.wikipedia.org/wiki/Semiring Semiring> is like the
-- the combination of two 'Precursor.Algebra.Monoid.Monoid's. The first
-- is called '+'; it has the identity element 'zero', and it is
-- commutative. The second is called '*'; it has identity element 'one',
-- and it must distribute over '+'.
--
-- = Laws
-- == Normal 'Precursor.Algebra.Monoid.Monoid' laws
-- * @(a '+' b) '+' c = a '+' (b '+' c)@
-- * @'zero' '+' a = a '+' 'zero' = a@
-- * @(a '*' b) '*' c = a '*' (b '*' c)@
-- * @'one' '*' a = a '*' 'one' = a@
--
-- == Commutativity of '+'
-- * @a '+' b = b '+' a@
--
-- == Distribution of '*' over '+'
-- * @a'*'(b '+' c) = (a'*'b) '+' (a'*'c)@
-- * @(a '+' b)'*'c = (a'*'c) '+' (b'*'c)@
--
-- Another useful law, annihilation, may be deduced from the axioms
-- above:
--
-- * @'zero' '*' a = a '*' 'zero' = 'zero'@
class Semiring a where
  -- | The identity of '*'.
  one  :: a
  -- | The identity of '+'.
  zero :: a
  infixl 7 *
  -- | An associative binary operation, which distributes over '+'.
  (*)  :: a -> a -> a
  -- | An associative, commutative binary operation.
  infixl 6 +
  (+)  :: a -> a -> a

  default one :: Num a => a
  default zero :: Num a => a
  one = 1
  zero = 0

  default (+) :: P.Num a => a -> a -> a
  default (*) :: P.Num a => a -> a -> a
  (+) = (P.+)
  (*) = (P.*)

instance Semiring Int
instance Semiring Int8
instance Semiring Int16
instance Semiring Int32
instance Semiring Int64
instance Semiring P.Integer
instance Semiring Float
instance Semiring Double

instance Semiring Bool where
  one = True
  zero = False
  (*) = (&&)
  (+) = (||)

instance Semiring b => Semiring (a -> b) where
  one  _ = one
  zero _ = zero
  (f * g) x = f x * g x
  (f + g) x = f x + g x