{-# LANGUAGE Safe                       #-}
{-# LANGUAGE TypeFamilies               #-}
module Data.Algebra.Dual where

import safe Control.Applicative
import safe Data.Algebra
import safe Data.Bool
import safe Data.Distributive
import safe Data.Functor.Classes
import safe Data.Functor.Compose
import safe Data.Functor.Rep
import safe Data.Semifield
import safe Data.Semigroup.Foldable as Foldable1
import safe Data.Semimodule
import safe Data.Semimodule.Basis
import safe Data.Semimodule.Transform
import safe Data.Semiring
import safe Prelude hiding (Num(..), Fractional(..), negate, sum, product)

type D = Dual

-- | A < https://en.wikipedia.org/wiki/Dual_number dual number >.
--
data Dual a = Dual a a deriving (Eq,Show)

instance Show1 Dual where
  liftShowsPrec f _ d (Dual a b) = showsBinaryWith f f "Dual" d a b

instance Representable Dual where
  type Rep Dual = D2
  tabulate f = Dual (f D21) (f D22)
  index (Dual a _ ) D21 = a
  index (Dual _ b ) D22 = b

instance Distributive Dual where
  distribute = distributeRep

instance Functor Dual where
  fmap f (Dual a b) = Dual (f a) (f b)

instance Applicative Dual where
  pure = pureRep
  (<*>) = apRep

instance Foldable Dual where
  foldMap f (Dual a b) = f a <> f b

instance Traversable Dual where
  traverse f (Dual a b) = Dual <$> f a <*> f b

instance Foldable1 Dual where
  foldMap1 f (Dual a b) = f a <> f b

instance (Additive-Semigroup) a => Semigroup (Additive (Dual a)) where
  (<>) = liftA2 $ mzipWithRep (+)

instance (Additive-Monoid) a => Monoid (Additive (Dual a)) where
  mempty = pure $ pureRep zero

instance (Additive-Group) a => Magma (Additive (Dual a)) where
  (<<) = liftA2 $ mzipWithRep (-)

instance (Additive-Group) a => Quasigroup (Additive (Dual a))
instance (Additive-Group) a => Loop (Additive (Dual a))
instance (Additive-Group) a => Group (Additive (Dual a))

{-
instance LeftSemimodule l s => LeftSemimodule l (Dual s) where
  lscale l (Dual a b) = Dual (l *. a) (l *. b)

instance RightSemimodule r s => RightSemimodule r (Dual s) where
  rscale r (Dual a b) = Dual (a .* r) (b .* r)
-}
instance Semiring a => LeftSemimodule a (Dual a) where
  lscale = lscaleDef
  {-# INLINE lscale #-}

instance Semiring a => RightSemimodule a (Dual a) where
  rscale = rscaleDef
  {-# INLINE rscale #-}

instance Semiring a => Bisemimodule a a (Dual a)