{-# LANGUAGE CPP                        #-}
{-# LANGUAGE Safe                       #-}
{-# LANGUAGE PolyKinds                  #-}
{-# LANGUAGE ConstraintKinds            #-}
{-# LANGUAGE DefaultSignatures          #-}
{-# LANGUAGE DeriveFunctor              #-}
{-# LANGUAGE DeriveGeneric              #-}
{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE NoImplicitPrelude          #-}
{-# LANGUAGE RebindableSyntax           #-}
{-# LANGUAGE TypeOperators              #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE RankNTypes               #-}

module Data.Semimodule.Basis where

import safe Control.Applicative
import safe Control.Category (Category, (>>>))
import safe Data.Algebra
import safe Data.Bool
import safe Data.Distributive
import safe Data.Foldable as Foldable (fold, foldl')
import safe Data.Functor.Rep
import safe Data.Group
import safe Data.Magma
import safe Data.Profunctor
import safe Data.Semifield
import safe Data.Semigroup.Foldable as Foldable1
import safe Data.Semimodule
import safe Data.Semiring
import safe Prelude hiding (Num(..), Fractional(..), negate, sum, product)
import safe qualified Prelude as P

type Basis b f = (Free f, Rep f ~ b)

-------------------------------------------------------------------------------
-- V2
-------------------------------------------------------------------------------

data V2 a = V2 !a !a deriving (Eq,Ord,Show)

-- | Vector addition.
--
-- >>> V2 1 2 <> V2 3 4
-- V2 4 6
--
instance (Additive-Semigroup) a => Semigroup (V2 a) where
  (<>) = mzipWithRep (+)

-- | Matrix addition.
--
-- >>> m23 1 2 3 4 5 6 <> m23 7 8 9 1 2 3 :: M23 Int
-- V2 (V3 8 10 12) (V3 5 7 9)
--
instance (Additive-Semigroup) a => Semigroup (Additive (V2 a)) where
  (<>) = liftA2 $ mzipWithRep (+)

instance (Additive-Monoid) a => Monoid (V2 a) where
  mempty = pureRep zero

instance (Additive-Monoid) a => Monoid (Additive (V2 a)) where
  mempty = pure $ pureRep zero

-- | Vector subtraction.
--
-- >>> V2 1 2 << V2 3 4
-- V2 (-2) (-2)
--
instance (Additive-Group) a => Magma (V2 a) where
  (<<) = mzipWithRep (-)

-- | Matrix subtraction.
--
-- >>> m23 1 2 3 4 5 6 << m23 7 8 9 1 2 3 :: M23 Int
-- V2 (V3 (-6) (-6) (-6)) (V3 3 3 3)
--
instance (Additive-Group) a => Magma (Additive (V2 a)) where
  (<<) = liftA2 $ mzipWithRep (-)

instance (Additive-Group) a => Quasigroup (V2 a)
instance (Additive-Group) a => Quasigroup (Additive (V2 a))
instance (Additive-Group) a => Loop (V2 a)
instance (Additive-Group) a => Loop (Additive (V2 a))
instance (Additive-Group) a => Group (V2 a)
instance (Additive-Group) a => Group (Additive (V2 a))

instance Semiring a => Semimodule a (V2 a) where
  a *. f = (a *) <$> f
  {-# INLINE (*.) #-}

instance Functor V2 where
  fmap f (V2 a b) = V2 (f a) (f b)
  {-# INLINE fmap #-}
  a <$ _ = V2 a a
  {-# INLINE (<$) #-}

instance Foldable V2 where
  foldMap f (V2 a b) = f a <> f b
  {-# INLINE foldMap #-}
  null _ = False
  length _ = two

instance Foldable1 V2 where
  foldMap1 f (V2 a b) = f a <> f b
  {-# INLINE foldMap1 #-}

instance Distributive V2 where
  distribute f = V2 (fmap (\(V2 x _) -> x) f) (fmap (\(V2 _ y) -> y) f)
  {-# INLINE distribute #-}

instance Representable V2 where
  type Rep V2 = E2
  tabulate f = V2 (f E21) (f E22)
  {-# INLINE tabulate #-}

  index (V2 x _) E21 = x
  index (V2 _ y) E22 = y
  {-# INLINE index #-}

-------------------------------------------------------------------------------
-- Standard basis on two real dimensions
-------------------------------------------------------------------------------

data E2 = E21 | E22 deriving (Eq, Ord, Show)

e2 :: Basis E2 f => a -> a -> f a
e2 x y = tabulate f where
  f E21 = x
  f E22 = y

instance Semigroup (Additive E2) where
  Additive E21 <> x = x
  x <> Additive E21 = x

  Additive E22 <> Additive E22 = Additive E21

instance Monoid (Additive E2) where
  mempty = pure E21

-- trivial diagonal algebra
instance Semiring r => Algebra r E2 where
  multiplyWith f = f' where
    fi = f E21 E21
    fj = f E22 E22

    f' E21 = fi
    f' E22 = fj

instance Semiring r => Composition r E2 where
  conjugateWith = id

  normWith f = flip multiplyWith E21 $ \i1 e2 ->
                 flip multiplyWith E22 $ \j1 j2 ->
                   f i1 * f e2 + f j1 * f j2

-------------------------------------------------------------------------------
-- V3
-------------------------------------------------------------------------------


data V3 a = V3 !a !a !a deriving (Eq,Ord,Show)

-- | Vector addition.
--
-- >>> V3 1 2 3 <> V3 4 5 6
-- V3 5 7 9
--
instance (Additive-Semigroup) a => Semigroup (V3 a) where
  (<>) = mzipWithRep (+)

-- | Matrix addition.
--
-- >>> V2 (V3 1 2 3) (V3 4 5 6) <> V2 (V3 7 8 9) (V3 1 2 3)
-- V2 (V3 8 10 12) (V3 5 7 9)
--
instance (Additive-Semigroup) a => Semigroup (Additive (V3 a)) where
  (<>) = liftA2 $ mzipWithRep (+)

instance (Additive-Monoid) a => Monoid (V3 a) where
  mempty = pureRep zero

instance (Additive-Monoid) a => Monoid (Additive (V3 a)) where
  mempty = pure $ pureRep zero

-- | Vector subtraction.
--
-- >>> V3 1 2 3 << V3 4 5 6
-- V3 (-3) (-3) (-3)
--
instance (Additive-Group) a => Magma (V3 a) where
  (<<) = mzipWithRep (-)

-- | Matrix subtraction.
--
-- >>> V3 (V3 1 2 3) (V3 4 5 6) (V3 7 8 9) << V3 (V3 7 8 9) (V3 7 8 9) (V3 7 8 9) 
-- V3 (V3 (-6) (-6) (-6)) (V3 (-3) (-3) (-3)) (V3 0 0 0)
--
instance (Additive-Group) a => Magma (Additive (V3 a)) where
  (<<) = liftA2 $ mzipWithRep (-)

instance (Additive-Group) a => Quasigroup (V3 a)
instance (Additive-Group) a => Quasigroup (Additive (V3 a))
instance (Additive-Group) a => Loop (V3 a)
instance (Additive-Group) a => Loop (Additive (V3 a))
instance (Additive-Group) a => Group (V3 a)
instance (Additive-Group) a => Group (Additive (V3 a))

instance Semiring a => Semimodule a (V3 a) where
  a *. f = (a *) <$> f
  {-# INLINE (*.) #-}

instance Functor V3 where
  fmap f (V3 a b c) = V3 (f a) (f b) (f c)
  {-# INLINE fmap #-}
  a <$ _ = V3 a a a
  {-# INLINE (<$) #-}

instance Foldable V3 where
  foldMap f (V3 a b c) = f a <> f b <> f c
  {-# INLINE foldMap #-}
  null _ = False
  --length _ = 3

instance Foldable1 V3 where
  foldMap1 f (V3 a b c) = f a <> f b <> f c
  {-# INLINE foldMap1 #-}

instance Distributive V3 where
  distribute f = V3 (fmap (\(V3 x _ _) -> x) f) (fmap (\(V3 _ y _) -> y) f) (fmap (\(V3 _ _ z) -> z) f)
  {-# INLINE distribute #-}

instance Representable V3 where
  type Rep V3 = E3
  tabulate f = V3 (f E31) (f E32) (f E33)
  {-# INLINE tabulate #-}

  index (V3 x _ _) E31 = x
  index (V3 _ y _) E32 = y
  index (V3 _ _ z) E33 = z
  {-# INLINE index #-}

-------------------------------------------------------------------------------
-- Standard basis on three real dimensions 
-------------------------------------------------------------------------------

data E3 = E31 | E32 | E33 deriving (Eq, Ord, Show)

e3 :: Basis E3 f => a -> a -> a -> f a
e3 x y z = tabulate f where
  f E31 = x
  f E32 = y
  f E33 = z

instance Semigroup (Additive E3) where
  Additive E31 <> x = x
  x <> Additive E31 = x

  Additive E32 <> Additive E33 = Additive E31
  Additive E33 <> Additive E32 = Additive E31

  Additive E32 <> Additive E32 = Additive E33
  Additive E33 <> Additive E33 = Additive E32

instance Monoid (Additive E3) where
  mempty = pure E31

instance Ring r => Algebra r E3 where
  multiplyWith f = f' where
    e31 = f E32 E33 - f E33 E32
    e32 = f E33 E31 - f E31 E33
    e33 = f E31 E32 - f E32 E31
    f' E31 = e31
    f' E32 = e32
    f' E33 = e33

instance Ring r => Composition r E3 where
  conjugateWith = id

  normWith f = flip multiplyWith' E31 $ \i1 e2 ->
                 flip multiplyWith' E32 $ \j1 j2 ->
                   flip multiplyWith' E33 $ \k1 k2 ->
                     f i1 * f e2 + f j1 * f j2 + f k1 * f k2

   where
    multiplyWith' f = f' where
      e31 = f E31 E31
      e32 = f E32 E32
      e33 = f E33 E33
      f' E31 = e31
      f' E32 = e32
      f' E33 = e33


-------------------------------------------------------------------------------
-- QuaternionBasis
-------------------------------------------------------------------------------

type QuaternionBasis = Maybe E3

instance Ring r => Algebra r QuaternionBasis where
  multiplyWith f = maybe fe f' where
    e = Nothing
    i = Just E31
    j = Just E32
    k = Just E33
    fe = f e e - (f i i + f j j + f k k)
    fi = f e i + f i e + (f j k - f k j)
    fj = f e j + f j e + (f k i - f i k)
    fk = f e k + f k e + (f i j - f j i)
    f' E31 = fi
    f' E32 = fj
    f' E33 = fk

instance Ring r => Unital r QuaternionBasis where
  unitWith x Nothing = x
  unitWith _ _ = zero

instance Ring r => Composition r QuaternionBasis where
  conjugateWith f = maybe fe f' where
    fe = f Nothing
    f' E31 = negate . f $ Just E31
    f' E32 = negate . f $ Just E32
    f' E33 = negate . f $ Just E33

  normWith f = flip multiplyWith zero $ \i1 e2 -> f i1 * conjugateWith f e2

instance Field r => Division r QuaternionBasis where
  reciprocalWith f i = conjugateWith f i / normWith f
{-
reciprocal'' x = divq unit x

divq (Quaternion r0 (V3 r1 r2 r3)) (Quaternion q0 (V3 q1 q2 q3)) =
 (/denom) <$> Quaternion (r0*q0 + r1*q1 + r2*q2 + r3*q3) imag
  where denom = q0*q0 + q1*q1 + q2*q2 + q3*q3
        imag = (V3 (r0*q1 + (negate r1*q0) + (negate r2*q3) + r3*q2)
                   (r0*q2 + r1*q3 + (negate r2*q0) + (negate r3*q1))
                   (r0*q3 + (negate r1*q2) + r2*q1 + (negate r3*q0)))

-}

-------------------------------------------------------------------------------
-- V4
-------------------------------------------------------------------------------

data V4 a = V4 !a !a !a !a deriving (Eq,Ord,Show)

-- | Vector addition.
--
-- >>> V4 1 2 3 4 <> V4 5 6 7 8
-- V4 6 8 10 12 
--
instance (Additive-Semigroup) a => Semigroup (V4 a) where
  (<>) = mzipWithRep (+)

-- | Matrix addition.
--
-- >>> m24 1 2 3 4 5 6 7 8 <> m24 1 2 3 4 5 6 7 8 :: M24 Int
-- V2 (V4 2 4 6 8) (V4 10 12 14 16)
--
instance (Additive-Semigroup) a => Semigroup (Additive (V4 a)) where
  (<>) = liftA2 $ mzipWithRep (+)

instance (Additive-Monoid) a => Monoid (V4 a) where
  mempty = pureRep zero

instance (Additive-Monoid) a => Monoid (Additive (V4 a)) where
  mempty = pure $ pureRep zero

-- | Vector subtraction.
--
-- >>> V4 1 2 3 << V4 4 5 6
-- V4 (-3) (-3) (-3)
--
instance (Additive-Group) a => Magma (V4 a) where
  (<<) = mzipWithRep (-)

-- | Matrix subtraction.
--
-- >>> V4 (V4 1 2 3) (V4 4 5 6) (V4 7 8 9) << V4 (V4 7 8 9) (V4 7 8 9) (V4 7 8 9) 
-- V4 (V4 (-6) (-6) (-6)) (V4 (-3) (-3) (-3)) (V4 0 0 0)
--
instance (Additive-Group) a => Magma (Additive (V4 a)) where
  (<<) = liftA2 $ mzipWithRep (-)

instance (Additive-Group) a => Quasigroup (V4 a)
instance (Additive-Group) a => Quasigroup (Additive (V4 a))
instance (Additive-Group) a => Loop (V4 a)
instance (Additive-Group) a => Loop (Additive (V4 a))
instance (Additive-Group) a => Group (V4 a)
instance (Additive-Group) a => Group (Additive (V4 a))

instance Semiring a => Semimodule a (V4 a) where
  a *. f = (a *) <$> f
  {-# INLINE (*.) #-}

instance Functor V4 where
  fmap f (V4 a b c d) = V4 (f a) (f b) (f c) (f d)
  {-# INLINE fmap #-}
  a <$ _ = V4 a a a a
  {-# INLINE (<$) #-}

instance Foldable V4 where
  foldMap f (V4 a b c d) = f a <> f b <> f c <> f d
  {-# INLINE foldMap #-}
  null _ = False
  length _ = two + two

instance Foldable1 V4 where
  foldMap1 f (V4 a b c d) = f a <> f b <> f c <> f d
  {-# INLINE foldMap1 #-}

instance Distributive V4 where
  distribute f = V4 (fmap (\(V4 x _ _ _) -> x) f) (fmap (\(V4 _ y _ _) -> y) f) (fmap (\(V4 _ _ z _) -> z) f) (fmap (\(V4 _ _ _ w) -> w) f)
  {-# INLINE distribute #-}

instance Representable V4 where
  type Rep V4 = E4
  tabulate f = V4 (f E41) (f E42) (f E43) (f E44)
  {-# INLINE tabulate #-}

  index (V4 x _ _ _) E41 = x
  index (V4 _ y _ _) E42 = y
  index (V4 _ _ z _) E43 = z
  index (V4 _ _ _ w) E44 = w
  {-# INLINE index #-}

-------------------------------------------------------------------------------
-- Standard basis on four real dimensions
-------------------------------------------------------------------------------

data E4 = E41 | E42 | E43 | E44 deriving (Eq, Ord, Show)

e4 :: Basis E4 f => a -> a -> a -> a -> f a
e4 x y z w = tabulate f where
  f E41 = x
  f E42 = y
  f E43 = z
  f E44 = w


data VFour a = VFour !a !a !a !a deriving (Eq,Ord,Show)

-- | Vector addition.
--
-- >>> VFour 1 2 3 4 <> VFour 5 6 7 8
-- VFour 6 8 10 12 
--
instance (Additive-Semigroup) a => Semigroup (VFour a) where
  (<>) = mzipWithRep (+)

-- | Matrix addition.
--
-- >>> m24 1 2 3 4 5 6 7 8 <> m24 1 2 3 4 5 6 7 8 :: M24 Int
-- V2 (VFour 2 4 6 8) (VFour 10 12 14 16)
--
instance (Additive-Semigroup) a => Semigroup (Additive (VFour a)) where
  (<>) = liftA2 $ mzipWithRep (+)

instance (Additive-Monoid) a => Monoid (VFour a) where
  mempty = pureRep zero

instance (Additive-Monoid) a => Monoid (Additive (VFour a)) where
  mempty = pure $ pureRep zero

-- | Vector subtraction.
--
-- >>> VFour 1 2 3 << VFour 4 5 6
-- VFour (-3) (-3) (-3)
--
instance (Additive-Group) a => Magma (VFour a) where
  (<<) = mzipWithRep (-)

-- | Matrix subtraction.
--
-- >>> VFour (VFour 1 2 3) (VFour 4 5 6) (VFour 7 8 9) << VFour (VFour 7 8 9) (VFour 7 8 9) (VFour 7 8 9) 
-- VFour (VFour (-6) (-6) (-6)) (VFour (-3) (-3) (-3)) (VFour 0 0 0)
--
instance (Additive-Group) a => Magma (Additive (VFour a)) where
  (<<) = liftA2 $ mzipWithRep (-)

instance (Additive-Group) a => Quasigroup (VFour a)
instance (Additive-Group) a => Quasigroup (Additive (VFour a))
instance (Additive-Group) a => Loop (VFour a)
instance (Additive-Group) a => Loop (Additive (VFour a))
instance (Additive-Group) a => Group (VFour a)
instance (Additive-Group) a => Group (Additive (VFour a))

instance Semiring a => Semimodule a (VFour a) where
  a *. f = (a *) <$> f
  {-# INLINE (*.) #-}

instance Functor VFour where
  fmap f (VFour a b c d) = VFour (f a) (f b) (f c) (f d)
  {-# INLINE fmap #-}
  a <$ _ = VFour a a a a
  {-# INLINE (<$) #-}

instance Foldable VFour where
  foldMap f (VFour a b c d) = f a <> f b <> f c <> f d
  {-# INLINE foldMap #-}
  null _ = False
  length _ = two + two

instance Foldable1 VFour where
  foldMap1 f (VFour a b c d) = f a <> f b <> f c <> f d
  {-# INLINE foldMap1 #-}

instance Distributive VFour where
  distribute f = VFour (fmap (\(VFour x _ _ _) -> x) f) (fmap (\(VFour _ y _ _) -> y) f) (fmap (\(VFour _ _ z _) -> z) f) (fmap (\(VFour _ _ _ w) -> w) f)
  {-# INLINE distribute #-}

instance Representable VFour where
  type Rep VFour = Either E2 E2
  tabulate f = VFour (f $ Left E21) (f $ Left E22) (f $ Right E21) (f $ Right E22)
  {-# INLINE tabulate #-}

  index (VFour x _ _ _) (Left E21) = x
  index (VFour _ y _ _) (Left E22) = y
  index (VFour _ _ z _) (Right E21) = z
  index (VFour _ _ _ w) (Right E22) = w
  {-# INLINE index #-}