{-# LANGUAGE CPP                        #-}
{-# LANGUAGE Safe                       #-}
{-# LANGUAGE PolyKinds                  #-}
{-# LANGUAGE ConstraintKinds            #-}
{-# LANGUAGE DefaultSignatures          #-}
{-# LANGUAGE DeriveFunctor              #-}
{-# LANGUAGE DeriveGeneric              #-}
{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE NoImplicitPrelude          #-}
{-# LANGUAGE RebindableSyntax           #-}
{-# LANGUAGE TypeOperators              #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE RankNTypes               #-}

module Data.Semimodule.Vector where
{- (
    type Basis
  , (*.)
  , (.*)
  , (`dot`)
  , (.*.)
  , triple
  , lerp
  , quadrance
  , qd
  , dirac
  , module Data.Semimodule.Vector
) where
-}
import safe Control.Applicative
import safe Data.Algebra
import safe Data.Bool
import safe Data.Distributive
import safe Data.Functor.Compose
import safe Data.Functor.Rep
import safe Data.Semifield
import safe Data.Semigroup.Foldable as Foldable1
import safe Data.Semimodule
import safe Data.Semimodule.Transform
import safe Data.Semiring
import safe Prelude hiding (Num(..), Fractional(..), negate, sum, product)

import safe qualified Control.Category as C
import safe qualified Control.Monad as M



-------------------------------------------------------------------------------
-- V1
-------------------------------------------------------------------------------

newtype V1 a = V1 a deriving (Eq,Ord,Show)

unV1 :: V1 a -> a
unV1 (V1 a) = a

-- | Vector addition.
--
-- >>> V1 1 <> V1 3
-- V1 4
--
instance (Additive-Semigroup) a => Semigroup (V1 a) where
  (<>) = mzipWithRep (+)

-- | Matrix addition.
--
-- >>> m13 1 1 3 4 5 6 <> m13 7 8 9 1 1 3 :: M13 Int
-- V1 (V3 8 10 11) (V3 5 7 9)
--
instance (Additive-Semigroup) a => Semigroup (Additive (V1 a)) where
  (<>) = liftA2 $ mzipWithRep (+)

instance (Additive-Monoid) a => Monoid (V1 a) where
  mempty = pureRep zero

instance (Additive-Monoid) a => Monoid (Additive (V1 a)) where
  mempty = pure $ pureRep zero


-- | Vector & matrix subtraction.
--
-- >>> V1 1 - V1 2 :: V1 Int
-- V1 (-1)
-- >>> m12 1 2 - m12 3 4 :: M12 Int
-- V1 (V2 (-2) (-2))
--
instance (Additive-Group) a => Magma (Additive (V1 a)) where
  (<<) = liftA2 $ mzipWithRep (-)

instance (Additive-Group) a => Quasigroup (Additive (V1 a))
instance (Additive-Group) a => Loop (Additive (V1 a))
instance (Additive-Group) a => Group (Additive (V1 a))

instance Semiring a => LeftSemimodule a (V1 a) where
  lscale = lscaleDef
  {-# INLINE lscale #-}

instance Semiring a => RightSemimodule a (V1 a) where
  rscale = rscaleDef
  {-# INLINE rscale #-}

instance Semiring a => Bisemimodule a a (V1 a)

instance Functor V1 where
  fmap f (V1 a) = V1 (f a)
  {-# INLINE fmap #-}
  a <$ _ = V1 a
  {-# INLINE (<$) #-}

instance Applicative V1 where
  pure = pureRep
  liftA2 = liftR2

instance Foldable V1 where
  foldMap f (V1 a) = f a
  {-# INLINE foldMap #-}
  null _ = False
  length _ = one

instance Foldable1 V1 where
  foldMap1 f (V1 a) = f a
  {-# INLINE foldMap1 #-}

instance Distributive V1 where
  distribute f = V1 $ fmap (\(V1 x) -> x) f
  {-# INLINE distribute #-}

instance Representable V1 where
  type Rep V1 = I1
  tabulate f = V1 (f I1)
  {-# INLINE tabulate #-}

  index (V1 x) I1 = x
  {-# INLINE index #-}

-------------------------------------------------------------------------------
-- Standard basis on one real dimension
-------------------------------------------------------------------------------

data I1 = I1 deriving (Eq, Ord, Show)

i1 :: a -> I1 -> a
i1 = const

fillI1 :: Basis I1 f => a -> f a
fillI1 x = tabulate $ i1 x

instance Semigroup (Additive I1) where
  _ <> _ = Additive I1

instance Monoid (Additive I1) where
  mempty = pure I1

-- The squaring function /N(x) = x^2/ on the real number field forms the primordial composition algebra.
--
instance Semiring r => Algebra r I1 where
  mult = M.join

instance Semiring r => Composition r I1 where
  conj = C.id

  norm f = flip mult I1 $ \ i j -> f i * f j

-------------------------------------------------------------------------------
-- V2
-------------------------------------------------------------------------------

data V2 a = V2 !a !a deriving (Eq,Ord,Show)

-- | Vector & matrix addition.
--
-- >>> V2 1 2 + V2 3 4
-- V2 4 6
-- >>> m23 1 2 3 4 5 6 + m23 7 8 9 1 2 3 :: M23 Int
-- V2 (V3 8 10 12) (V3 5 7 9)
--
instance (Additive-Semigroup) a => Semigroup (Additive (V2 a)) where
  (<>) = liftA2 $ mzipWithRep (+)

instance (Additive-Monoid) a => Monoid (Additive (V2 a)) where
  mempty = pure $ pureRep zero

-- | Vector & matrix subtraction.
--
-- >>> V2 1 2 - V2 3 4
-- V2 (-2) (-2)
--
-- >>> m23 1 2 3 4 5 6 - m23 7 8 9 1 2 3 :: M23 Int
-- V2 (V3 (-6) (-6) (-6)) (V3 3 3 3)
--
instance (Additive-Group) a => Magma (Additive (V2 a)) where
  (<<) = liftA2 $ mzipWithRep (-)

instance (Additive-Group) a => Quasigroup (Additive (V2 a))
instance (Additive-Group) a => Loop (Additive (V2 a))
instance (Additive-Group) a => Group (Additive (V2 a))


-- 3 *. V2 1 2  :: V2 Int
-- V2 3 6
instance Semiring a => LeftSemimodule a (V2 a) where
  lscale = lscaleDef
  {-# INLINE lscale #-}

instance Semiring a => RightSemimodule a (V2 a) where
  rscale = rscaleDef
  {-# INLINE rscale #-}

instance Semiring a => Bisemimodule a a (V2 a)

instance Functor V2 where
  fmap f (V2 a b) = V2 (f a) (f b)
  {-# INLINE fmap #-}
  a <$ _ = V2 a a
  {-# INLINE (<$) #-}

instance Applicative V2 where
  pure = pureRep
  liftA2 = liftR2

instance Foldable V2 where
  foldMap f (V2 a b) = f a <> f b
  {-# INLINE foldMap #-}
  null _ = False
  length _ = two

instance Foldable1 V2 where
  foldMap1 f (V2 a b) = f a <> f b
  {-# INLINE foldMap1 #-}

instance Distributive V2 where
  distribute f = V2 (fmap (\(V2 x _) -> x) f) (fmap (\(V2 _ y) -> y) f)
  {-# INLINE distribute #-}

instance Representable V2 where
  type Rep V2 = I2
  tabulate f = V2 (f I21) (f I22)
  {-# INLINE tabulate #-}

  index (V2 x _) I21 = x
  index (V2 _ y) I22 = y
  {-# INLINE index #-}

-------------------------------------------------------------------------------
-- Standard basis on two real dimensions
-------------------------------------------------------------------------------

data I2 = I21 | I22 deriving (Eq, Ord, Show)

i2 :: a -> a -> I2 -> a
i2 x _ I21 = x
i2 _ y I22 = y

fillI2 :: Basis I2 f => a -> a -> f a
fillI2 x y = tabulate $ i2 x y

instance Semigroup (Additive I2) where
  Additive I21 <> x = x
  x <> Additive I21 = x

  Additive I22 <> Additive I22 = Additive I21

instance Monoid (Additive I2) where
  mempty = pure I21

-- trivial diagonal algebra
instance Semiring r => Algebra r I2 where
  mult f = f' where
    fi = f I21 I21
    fj = f I22 I22

    f' I21 = fi
    f' I22 = fj

instance Semiring r => Composition r I2 where
  conj = C.id

  norm f = flip mult I21 $ \ix1 ix2 ->
                 flip mult I22 $ \jx1 jx2 ->
                   f ix1 * f ix2 + f jx1 * f jx2

-------------------------------------------------------------------------------
-- V3
-------------------------------------------------------------------------------


data V3 a = V3 !a !a !a deriving (Eq,Ord,Show)


-- | Vector & matrix addition.
--
-- >>> V3 1 2 3 + V3 4 5 6
-- V3 5 7 9
-- >>> V2 (V3 1 2 3) (V3 4 5 6) + V2 (V3 7 8 9) (V3 1 2 3)
-- V2 (V3 8 10 12) (V3 5 7 9)
--
instance (Additive-Semigroup) a => Semigroup (Additive (V3 a)) where
  (<>) = liftA2 $ mzipWithRep (+)

instance (Additive-Monoid) a => Monoid (Additive (V3 a)) where
  mempty = pure $ pureRep zero

-- | Vector & matrix subtraction.
--
-- >>> V3 1 2 3 - V3 4 5 6
-- V3 (-3) (-3) (-3)
-- >>> V3 (V3 1 2 3) (V3 4 5 6) (V3 7 8 9) - V3 (V3 7 8 9) (V3 7 8 9) (V3 7 8 9) 
-- V3 (V3 (-6) (-6) (-6)) (V3 (-3) (-3) (-3)) (V3 0 0 0)
--
instance (Additive-Group) a => Magma (Additive (V3 a)) where
  (<<) = liftA2 $ mzipWithRep (-)

instance (Additive-Group) a => Quasigroup (Additive (V3 a))
instance (Additive-Group) a => Loop (Additive (V3 a))
instance (Additive-Group) a => Group (Additive (V3 a))

instance Semiring a => LeftSemimodule a (V3 a) where
  lscale = lscaleDef
  {-# INLINE lscale #-}

instance Semiring a => RightSemimodule a (V3 a) where
  rscale = rscaleDef
  {-# INLINE rscale #-}

instance Semiring a => Bisemimodule a a (V3 a)

instance Functor V3 where
  fmap f (V3 a b c) = V3 (f a) (f b) (f c)
  {-# INLINE fmap #-}
  a <$ _ = V3 a a a
  {-# INLINE (<$) #-}

instance Applicative V3 where
  pure = pureRep
  liftA2 = liftR2

instance Foldable V3 where
  foldMap f (V3 a b c) = f a <> f b <> f c
  {-# INLINE foldMap #-}
  null _ = False
  --length _ = 3

instance Foldable1 V3 where
  foldMap1 f (V3 a b c) = f a <> f b <> f c
  {-# INLINE foldMap1 #-}

instance Distributive V3 where
  distribute f = V3 (fmap (\(V3 x _ _) -> x) f) (fmap (\(V3 _ y _) -> y) f) (fmap (\(V3 _ _ z) -> z) f)
  {-# INLINE distribute #-}

instance Representable V3 where
  type Rep V3 = I3
  tabulate f = V3 (f I31) (f I32) (f I33)
  {-# INLINE tabulate #-}

  index (V3 x _ _) I31 = x
  index (V3 _ y _) I32 = y
  index (V3 _ _ z) I33 = z
  {-# INLINE index #-}

-------------------------------------------------------------------------------
-- Standard basis on three real dimensions 
-------------------------------------------------------------------------------

data I3 = I31 | I32 | I33 deriving (Eq, Ord, Show)

i3 :: a -> a -> a -> I3 -> a
i3 x _ _ I31 = x
i3 _ y _ I32 = y
i3 _ _ z I33 = z

fillI3 :: Basis I3 f => a -> a -> a -> f a
fillI3 x y z = tabulate $ i3 x y z

instance Semigroup (Additive I3) where
  Additive I31 <> x = x
  x <> Additive I31 = x

  Additive I32 <> Additive I33 = Additive I31
  Additive I33 <> Additive I32 = Additive I31

  Additive I32 <> Additive I32 = Additive I33
  Additive I33 <> Additive I33 = Additive I32

instance Monoid (Additive I3) where
  mempty = pure I31

instance Ring r => Algebra r I3 where
  mult f = f' where
    i31 = f I32 I33 - f I33 I32
    i32 = f I33 I31 - f I31 I33
    i33 = f I31 I32 - f I32 I31
    f' I31 = i31
    f' I32 = i32
    f' I33 = i33

instance Ring r => Composition r I3 where
  conj = C.id

  norm f = flip mult' I31 $ \ix1 ix2 ->
                 flip mult' I32 $ \jx1 jx2 ->
                   flip mult' I33 $ \kx1 kx2 ->
                     f ix1 * f ix2 + f jx1 * f jx2 + f kx1 * f kx2

   where
    mult' f1 = f1' where
      i31 = f1 I31 I31
      i32 = f1 I32 I32
      i33 = f1 I33 I33
      f1' I31 = i31
      f1' I32 = i32
      f1' I33 = i33


-------------------------------------------------------------------------------
-- QuaternionBasis
-------------------------------------------------------------------------------

type QuaternionBasis = Maybe I3

instance Ring r => Algebra r QuaternionBasis where
  mult f = maybe fe f' where
    e = Nothing
    i = Just I31
    j = Just I32
    k = Just I33
    fe = f e e - (f i i + f j j + f k k)
    fi = f e i + f i e + (f j k - f k j)
    fj = f e j + f j e + (f k i - f i k)
    fk = f e k + f k e + (f i j - f j i)
    f' I31 = fi
    f' I32 = fj
    f' I33 = fk

instance Ring r => Unital r QuaternionBasis where
  unit x Nothing = x
  unit _ _ = zero

instance Ring r => Composition r QuaternionBasis where
  conj = Tran g where
    g f Nothing = f Nothing
    g f (Just I31) = negate . f $ Just I31
    g f (Just I32) = negate . f $ Just I32
    g f (Just I33) = negate . f $ Just I33

instance Field r => Division r QuaternionBasis

{-
reciprocal'' x = divq unit x

divq (Quaternion r0 (V3 r1 r2 r3)) (Quaternion q0 (V3 q1 q2 q3)) =
 (/denom) <$> Quaternion (r0*q0 + r1*q1 + r2*q2 + r3*q3) imag
  where denom = q0*q0 + q1*q1 + q2*q2 + q3*q3
        imag = (V3 (r0*q1 + (negate r1*q0) + (negate r2*q3) + r3*q2)
                   (r0*q2 + r1*q3 + (negate r2*q0) + (negate r3*q1))
                   (r0*q3 + (negate r1*q2) + r2*q1 + (negate r3*q0)))

-}

-------------------------------------------------------------------------------
-- V4
-------------------------------------------------------------------------------

data V4 a = V4 !a !a !a !a deriving (Eq,Ord,Show)

-- | Vector & matrix addition.
--
-- >>> V4 1 2 3 4 + V4 5 6 7 8
-- V4 6 8 10 12 
-- >>> m24 1 2 3 4 5 6 7 8 + m24 1 2 3 4 5 6 7 8 :: M24 Int
-- V2 (V4 2 4 6 8) (V4 10 12 14 16)
--
instance (Additive-Semigroup) a => Semigroup (Additive (V4 a)) where
  (<>) = liftA2 $ mzipWithRep (+)

instance (Additive-Monoid) a => Monoid (Additive (V4 a)) where
  mempty = pure $ pureRep zero

-- | Vector & matrix subtraction.
--
-- >>> V4 1 2 3 - V4 4 5 6
-- V4 (-3) (-3) (-3)
-- >>> V4 (V4 1 2 3) (V4 4 5 6) (V4 7 8 9) - V4 (V4 7 8 9) (V4 7 8 9) (V4 7 8 9) 
-- V4 (V4 (-6) (-6) (-6)) (V4 (-3) (-3) (-3)) (V4 0 0 0)
--
instance (Additive-Group) a => Magma (Additive (V4 a)) where
  (<<) = liftA2 $ mzipWithRep (-)

instance (Additive-Group) a => Quasigroup (Additive (V4 a))
instance (Additive-Group) a => Loop (Additive (V4 a))
instance (Additive-Group) a => Group (Additive (V4 a))

instance Semiring a => LeftSemimodule a (V4 a) where
  lscale = lscaleDef
  {-# INLINE lscale #-}

instance Semiring a => RightSemimodule a (V4 a) where
  rscale = rscaleDef
  {-# INLINE rscale #-}

instance Semiring a => Bisemimodule a a (V4 a)

instance Functor V4 where
  fmap f (V4 a b c d) = V4 (f a) (f b) (f c) (f d)
  {-# INLINE fmap #-}
  a <$ _ = V4 a a a a
  {-# INLINE (<$) #-}

instance Applicative V4 where
  pure = pureRep
  liftA2 = liftR2

instance Foldable V4 where
  foldMap f (V4 a b c d) = f a <> f b <> f c <> f d
  {-# INLINE foldMap #-}
  null _ = False
  length _ = two + two

instance Foldable1 V4 where
  foldMap1 f (V4 a b c d) = f a <> f b <> f c <> f d
  {-# INLINE foldMap1 #-}

instance Distributive V4 where
  distribute f = V4 (fmap (\(V4 x _ _ _) -> x) f) (fmap (\(V4 _ y _ _) -> y) f) (fmap (\(V4 _ _ z _) -> z) f) (fmap (\(V4 _ _ _ w) -> w) f)
  {-# INLINE distribute #-}

instance Representable V4 where
  type Rep V4 = I4
  tabulate f = V4 (f I41) (f I42) (f I43) (f I44)
  {-# INLINE tabulate #-}

  index (V4 x _ _ _) I41 = x
  index (V4 _ y _ _) I42 = y
  index (V4 _ _ z _) I43 = z
  index (V4 _ _ _ w) I44 = w
  {-# INLINE index #-}

-------------------------------------------------------------------------------
-- Standard basis on four real dimensions
-------------------------------------------------------------------------------

data I4 = I41 | I42 | I43 | I44 deriving (Eq, Ord, Show)

i4 :: a -> a -> a -> a -> I4 -> a
i4 x _ _ _ I41 = x
i4 _ y _ _ I42 = y
i4 _ _ z _ I43 = z
i4 _ _ _ w I44 = w

fillI4 :: Basis I4 f => a -> a -> a -> a -> f a
fillI4 x y z w = tabulate $ i4 x y z w