{-# LANGUAGE TemplateHaskell      #-}
{-# LANGUAGE ScopedTypeVariables  #-}
{-# LANGUAGE ViewPatterns         #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

-----------------------------------------------------------------------------
-- |
-- Module      :  Data.Tensor.Static.TH
-- Copyright   :  (C) 2017 Alexey Vagarenko
-- License     :  BSD-style (see LICENSE)
-- Maintainer  :  Alexey Vagarenko (vagarenko@gmail.com)
-- Stability   :  experimental
-- Portability :  non-portable
--
----------------------------------------------------------------------------

module Data.Tensor.Static.TH (
      genTensorInstance
) where

import Data.List                  (foldl')
import Data.Tensor.Static         (tensor, Tensor, unsafeFromList, toList, IsTensor)
import Language.Haskell.TH
import qualified Data.List.NonEmpty as N


-- | Generate instance for tensor and lenses for its elements.
genTensorInstance :: N.NonEmpty Int     -- ^ Dimensions of the tensor.
                  -> Name               -- ^ Type of elements.
                  -> Q [Dec]
genTensorInstance (N.toList -> dimensions) elemTypeName = do
    conName <- newName ("Tensor'" ++ concatMap (\x -> show x ++ "'") dimensions ++ nameBase elemTypeName)

    let fieldCount = product dimensions
    fieldNames <- mapM (newName . ('x' :) . show) [0 .. fieldCount - 1]

    let fields   = replicate (fromIntegral fieldCount) (Bang SourceUnpack SourceStrict, elemType)
        dims     = natListT dimensions
        elemType = ConT elemTypeName

    let dataInstDec = DataInstD
            []                       -- context
            ''Tensor                 -- family name
            [dims, elemType]         -- family params
            (Just StarT)             -- kind
            [NormalC conName fields] -- data constructor with `fieldCount` unpacked fields of type `elemType`
            []
        
    let fromListPat     = foldr (\name pat -> InfixP (VarP name) '(:) pat) WildP fieldNames
        constructTensor = foldl' (\acc name -> acc `AppE` VarE name) (ConE conName) fieldNames
        tensorPat       = ConP conName (map VarP fieldNames)
        toListBody      = ListE (map VarE fieldNames)
        failBody        = VarE 'error `AppE` LitE (StringL ("Not enough elements to build a Tensor of shape "
                                                            ++ show dimensions))

    let tensorDec          = ValD (VarP 'tensor) (NormalB $ ConE conName ) []
        unsafeFromListDec  = FunD 'unsafeFromList [ Clause [fromListPat] (NormalB constructTensor ) []
                                                  , Clause [WildP      ] (NormalB failBody        ) []]  
        toListDec          = FunD 'toList         [ Clause [tensorPat  ] (NormalB toListBody      ) []]
        tensorCInstPragmas =
            [ PragmaD (InlineP 'tensor         Inline FunLike AllPhases)
            , PragmaD (InlineP 'unsafeFromList Inline FunLike AllPhases)
            , PragmaD (InlineP 'toList         Inline FunLike AllPhases) ]

    let tensorCInstDec = InstanceD
            Nothing
            []
            (ConT ''IsTensor `AppT` dims `AppT` elemType)
            ([dataInstDec, tensorDec, unsafeFromListDec, toListDec] ++ tensorCInstPragmas)        

    pure [tensorCInstDec]

-- | Create type-level list of Nat.
natListT :: [Int] -> Type
natListT = foldr (\d acc -> PromotedConsT `AppT` LitT (NumTyLit $ fromIntegral d) `AppT` acc) PromotedNilT