{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}

module NumHask.Array.HMatrix
  ( Array(..),
    index,
    mmult,
  ) where

import GHC.Exts (IsList (..))
import GHC.Show (Show (..))
import NumHask.Array.Shape
import NumHask.Prelude as P
import qualified Numeric.LinearAlgebra as H
import qualified Numeric.LinearAlgebra.Devel as H
import qualified Prelude

newtype Array s a = Array {unArray :: H.Matrix a}
  deriving (Show, NFData, Generic)

-- | explicit rather than via Representable
index ::
  forall s a.
  ( HasShape s,
    H.Element a,
    H.Container H.Vector a
  ) =>
  Array s a ->
  [Int] ->
  a
index (Array v) i = H.flatten v `H.atIndex` flatten s i
  where
    s = shapeVal (toShape @s)

instance
  ( Additive a,
    HasShape s,
    H.Container H.Vector a,
    Num a
  ) =>
  Additive (Array s a)
  where

  (+) (Array x1) (Array x2) = Array $ H.add x1 x2

  zero = Array $ H.konst zero (n, m)
    where
      s = shapeVal (toShape @s)
      [n, m] = s

instance
  ( Multiplicative a,
    HasShape s,
    H.Container H.Vector a,
    Num (H.Vector a),
    Num a
  ) =>
  Multiplicative (Array s a)
  where

  (*) (Array x1) (Array x2) = Array $ H.liftMatrix2 (Prelude.*) x1 x2

  one = Array $ H.konst one (n, m)
    where
      s = shapeVal (toShape @s)
      [n, m] = s

type instance Actor (Array s a) = a

instance
  ( HasShape s,
    P.Distributive a,
    CommutativeRing a,
    Semiring a,
    H.Container H.Vector a,
    Num (H.Vector a),
    Num a
  ) =>
  Hilbert (Array s a)
  where
  (<.>) (Array a) (Array b) = H.sumElements $ H.liftMatrix2 (Prelude.*) a b
  {-# INLINE (<.>) #-}

instance
  ( HasShape s,
    Multiplicative a,
    H.Container H.Vector a,
    Num a
  ) =>
  MultiplicativeAction (Array s a)
  where

  (.*) (Array r) s = Array $ H.cmap (* s) r
  {-# INLINE (.*) #-}

  (*.) s (Array r) = Array $ H.cmap (s *) r
  {-# INLINE (*.) #-}

-- | from flat list
instance
  ( HasShape s,
    Additive a,
    H.Element a
  ) =>
  IsList (Array s a)
  where

  type Item (Array s a) = a

  fromList l = Array $ H.reshape n $ H.fromList $ take mn $ l ++ repeat zero
    where
      mn = P.product $ shapeVal (toShape @s)
      s = shapeVal (toShape @s)
      n = Prelude.last s

  toList (Array v) = H.toList $ H.flatten v

-- | fast
mmult ::
  forall m n k a.
  ( KnownNat k,
    KnownNat m,
    KnownNat n,
    HasShape [m, n],
    Ring a,
    H.Numeric a
  ) =>
  Array [m, k] a ->
  Array [k, n] a ->
  Array [m, n] a
mmult (Array x) (Array y) = Array $ x H.<> y

type instance Actor (Array s a) = a