{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE CPP #-}

module Shapes.Linear.MatrixTemplate where

import Data.Monoid
import Language.Haskell.TH

import Shapes.Linear.Template

makeMatrixNL :: (Int, Int) -> (Name, Int)
makeMatrixNL (rows, cols) =
  (mkName $ "M" ++ show rows ++ "x" ++ show cols, rows * cols)

makeMatrixType :: ValueInfo -> (Int, Int) -> DecsQ
makeMatrixType vi@ValueInfo{..} dims = do
  let (matrixN, len) = makeMatrixNL dims
#if MIN_VERSION_template_haskell(2,11,0)
      constrArg = bangType (bang noSourceUnpackedness noSourceStrictness) (conT _valueN)
#else
      constrArg = strictType notStrict (conT _valueN)
#endif
      definers = [ defineLift
                 , defineLift2
                 , defineFromList
                 , defineToList
                 , deriveShow
                 , deriveArbitrary
                 ]
      definers' = [ defineMatrixMulVector
                  , defineVectorMulMatrix
                  , defineDiagMulMatrix
                  , defineMatrixMulDiag
                  , defineVectorOuterProduct
                  ]
  impls <- concat <$> mapM (\f -> f matrixN vi len) definers
  impls' <- concat <$> mapM (\f -> f vi dims) definers'
#if MIN_VERSION_template_haskell(2,12,0)
  matrixD <- dataD (cxt []) matrixN [] Nothing [normalC matrixN (replicate len constrArg)] []
#elif MIN_VERSION_template_haskell(2,11,0)
  matrixD <- dataD (cxt []) matrixN [] Nothing [normalC matrixN (replicate len constrArg)] (mapM conT [])
#else
  matrixD <- dataD (cxt []) matrixN [] [normalC matrixN (replicate len constrArg)] []
#endif
  return $ matrixD : impls ++ impls'

defineMatrixMul :: ValueInfo -> (Int, Int, Int) -> DecsQ
defineMatrixMul vi@ValueInfo{..} (left, inner, right) = do
  let (matN, len) = makeMatrixNL (left, inner)
      (matN', len') = makeMatrixNL (inner, right)
      (matN'', _) = makeMatrixNL (left, right)
  (matP, elemVars) <- conPE matN "a" len
  (matP', elemVars') <- conPE matN "b" len'
  let rows = chunks inner elemVars
      cols = stripes right elemVars'
      dotEs = do
        row <- rows
        col <- cols
        return $ dotE vi row col
      resultE = appsE (conE matN'' : dotEs)
      mulN = mkName $ "mul" ++ show left ++ "x" ++ show inner ++ "x" ++ show right
      mulC = simpleClause [matP, matP'] resultE
      mulT = arrowsT [matT, matT', matT'']
      matT = conT matN
      matT' = conT matN'
      matT'' = conT matN''
  inlSigDef mulN mulT [mulC]

defineMatrixMulVector :: ValueInfo -> (Int, Int) -> DecsQ
defineMatrixMulVector vi@ValueInfo{..} dims@(left, inner) = do
  let (matN, len) = makeMatrixNL dims
      vecN = makeVectorN inner
      vecN' = makeVectorN left
  (matP, elemVars) <- conPE matN "a" len
  (vecP, col) <- conPE vecN "b" inner
  let rows = chunks inner elemVars
      dotEs = do
        row <- rows
        return $ dotE vi row col
      resultE = appsE (conE vecN' : dotEs)
      mulN = mkName $ "mul" ++ show left ++ "x" ++ show inner ++ "c"
      mulC = simpleClause [matP, vecP] resultE
      mulT = arrowsT [matT, vecT, vecT']
      matT = conT matN
      vecT = conT vecN
      vecT' = conT vecN'
  inlSigDef mulN mulT [mulC]

defineVectorMulMatrix :: ValueInfo -> (Int, Int) -> DecsQ
defineVectorMulMatrix vi@ValueInfo{..} dims@(inner, right) = do
  let vecN = makeVectorN inner
      (matN, len) = makeMatrixNL dims
      vecN' = makeVectorN right
  (vecP, row) <- conPE vecN "a" inner
  (matP, elemVars) <- conPE matN "b" len
  let cols = stripes right elemVars
      dotEs = do
        col <- cols
        return $ dotE vi row col
      resultE = appsE (conE vecN' : dotEs)
      mulN = mkName $ "mulr" ++ show inner ++ "x" ++ show right
      mulC = simpleClause [vecP, matP] resultE
      mulT = arrowsT [vecT, matT, vecT']
      vecT = conT vecN
      matT = conT matN
      vecT' = conT vecN'
  inlSigDef mulN mulT [mulC]

defineDiagMulMatrix :: ValueInfo -> (Int, Int) -> DecsQ
defineDiagMulMatrix ValueInfo{..} dims@(inner, right) = do
  let vecN = makeVectorN inner
      (matN, len) = makeMatrixNL dims
  (vecP, diag) <- conPE vecN "a" inner
  (matP, elemVars) <- conPE matN "b" len
  let rows = chunks right elemVars
      rowE scalar = fmap (infixApp' (varE _valueMul) scalar)
      rowEs = zipWith rowE diag rows
      resultE = appsE (conE matN : concat rowEs)
      mulN = mkName $ "muld" ++ show inner ++ "x" ++ show right
      mulC = simpleClause [vecP, matP] resultE
      mulT = arrowsT [vecT, matT, matT]
      vecT = conT vecN
      matT = conT matN
  inlSigDef mulN mulT [mulC]

defineMatrixMulDiag :: ValueInfo -> (Int, Int) -> DecsQ
defineMatrixMulDiag ValueInfo{..} dims@(left, inner) = do
  let vecN = makeVectorN inner
      (matN, len) = makeMatrixNL dims
  (matP, elemVars) <- conPE matN "a" len
  (vecP, diag) <- conPE vecN "b" inner
  let cols = stripes inner elemVars
      colE scalar = fmap (infixApp' (varE _valueMul) scalar)
      colEs = zipWith colE diag cols
      resultE = appsE (conE matN : concat colEs)
      mulN = mkName $ "mul" ++ show left ++ "x" ++ show inner ++ "d"
      mulC = simpleClause [matP, vecP] resultE
      mulT = arrowsT [matT, vecT, matT]
      vecT = conT vecN
      matT = conT matN
  inlSigDef mulN mulT [mulC]

defineVectorOuterProduct :: ValueInfo -> (Int, Int) -> DecsQ
defineVectorOuterProduct ValueInfo{..} dims@(left, right) = do
  let vecN = makeVectorN left
      vecN' = makeVectorN right
      (matN, _) = makeMatrixNL dims
  (vecP, elemVars) <- conPE vecN "a" left
  (vecP', elemVars') <- conPE vecN' "b" right
  let elemEs = do
        x <- elemVars
        y <- elemVars'
        return $ infixApp' (varE _valueMul) x y
      resultE = appsE (conE matN : elemEs)
      mulN = mkName $ "mulT" ++ show left ++ "x" ++ show right
      mulC = simpleClause [vecP, vecP'] resultE
      mulT = arrowsT [vecT, vecT', matT]
      vecT = conT vecN
      vecT' = conT vecN'
      matT = conT matN
  inlSigDef mulN mulT [mulC]

chunks :: Int -> [a] -> [[a]]
chunks _ [] = []
chunks chunkSize xs =
  let (front, back) = splitAt chunkSize xs in front:chunks chunkSize back

stripes :: Int -> [a] -> [[a]]
stripes chunkSize = raggedZip . chunks chunkSize

unevenZip :: Monoid a => [a] -> [a] -> [a]
unevenZip [] [] = []
unevenZip [] (x:xs) = x : unevenZip [] xs
unevenZip (x:xs) [] = x : unevenZip xs []
unevenZip (x:xs) (y:ys) = (x <> y) : unevenZip xs ys

raggedZip :: [[a]] -> [[a]]
raggedZip = foldr (unevenZip . fmap pure) []