{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_GHC -fno-warn-type-defaults #-}

-- | safe-typed n-dimensional arrays
module NumHask.Array
  ( Array(..)
  , SomeArray(..)
  , row
  , col
  , unsafeRow
  , unsafeCol
  , slice
  , unsafeSlice
  , index
  , unsafeIndex
  , foldAlong
  , mapAlong
  , concatenate
  , zipWith
  , transpose
  , squeeze
  , (><)
  , mmult
  , fromList
  ) where

import Data.Distributive
import Data.Functor.Rep
import Data.Promotion.Prelude
import Data.Singletons
import Data.Singletons.Prelude
import Data.Singletons.TypeLits
import GHC.Exts
import GHC.Show
import GHC.Generics (Generic1)
-- import Control.DeepSeq (NFData1)
import NumHask.Array.Constraints
import NumHask.Prelude hiding (All, Map, (><), mmult, show, row, col, zipWith, transpose)
import qualified Data.Vector as V
import qualified NumHask.Prelude as P
import Data.Kind

-- $setup
-- >>> :set -XDataKinds
-- >>> :set -XOverloadedLists
-- >>> :set -XTypeFamilies
-- >>> let a = [1..24] :: Array '[2,3,4] Int
-- >>> let v = [1,2,3] :: Array '[3] Int

-- | an n-dimensional array where shape is specified at the type level
-- The main purpose of this, beyond safe typing, is to supply the Representable instance with an initial object.
--
-- >>> a
-- [[[1, 2, 3, 4],
--   [5, 6, 7, 8],
--   [9, 10, 11, 12]],
--  [[13, 14, 15, 16],
--   [17, 18, 19, 20],
--   [21, 22, 23, 24]]]
newtype Array (r :: [Nat]) a = Array (V.Vector a) deriving (Functor, Eq, Foldable, Generic, Generic1, NFData)

-- | an n-dimensional array where shape is specified at the value level
data SomeArray a =
  SomeArray [Int]
            (V.Vector a)
  deriving (Functor, Eq, Foldable)

-- | convert a 'Array' to a 'SomeArray', losing the type level shape
someArray :: (SingI r) => Array (r :: [Nat]) a -> SomeArray a
someArray n@(Array v) = SomeArray (shape n) v

instance forall (r :: [Nat]). (SingI r) => HasShape (Array r) where
  type Shape (Array r) = [Int]
  shape _ = fmap fromIntegral (fromSing (sing :: Sing r))

-- | convert from n-dim shape index to a flat index
--
-- >>> ind [2,3,4] [1,1,1]
-- 17
ind :: [Int] -> [Int] -> Int
ind ns xs = sum $ P.zipWith (*) xs (drop 1 $ scanr (*) 1 ns)

-- | convert from a flat index to a shape index
--
-- >>> unind [2,3,4] 17
-- [1,1,1]
unind :: [Int] -> Int -> [Int]
unind ns x =
  fst $
  foldr
    (\a (acc, r) ->
       let (d, m) = divMod r a
       in (m : acc, d))
    ([], x)
    ns

instance forall r. (SingI r) => Distributive (Array r) where
  distribute f =
    Array $ V.generate n $ \i -> fmap (\(Array v) -> V.unsafeIndex v i) f
    where
      n =
        case (sing :: Sing r) of
          SNil -> 1
          (SCons x xs) -> product $ fromInteger <$> (fromSing x : fromSing xs)

instance forall (r :: [Nat]). (SingI r) => Representable (Array r) where
  type Rep (Array r) = [Int]
  tabulate f = Array $ V.generate (product ns) (f . unind ns)
    where
      ns =
        case (sing :: Sing r) of
          SNil -> []
          (SCons x xs) -> fromIntegral <$> (fromSing x : fromSing xs)
  index (Array xs) rs = xs V.! ind ns rs
    where
      ns =
        case (sing :: Sing r) of
          SNil -> []
          (SCons x xs') -> fromIntegral <$> (fromSing x : fromSing xs')

-- | from flat list
instance (SingI r, Num a) => IsList (Array (r :: [Nat]) a) where
  type Item (Array r a) = a
  fromList l = Array $ V.fromList $ take n $ l ++ repeat 0
    where
      n =
        case (sing :: Sing r) of
          SNil -> 1
          (SCons x xs') ->
            product $ fromIntegral <$> (fromSing x : fromSing xs')
  toList (Array v) = V.toList v

instance (Show a) => Show (SomeArray a) where
  show r@(SomeArray l _) = go (length l) r
    where
      go n r'@(SomeArray l' v') =
        case length l' of
          0 -> show $ V.head v'
          1 -> "[" ++ intercalate ", " (show <$> GHC.Exts.toList v') ++ "]"
          x ->
            "[" ++
            intercalate
              (",\n" ++ replicate (n - x + 1) ' ')
              (go n <$> flatten1 r') ++
            "]"

-- | convert the top layer of a SomeArray to a [SomeArray]
flatten1 :: SomeArray a -> [SomeArray a]
flatten1 (SomeArray rep v) =
  (\s -> SomeArray (drop 1 rep) (V.unsafeSlice (s * l) l v)) <$> ss
  where
    (n, l) =
      case rep of
        [] -> (0, 1)
        x:r -> (x, product r)
    ss = take n [0 ..]

instance (Show a, SingI r) => Show (Array (r :: [Nat]) a) where
  show = show . someArray

-- instance NFData (Array (r :: [Nat]) a) where
    -- nrf (Array v) = Array (nrf v)

-- ** Operations
-- | outer product
--
-- todo: reconcile with numhask version
--
-- >>> v >< v
-- [[1, 2, 3],
--  [2, 4, 6],
--  [3, 6, 9]]
(><) ::
     forall (r :: [Nat]) (s :: [Nat]) a.
     (CRing a, SingI r, SingI s, SingI (r :++ s))
  => Array r a
  -> Array s a
  -> Array (r :++ s) a
(><) m n = tabulate (\i -> index m (take dimm i) * index n (drop dimm i))
  where
    dimm = length (shape m)

-- | matrix multiplication for a '2-Array'
--
-- >>> let a = [1, 2, 3, 4] :: Array '[2, 2] Int
-- >>> let b = [5, 6, 7, 8] :: Array '[2, 2] Int
-- >>> a
-- [[1, 2],
--  [3, 4]]
-- >>> b
-- [[5, 6],
--  [7, 8]]
-- >>> mmult a b
-- [[19, 22],
--  [43, 50]]
mmult ::
     forall m n k a.
     (Semiring a, Num a, CRing a, KnownNat m, KnownNat n, KnownNat k)
  => Array '[ m, k] a
  -> Array '[ k, n] a
  -> Array '[ m, n] a
mmult x y = tabulate (\[i, j] -> unsafeRow i x <.> unsafeCol j y)

-- | extract the row of a matrix
row ::
     forall i a m n. (KnownNat m, KnownNat n, KnownNat i, (i :< m) ~ 'True)
  => Proxy i
  -> Array '[ m, n] a
  -> Array '[ n] a
row i_ = unsafeRow i
  where
    i = (fromIntegral . fromSing . singByProxy) i_

unsafeRow ::
     forall a m n. (KnownNat m, KnownNat n)
  => Int
  -> Array '[ m, n] a
  -> Array '[ n] a
unsafeRow i t@(Array a) = Array $ V.unsafeSlice (i * n) n a
  where
    [_, n] = shape t

-- | extract the column of a matrix
col ::
     forall j a m n. (KnownNat m, KnownNat n, KnownNat j, (j :< n) ~ 'True)
  => Proxy j
  -> Array '[ m, n] a
  -> Array '[ m] a
col j_ = unsafeCol j
  where
    j = (fromIntegral . fromSing . singByProxy) j_

unsafeCol ::
     forall a m n. (KnownNat m, KnownNat n)
  => Int
  -> Array '[ m, n] a
  -> Array '[ m] a
unsafeCol j t@(Array a) = Array $ V.generate m (\x -> a V.! (j + x * n))
  where
    [m, n] = shape t

-- |
--
-- >>> unsafeIndex a [0,2,1]
-- 10
unsafeIndex :: SingI r => Array r a -> [Int] -> a
unsafeIndex t@(Array a) i = a V.! ind (shape t) i

-- |
--
-- >>> unsafeSlice [[0,1],[2],[1,2]] a :: Array '[2,1,2] Int
-- [[[10, 11]],
--  [[22, 23]]]
unsafeSlice :: (SingI r) => [[Int]] -> Array r a -> Array r0 a
unsafeSlice s t = Array (V.fromList [unsafeIndex t i | i <- sequence s])

-- | Slice xs = Map Length xs
type family Slice (xss :: [[Nat]]) :: [Nat] where
  Slice xss = Map LengthSym0 xss

-- | AllLT xs n = All (n >) xs
data AllLTSym0 (a :: TyFun [Nat] (TyFun Nat Bool -> Type))

data AllLTSym1 (l :: [Nat]) (a :: TyFun Nat Bool)

type instance Apply AllLTSym0 l = AllLTSym1 l

type instance Apply (AllLTSym1 l) n = All ((:>$$) n) l

-- |
--
-- >>> slice (Proxy :: Proxy '[ '[0,1],'[2],'[1,2]]) a
-- [[[10, 11]],
--  [[22, 23]]]
slice ::
     forall s r a. (SingI s, SingI r, And (ZipWith AllLTSym0 s r) ~ 'True)
  => Proxy s
  -> Array r a
  -> Array (Slice s) a
slice s_ = unsafeSlice s
  where
    s = ((fmap . fmap) fromInteger . fromSing . singByProxy) s_

-- Chunks a vector v into a list of modules whose dimension is each i
chunkItUp :: [V.Vector a] -> Int -> V.Vector a -> [V.Vector a]
chunkItUp acc i v =
  if null v
    then acc
    else let (c, r) = V.splitAt i v
         in chunkItUp (c : acc) i r

zipWith :: (a -> a -> a) -> Array s a -> Array s a -> Array s a
zipWith fn (Array a) (Array b) = Array $ V.zipWith fn a b

-- |
--
-- >>> foldAlong (Proxy :: Proxy 1) (\_ -> ([0..3] :: Array '[4] Int)) a
-- [[0, 1, 2, 3],
--  [0, 1, 2, 3]]
--
-- todo: resolution of a primitive and a scalar eg
--        Expected type: Array '[10] Int -> Array '[] Int
--        Actual type: Array '[10] (Array '[] Int) -> Array '[] Int
foldAlong ::
     forall s vw uvw uw w a.
     ( SingI s
     , SingI uvw
     , uw ~ (Fold s uvw)
     , w ~ (Drop 1 vw)
     , vw ~ (TailModule s uvw)
     )
  => Proxy s
  -> (Array vw a -> Array w a)
  -> Array uvw a
  -> Array uw a
foldAlong s_ f a@(Array v) =
  Array $
  V.concat
    (foldl'
       (\xs x ->
          let (Array vx) = f (Array x)
          in vx : xs)
       []
       md)
  where
    s = (fromInteger . fromSing . singByProxy) s_
    md = chunkItUp [] (product $ drop s $ shape a) v

-- |
--
-- >>> mapAlong (Proxy :: Proxy 0) (\x -> NumHask.Array.zipWith (*) x x) a
-- [[[1, 4, 9, 16],
--   [25, 36, 49, 64],
--   [81, 100, 121, 144]],
--  [[169, 196, 225, 256],
--   [289, 324, 361, 400],
--   [441, 484, 529, 576]]]
mapAlong ::
     forall s uvw vw a. (SingI s, SingI uvw, vw ~ (HeadModule s uvw))
  => Proxy s
  -> (Array vw a -> Array vw a)
  -> Array uvw a
  -> Array uvw a
mapAlong s_ f a@(Array v) =
  Array $
  V.concat
    (foldl'
       (\xs x ->
          let (Array vx) = f (Array x)
          in vx : xs)
       []
       md)
  where
    s = (fromInteger . fromSing . singByProxy) s_
    md = chunkItUp [] (product $ drop s $ shape a) v

-- |
--
-- >>> concatenate (Proxy :: Proxy 2) a a
-- [[[1, 2, 3, 4, 1, 2, 3, 4],
--   [5, 6, 7, 8, 5, 6, 7, 8],
--   [9, 10, 11, 12, 9, 10, 11, 12]],
--  [[13, 14, 15, 16, 13, 14, 15, 16],
--   [17, 18, 19, 20, 17, 18, 19, 20],
--   [21, 22, 23, 24, 21, 22, 23, 24]]]
concatenate ::
     forall s r t a. (SingI s, SingI r, SingI t, (IsValidConcat s t r) ~ 'True)
  => Proxy s
  -> Array r a
  -> Array t a
  -> Array (Concatenate s t r) a
concatenate s_ r@(Array vr) t@(Array vt) =
  Array . V.concat $ (concat . reverse . P.transpose) [rm, tm]
  where
    s = (fromInteger . fromSing . singByProxy) s_
    rm = chunkItUp [] (product $ drop s $ shape t) vt
    tm = chunkItUp [] (product $ drop s $ shape r) vr

-- |
--
-- >>> NumHask.Array.transpose a
-- [[[1, 2],
--   [3, 4],
--   [5, 6]],
--  [[7, 8],
--   [9, 10],
--   [11, 12]],
--  [[13, 14],
--   [15, 16],
--   [17, 18]],
--  [[19, 20],
--   [21, 22],
--   [23, 24]]]
transpose ::
     forall s t a. (t ~ Transpose s)
  => Array s a
  -> Array t a
transpose (Array x) = Array x

-- |
--
-- >>> let a = [1..24] :: Array '[2,1,3,4,1] Int
-- >>> a
-- [[[[[1],
--     [2],
--     [3],
--     [4]],
--    [[5],
--     [6],
--     [7],
--     [8]],
--    [[9],
--     [10],
--     [11],
--     [12]]]],
--  [[[[13],
--     [14],
--     [15],
--     [16]],
--    [[17],
--     [18],
--     [19],
--     [20]],
--    [[21],
--     [22],
--     [23],
--     [24]]]]]
-- >>> squeeze a
-- [[[1, 2, 3, 4],
--   [5, 6, 7, 8],
--   [9, 10, 11, 12]],
--  [[13, 14, 15, 16],
--   [17, 18, 19, 20],
--   [21, 22, 23, 24]]]
squeeze ::
     forall s t a. (t ~ Squeeze s)
  => Array s a
  -> Array t a
squeeze (Array x) = Array x

instance (SingI r, AdditiveMagma a) => AdditiveMagma (Array r a) where
  plus = liftR2 plus

instance (SingI r, AdditiveUnital a) => AdditiveUnital (Array r a) where
  zero = pureRep zero

instance (SingI r, AdditiveAssociative a) =>
         AdditiveAssociative (Array r a)

instance (SingI r, AdditiveCommutative a) =>
         AdditiveCommutative (Array r a)

instance (SingI r, AdditiveInvertible a) => AdditiveInvertible (Array r a) where
  negate = fmapRep negate

instance (SingI r, Additive a) => Additive (Array r a)

instance (SingI r, AdditiveGroup a) => AdditiveGroup (Array r a)

instance (SingI r, MultiplicativeMagma a) =>
         MultiplicativeMagma (Array r a) where
  times = liftR2 times

instance (SingI r, MultiplicativeUnital a) =>
         MultiplicativeUnital (Array r a) where
  one = pureRep one

instance (SingI r, MultiplicativeAssociative a) =>
         MultiplicativeAssociative (Array r a)

instance (SingI r, MultiplicativeCommutative a) =>
         MultiplicativeCommutative (Array r a)

instance (SingI r, MultiplicativeInvertible a) =>
         MultiplicativeInvertible (Array r a) where
  recip = fmapRep recip

instance (SingI r, Multiplicative a) => Multiplicative (Array r a)

instance (SingI r, MultiplicativeGroup a) =>
         MultiplicativeGroup (Array r a)

instance (SingI r, MultiplicativeMagma a, Additive a) =>
         Distribution (Array r a)

instance (SingI r, Semiring a) => Semiring (Array r a)

instance (SingI r, Ring a) => Ring (Array r a)

instance (SingI r, CRing a) => CRing (Array r a)

instance (SingI r, Field a) => Field (Array r a)

instance (SingI r, ExpField a) => ExpField (Array r a) where
  exp = fmapRep exp
  log = fmapRep log

instance (SingI r, BoundedField a) => BoundedField (Array r a) where
  isNaN f = or (fmapRep isNaN f)

instance (SingI r, Signed a) => Signed (Array r a) where
  sign = fmapRep sign
  abs = fmapRep abs

instance (ExpField a) => Normed (Array r a) a where
  size r = sqrt $ foldr (+) zero $ (** (one + one)) <$> r

instance (SingI r, Epsilon a) => Epsilon (Array r a) where
  nearZero f = and (fmapRep nearZero f)
  aboutEqual a b = and (liftR2 aboutEqual a b)

instance (SingI r, ExpField a) => Metric (Array r a) a where
  distance a b = size (a - b)

instance (SingI r, Integral a) => Integral (Array r a) where
  divMod a b = (d, m)
    where
      x = liftR2 divMod a b
      d = fmap fst x
      m = fmap snd x

instance (CRing a, Num a, Semiring a, SingI r) => Hilbert (Array r) a where
  a <.> b = sum $ liftR2 (*) a b