{-|
  Copyright   :  (C) 2019, Google Inc.
  License     :  BSD2 (see the file LICENSE)
  Maintainer  :  Christiaan Baaij <christiaan.baaij@gmail.com>
-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-} -- needed for constraint on the Fixed instance
#if __GLASGOW_HASKELL__ < 806
{-# OPTIONS_GHC -Wwarn=unused-pattern-binds #-}
#endif

module Clash.Class.AutoReg.Internal
  ( AutoReg (..)
  , deriveAutoReg
  , deriveAutoRegTuples
  )
  where

import           Data.List                    (nub,zipWith4)
import           Data.Maybe                   (fromMaybe,isJust)

import           GHC.Stack                    (HasCallStack)
import           GHC.TypeNats                 (KnownNat,Nat,type (+))
import           Clash.Explicit.Signal
import           Clash.Promoted.Nat
import           Clash.Magic
import           Clash.XException             (NFDataX, deepErrorX)

import           Clash.Sized.BitVector
import           Clash.Sized.Fixed
import           Clash.Sized.Index
import           Clash.Sized.RTree
import           Clash.Sized.Signed
import           Clash.Sized.Unsigned
import           Clash.Sized.Vector           (Vec, lazyV, smap)

import           Data.Int
import           Data.Word
import           Foreign.C.Types              (CUShort)
import           Numeric.Half                 (Half)

import           Language.Haskell.TH.Datatype
import           Language.Haskell.TH.Syntax
import           Language.Haskell.TH.Lib
import           Language.Haskell.TH.Ppr

import           Control.Lens.Internal.TH     (bndrName, conAppsT)

-- $setup
-- >>> import Data.Maybe
-- >>> import Clash.Class.BitPack (pack)
-- >>> :set -fplugin GHC.TypeLits.Normalise
-- >>> :set -fplugin GHC.TypeLits.KnownNat.Solver

-- | 'autoReg' is a "smart" version of 'register'. It does two things:
--
--   1. It splits product types over their fields. For example, given a 3-tuple,
--   the corresponding HDL will end up with three instances of a register (or
--   more if the three fields can be split up similarly).
--
--   2. Given a data type where a constructor indicates (parts) of the data will
--   (not) be updated a given cycle, it will split the data in two parts. The
--   first part will contain the "always interesting" parts (the constructor
--   bits). The second holds the "potentially uninteresting" data (the rest).
--   Both parts will be stored in separate registers. The register holding the
--   "potentially uninteresting" part will only be enabled if the constructor
--   bits indicate they're interesting.
--
--   The most important example of this is "Maybe". Consider "Maybe Byte)";
--   when viewed as bits, a 'Nothing' would look like:
--
--     >>> pack @(Maybe (Signed 16)) Nothing
--     0_...._...._...._....
--
--   and 'Just'
--
--     >>> pack @(Maybe (Signed 16)) (Just 3)
--     1_0000_0000_0000_0011
--
--   In the first case, Nothing, we don't particularly care about updating the
--   register holding the "Signed 16" field, as they'll be unknown anyway. We
--   can therefore deassert its enable line.
--
-- Making Clash lay it out like this increases the chances of synthesis tools
-- clock gating the registers, saving energy.
--
-- This version of 'autoReg' will split the given data type up recursively. For
-- example, given "a :: Maybe (Maybe Int, Maybe Int)", a total of five registers
-- will be rendered. Both the "interesting" and "uninteresting" enable lines of
-- the inner Maybe types will be controlled by the outer one, in addition to
-- the inner parts controlling their "uninteresting" parts as described in (2).
--
-- The default implementation is just 'register'. If you don't need or want
-- the special features of "AutoReg", you can use that by writing an empty instance.
--
-- > data MyDataType = ...
-- > instance AutoReg MyDataType
--
-- If you have a product type you can use 'deriveAutoReg' to derive an instance.
--
-- "Clash.Prelude" exports an implicit version of this: 'Clash.Prelude.autoReg'
class NFDataX a => AutoReg a where
  autoReg
    :: (HasCallStack, KnownDomain dom)
    => Clock dom -> Reset dom -> Enable dom
    -> a  -- ^ Reset value
    -> Signal dom a
    -> Signal dom a
  autoReg = Clock dom
-> Reset dom -> Enable dom -> a -> Signal dom a -> Signal dom a
forall (dom :: Domain) a.
(KnownDomain dom, NFDataX a) =>
Clock dom
-> Reset dom -> Enable dom -> a -> Signal dom a -> Signal dom a
register

instance AutoReg ()
instance AutoReg Bool

instance AutoReg Double
instance AutoReg Float
instance AutoReg CUShort
instance AutoReg Half

instance AutoReg Char

instance AutoReg Integer
instance AutoReg Int
instance AutoReg Int8
instance AutoReg Int16
instance AutoReg Int32
instance AutoReg Int64
instance AutoReg Word
instance AutoReg Word8
instance AutoReg Word16
instance AutoReg Word32
instance AutoReg Word64

instance AutoReg Bit
instance AutoReg (BitVector n)
instance AutoReg (Signed n)
instance AutoReg (Unsigned n)
instance AutoReg (Index n)
instance NFDataX (rep (int + frac)) => AutoReg (Fixed rep int frac)

instance AutoReg a => AutoReg (Maybe a) where
  autoReg :: Clock dom
-> Reset dom
-> Enable dom
-> Maybe a
-> Signal dom (Maybe a)
-> Signal dom (Maybe a)
autoReg Clock dom
clk Reset dom
rst Enable dom
en Maybe a
initVal Signal dom (Maybe a)
input =
    Bool -> a -> Maybe a
forall a. Bool -> a -> Maybe a
createMaybe (Bool -> a -> Maybe a)
-> Signal dom Bool -> Signal dom (a -> Maybe a)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Signal dom Bool
tagR Signal dom (a -> Maybe a) -> Signal dom a -> Signal dom (Maybe a)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Signal dom a
valR
   where
     tag :: Signal dom Bool
tag = Maybe a -> Bool
forall a. Maybe a -> Bool
isJust (Maybe a -> Bool) -> Signal dom (Maybe a) -> Signal dom Bool
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Signal dom (Maybe a)
input
     tagInit :: Bool
tagInit = Maybe a -> Bool
forall a. Maybe a -> Bool
isJust Maybe a
initVal
     tagR :: Signal dom Bool
tagR = Clock dom
-> Reset dom
-> Enable dom
-> Bool
-> Signal dom Bool
-> Signal dom Bool
forall (dom :: Domain) a.
(KnownDomain dom, NFDataX a) =>
Clock dom
-> Reset dom -> Enable dom -> a -> Signal dom a -> Signal dom a
register Clock dom
clk Reset dom
rst Enable dom
en Bool
tagInit Signal dom Bool
tag

     val :: Signal dom a
val = a -> Maybe a -> a
forall a. a -> Maybe a -> a
fromMaybe (String -> a
forall a. (NFDataX a, HasCallStack) => String -> a
deepErrorX String
"autoReg'.val") (Maybe a -> a) -> Signal dom (Maybe a) -> Signal dom a
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Signal dom (Maybe a)
input
     valInit :: a
valInit = a -> Maybe a -> a
forall a. a -> Maybe a -> a
fromMaybe (String -> a
forall a. (NFDataX a, HasCallStack) => String -> a
deepErrorX String
"autoReg'.valInit") Maybe a
initVal

     valR :: Signal dom a
valR = Clock dom
-> Reset dom -> Enable dom -> a -> Signal dom a -> Signal dom a
forall a (dom :: Domain).
(AutoReg a, HasCallStack, KnownDomain dom) =>
Clock dom
-> Reset dom -> Enable dom -> a -> Signal dom a -> Signal dom a
autoReg Clock dom
clk Reset dom
rst (Enable dom -> Signal dom Bool -> Enable dom
forall (dom :: Domain). Enable dom -> Signal dom Bool -> Enable dom
enable Enable dom
en Signal dom Bool
tag) a
valInit Signal dom a
val

     createMaybe :: Bool -> a -> Maybe a
createMaybe Bool
t a
v = case Bool
t of
       Bool
True -> a -> Maybe a
forall a. a -> Maybe a
Just a
v
       Bool
False -> Maybe a
forall a. Maybe a
Nothing

instance (KnownNat n, AutoReg a) => AutoReg (Vec n a) where
  autoReg
    :: forall dom. (HasCallStack, KnownDomain dom)
    => Clock dom -> Reset dom -> Enable dom
    -> Vec n a -- ^ Reset value
    -> Signal dom (Vec n a)
    -> Signal dom (Vec n a)
  autoReg :: Clock dom
-> Reset dom
-> Enable dom
-> Vec n a
-> Signal dom (Vec n a)
-> Signal dom (Vec n a)
autoReg Clock dom
clk Reset dom
rst Enable dom
en Vec n a
initVal Signal dom (Vec n a)
xs =
    Unbundled dom (Vec n a) -> Signal dom (Vec n a)
forall a (dom :: Domain).
Bundle a =>
Unbundled dom a -> Signal dom a
bundle (Unbundled dom (Vec n a) -> Signal dom (Vec n a))
-> Unbundled dom (Vec n a) -> Signal dom (Vec n a)
forall a b. (a -> b) -> a -> b
$ (forall (l :: Nat). SNat l -> a -> Signal dom a -> Signal dom a)
-> Vec n a -> Vec n (Signal dom a -> Signal dom a)
forall (k :: Nat) a b.
KnownNat k =>
(forall (l :: Nat). SNat l -> a -> b) -> Vec k a -> Vec k b
smap forall (l :: Nat). SNat l -> a -> Signal dom a -> Signal dom a
go (Vec n a -> Vec n a
forall (n :: Nat) a. KnownNat n => Vec n a -> Vec n a
lazyV Vec n a
initVal) Vec n (Signal dom a -> Signal dom a)
-> Vec n (Signal dom a) -> Vec n (Signal dom a)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Signal dom (Vec n a) -> Unbundled dom (Vec n a)
forall a (dom :: Domain).
Bundle a =>
Signal dom a -> Unbundled dom a
unbundle Signal dom (Vec n a)
xs
   where
    go :: forall (i :: Nat). SNat i -> a  -> Signal dom a -> Signal dom a
    go :: SNat i -> a -> Signal dom a -> Signal dom a
go SNat i
SNat = forall a. a -> a
forall (name :: Nat) a. a -> a
suffixNameFromNatP @i ((Signal dom a -> Signal dom a) -> Signal dom a -> Signal dom a)
-> (a -> Signal dom a -> Signal dom a)
-> a
-> Signal dom a
-> Signal dom a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Clock dom
-> Reset dom -> Enable dom -> a -> Signal dom a -> Signal dom a
forall a (dom :: Domain).
(AutoReg a, HasCallStack, KnownDomain dom) =>
Clock dom
-> Reset dom -> Enable dom -> a -> Signal dom a -> Signal dom a
autoReg Clock dom
clk Reset dom
rst Enable dom
en

instance (KnownNat d, AutoReg a) => AutoReg (RTree d a) where
  autoReg :: Clock dom
-> Reset dom
-> Enable dom
-> RTree d a
-> Signal dom (RTree d a)
-> Signal dom (RTree d a)
autoReg Clock dom
clk Reset dom
rst Enable dom
en RTree d a
initVal Signal dom (RTree d a)
xs =
    Unbundled dom (RTree d a) -> Signal dom (RTree d a)
forall a (dom :: Domain).
Bundle a =>
Unbundled dom a -> Signal dom a
bundle (Unbundled dom (RTree d a) -> Signal dom (RTree d a))
-> Unbundled dom (RTree d a) -> Signal dom (RTree d a)
forall a b. (a -> b) -> a -> b
$ (Clock dom
-> Reset dom -> Enable dom -> a -> Signal dom a -> Signal dom a
forall a (dom :: Domain).
(AutoReg a, HasCallStack, KnownDomain dom) =>
Clock dom
-> Reset dom -> Enable dom -> a -> Signal dom a -> Signal dom a
autoReg Clock dom
clk Reset dom
rst Enable dom
en) (a -> Signal dom a -> Signal dom a)
-> RTree d a -> RTree d (Signal dom a -> Signal dom a)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> RTree d a -> RTree d a
forall (d :: Nat) a. KnownNat d => RTree d a -> RTree d a
lazyT RTree d a
initVal RTree d (Signal dom a -> Signal dom a)
-> RTree d (Signal dom a) -> RTree d (Signal dom a)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Signal dom (RTree d a) -> Unbundled dom (RTree d a)
forall a (dom :: Domain).
Bundle a =>
Signal dom a -> Unbundled dom a
unbundle Signal dom (RTree d a)
xs


-- | Decompose an applied type into its individual components. For example, this:
--
-- @
-- Either Int Char
-- @
--
-- would be unfolded to this:
--
-- @
-- ('ConT' ''Either, ['ConT' ''Int, 'ConT' ''Char])
-- @
--
-- This function ignores explicit parentheses and visible kind applications.
--
-- NOTE: Copied from "Control.Lens.Internal.TH".
-- TODO: Remove this function. Can be removed once we can upgrade to lens 4.18.
-- TODO: This is currently difficult due to issue with nix.
unfoldType :: Type -> (Type, [Type])
unfoldType :: Type -> (Type, [Type])
unfoldType = [Type] -> Type -> (Type, [Type])
go []
  where
    go :: [Type] -> Type -> (Type, [Type])
    go :: [Type] -> Type -> (Type, [Type])
go [Type]
acc (ForallT [TyVarBndr]
_ [Type]
_ Type
ty) = [Type] -> Type -> (Type, [Type])
go [Type]
acc Type
ty
    go [Type]
acc (AppT Type
ty1 Type
ty2)   = [Type] -> Type -> (Type, [Type])
go (Type
ty2Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
:[Type]
acc) Type
ty1
    go [Type]
acc (SigT Type
ty Type
_)      = [Type] -> Type -> (Type, [Type])
go [Type]
acc Type
ty
#if MIN_VERSION_template_haskell(2,11,0)
    go [Type]
acc (ParensT Type
ty)     = [Type] -> Type -> (Type, [Type])
go [Type]
acc Type
ty
#endif
#if MIN_VERSION_template_haskell(2,15,0)
    go [Type]
acc (AppKindT Type
ty Type
_)  = [Type] -> Type -> (Type, [Type])
go [Type]
acc Type
ty
#endif
    go [Type]
acc Type
ty               = (Type
ty, [Type]
acc)

-- | Automatically derives an 'AutoReg' instance for a product type
--
-- Usage:
--
-- > data Pair a b = MkPair { getA :: a, getB :: b } deriving (Generic, NFDataX)
-- > data Tup3 a b c = MkTup3 { getAB :: Pair a b, getC :: c } deriving (Generic, NFDataX)
-- > deriveAutoReg ''Pair
-- > deriveAutoReg ''Tup3
--
-- __NB__: Because of the way template haskell works the order here matters,
-- if you try to @deriveAutoReg ''Tup3@ before @Pair@ it will complain
-- about missing an @instance AutoReg (Pair a b)@.
deriveAutoReg :: Name -> DecsQ
deriveAutoReg :: Name -> DecsQ
deriveAutoReg Name
tyNm = do
  DatatypeInfo
tyInfo <- Name -> Q DatatypeInfo
reifyDatatype Name
tyNm
  case DatatypeInfo -> [ConstructorInfo]
datatypeCons DatatypeInfo
tyInfo of
    [] -> String -> DecsQ
forall (m :: Type -> Type) a. MonadFail m => String -> m a
fail String
"Can't deriveAutoReg for empty types"
    [ConstructorInfo
conInfo] -> DatatypeInfo -> ConstructorInfo -> DecsQ
deriveAutoRegProduct DatatypeInfo
tyInfo ConstructorInfo
conInfo
    [ConstructorInfo]
_ -> String -> DecsQ
forall (m :: Type -> Type) a. MonadFail m => String -> m a
fail String
"Can't deriveAutoReg for sum types"



{-
For a type like:
   data Product a b .. = MkProduct { getA :: a, getB :: b, .. }
This generates the following instance:

instance (AutoReg a, AutoReg b, ..) => AutoReg (Product a b ..) where
  autoReg clk rst en initVal input =
    MkProduct <$> sig0 <*> sig1 ...
    where
      field0 = (\(MkProduct x _ ...) -> x) <$> input
      field1 = (\(MkProduct _ x ...) -> x) <$> input
      ...
      MkProduct initVal0 initVal1 ... = initVal
      sig0 = suffixNameP @"getA" autoReg clk rst en initVal0 field0
      sig1 = suffixNameP @"getB" autoReg clk rst en initVal1 field1
      ...
-}
deriveAutoRegProduct :: DatatypeInfo -> ConstructorInfo -> DecsQ
deriveAutoRegProduct :: DatatypeInfo -> ConstructorInfo -> DecsQ
deriveAutoRegProduct DatatypeInfo
tyInfo ConstructorInfo
conInfo = Name -> [(Maybe Name, Type)] -> DecsQ
go (ConstructorInfo -> Name
constructorName ConstructorInfo
conInfo) [(Maybe Name, Type)]
fieldInfos
 where
  tyNm :: Name
tyNm = DatatypeInfo -> Name
datatypeName DatatypeInfo
tyInfo
  tyVarBndrs :: [TyVarBndr]
tyVarBndrs = DatatypeInfo -> [TyVarBndr]
datatypeVars DatatypeInfo
tyInfo

#if MIN_VERSION_th_abstraction(0,3,0)
  toTyVar :: TyVarBndr -> Type
toTyVar = Name -> Type
VarT (Name -> Type) -> (TyVarBndr -> Name) -> TyVarBndr -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr -> Name
bndrName
#else
  toTyVar t = case t of
    VarT _ -> t
    SigT t' _ -> toTyVar t'
    _ -> error "deriveAutoRegProduct.toTv"
#endif

  tyVars :: [Type]
tyVars = (TyVarBndr -> Type) -> [TyVarBndr] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map TyVarBndr -> Type
toTyVar [TyVarBndr]
tyVarBndrs
  ty :: Type
ty = Name -> [Type] -> Type
conAppsT Name
tyNm [Type]
tyVars

  fieldInfos :: [(Maybe Name, Type)]
fieldInfos =
    [Maybe Name] -> [Type] -> [(Maybe Name, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Maybe Name]
fieldNames (ConstructorInfo -> [Type]
constructorFields ConstructorInfo
conInfo)
   where
    fieldNames :: [Maybe Name]
fieldNames =
      case ConstructorInfo -> ConstructorVariant
constructorVariant ConstructorInfo
conInfo of
        RecordConstructor [Name]
nms -> (Name -> Maybe Name) -> [Name] -> [Maybe Name]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Maybe Name
forall a. a -> Maybe a
Just [Name]
nms
        ConstructorVariant
_ -> Maybe Name -> [Maybe Name]
forall a. a -> [a]
repeat Maybe Name
forall a. Maybe a
Nothing

  go :: Name -> [(Maybe Name,Type)] -> Q [Dec]
  go :: Name -> [(Maybe Name, Type)] -> DecsQ
go Name
dcNm [(Maybe Name, Type)]
fields = do
    [Name]
args <- (String -> Q Name) -> [String] -> Q [Name]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM String -> Q Name
newName [String
"clk", String
"rst", String
"en", String
"initVal", String
"input"]
    let
      [ExpQ
clkE, ExpQ
rstE, ExpQ
enE, ExpQ
initValE, ExpQ
inputE] = (Name -> ExpQ) -> [Name] -> [ExpQ]
forall a b. (a -> b) -> [a] -> [b]
map Name -> ExpQ
varE [Name]
args
      argsP :: [PatQ]
argsP = (Name -> PatQ) -> [Name] -> [PatQ]
forall a b. (a -> b) -> [a] -> [b]
map Name -> PatQ
varP [Name]
args
      fieldNames :: [Maybe Name]
fieldNames = ((Maybe Name, Type) -> Maybe Name)
-> [(Maybe Name, Type)] -> [Maybe Name]
forall a b. (a -> b) -> [a] -> [b]
map (Maybe Name, Type) -> Maybe Name
forall a b. (a, b) -> a
fst [(Maybe Name, Type)]
fields

      field :: Name -> Int -> DecQ
      field :: Name -> Int -> DecQ
field Name
nm Int
nr =
        PatQ -> BodyQ -> [DecQ] -> DecQ
valD (Name -> PatQ
varP Name
nm) (ExpQ -> BodyQ
normalB [| $fieldSel <$> $inputE |]) []
       where
        fieldSel :: ExpQ
fieldSel = do
          Name
xNm <- String -> Q Name
newName String
"x"
          let fieldP :: [PatQ]
fieldP = [ if Int
nr Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n then Name -> PatQ
varP Name
xNm else PatQ
wildP
                       | (Int
n,(Maybe Name, Type)
_) <- [Int] -> [(Maybe Name, Type)] -> [(Int, (Maybe Name, Type))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] [(Maybe Name, Type)]
fields]
          [PatQ] -> ExpQ -> ExpQ
lamE [Name -> [PatQ] -> PatQ
conP Name
dcNm [PatQ]
fieldP] (Name -> ExpQ
varE Name
xNm)   -- "\(Dc _ _ .. x _ ..) -> x"

    [Name]
parts <- String -> [(Maybe Name, Type)] -> Q [Name]
forall a. String -> [a] -> Q [Name]
generateNames String
"field" [(Maybe Name, Type)]
fields
    [Dec]
fieldDecls <- [DecQ] -> DecsQ
forall (t :: Type -> Type) (m :: Type -> Type) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([DecQ] -> DecsQ) -> [DecQ] -> DecsQ
forall a b. (a -> b) -> a -> b
$ (Name -> Int -> DecQ) -> [Name] -> [Int] -> [DecQ]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Name -> Int -> DecQ
field [Name]
parts [Int
0..]
    [Name]
sigs <- String -> [(Maybe Name, Type)] -> Q [Name]
forall a. String -> [a] -> Q [Name]
generateNames String
"sig" [(Maybe Name, Type)]
fields
    [Name]
initVals <- String -> [(Maybe Name, Type)] -> Q [Name]
forall a. String -> [a] -> Q [Name]
generateNames String
"initVal" [(Maybe Name, Type)]
fields
    let initPat :: PatQ
initPat = Name -> [PatQ] -> PatQ
conP Name
dcNm ((Name -> PatQ) -> [Name] -> [PatQ]
forall a b. (a -> b) -> [a] -> [b]
map Name -> PatQ
varP [Name]
initVals)
    Dec
initDecl <- PatQ -> BodyQ -> [DecQ] -> DecQ
valD PatQ
initPat (ExpQ -> BodyQ
normalB ExpQ
initValE) []

    let
      genAutoRegDecl :: PatQ -> ExpQ -> ExpQ -> Maybe Name -> DecsQ
      genAutoRegDecl :: PatQ -> ExpQ -> ExpQ -> Maybe Name -> DecsQ
genAutoRegDecl PatQ
s ExpQ
v ExpQ
i Maybe Name
nameM =
        [d| $s = $nameMe autoReg $clkE $rstE $enE $i $v |]
       where
        nameMe :: ExpQ
nameMe = case Maybe Name
nameM of
          Maybe Name
Nothing -> [| id |]
          Just Name
nm -> let nmSym :: TypeQ
nmSym = TyLitQ -> TypeQ
litT (TyLitQ -> TypeQ) -> TyLitQ -> TypeQ
forall a b. (a -> b) -> a -> b
$ String -> TyLitQ
strTyLit (Name -> String
nameBase Name
nm)
                     in [| suffixNameP @($nmSym) |]

    [Dec]
partDecls <- [[Dec]] -> [Dec]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat ([[Dec]] -> [Dec]) -> Q [[Dec]] -> DecsQ
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ([DecsQ] -> Q [[Dec]]
forall (t :: Type -> Type) (m :: Type -> Type) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([DecsQ] -> Q [[Dec]]) -> [DecsQ] -> Q [[Dec]]
forall a b. (a -> b) -> a -> b
$ (PatQ -> ExpQ -> ExpQ -> Maybe Name -> DecsQ)
-> [PatQ] -> [ExpQ] -> [ExpQ] -> [Maybe Name] -> [DecsQ]
forall a b c d e.
(a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
zipWith4 PatQ -> ExpQ -> ExpQ -> Maybe Name -> DecsQ
genAutoRegDecl
                                                 (Name -> PatQ
varP (Name -> PatQ) -> [Name] -> [PatQ]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
sigs)
                                                 (Name -> ExpQ
varE (Name -> ExpQ) -> [Name] -> [ExpQ]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
parts)
                                                 (Name -> ExpQ
varE (Name -> ExpQ) -> [Name] -> [ExpQ]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
initVals)
                                                 ([Maybe Name]
fieldNames)
                            )
    let
        decls :: [DecQ]
        decls :: [DecQ]
decls = (Dec -> DecQ) -> [Dec] -> [DecQ]
forall a b. (a -> b) -> [a] -> [b]
map Dec -> DecQ
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Dec
initDecl Dec -> [Dec] -> [Dec]
forall a. a -> [a] -> [a]
: [Dec]
fieldDecls [Dec] -> [Dec] -> [Dec]
forall a. [a] -> [a] -> [a]
++ [Dec]
partDecls)
        tyConE :: ExpQ
tyConE = Name -> ExpQ
conE Name
dcNm
        body :: ExpQ
body =
          case (Name -> ExpQ) -> [Name] -> [ExpQ]
forall a b. (a -> b) -> [a] -> [b]
map Name -> ExpQ
varE [Name]
sigs of
            (ExpQ
sig0:[ExpQ]
rest) -> (ExpQ -> ExpQ -> ExpQ) -> ExpQ -> [ExpQ] -> ExpQ
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
                             (\ExpQ
acc ExpQ
sigN -> [| $acc <*> $sigN |])
                             [| $tyConE <$> $sig0 |]
                             [ExpQ]
rest
            [] -> [| $tyConE |]

    Dec
autoRegDec <- Name -> [ClauseQ] -> DecQ
funD 'autoReg [[PatQ] -> BodyQ -> [DecQ] -> ClauseQ
clause [PatQ]
argsP (ExpQ -> BodyQ
normalB ExpQ
body) [DecQ]
decls]
    [Type]
ctx <- ConstructorInfo -> Q [Type]
calculateRequiredContext ConstructorInfo
conInfo
    [Dec] -> DecsQ
forall (m :: Type -> Type) a. Monad m => a -> m a
return [Maybe Overlap -> [Type] -> Type -> [Dec] -> Dec
InstanceD Maybe Overlap
forall a. Maybe a
Nothing [Type]
ctx (Type -> Type -> Type
AppT (Name -> Type
ConT ''AutoReg) Type
ty) [Dec
autoRegDec]]

-- Calculate the required constraint to call autoReg on all the fields of a
-- given constructor
calculateRequiredContext :: ConstructorInfo -> Q Cxt
calculateRequiredContext :: ConstructorInfo -> Q [Type]
calculateRequiredContext ConstructorInfo
conInfo = do
  let fieldTys :: [Type]
fieldTys = ConstructorInfo -> [Type]
constructorFields ConstructorInfo
conInfo
  [[Type]]
wantedInstances <- (Type -> Q [Type]) -> [Type] -> Q [[Type]]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Type
ty -> Name -> [Type] -> Q [Type]
constraintsWantedFor ''AutoReg [Type
ty]) ([Type] -> [Type]
forall a. Eq a => [a] -> [a]
nub [Type]
fieldTys)
  [Type] -> Q [Type]
forall (m :: Type -> Type) a. Monad m => a -> m a
return ([Type] -> Q [Type]) -> [Type] -> Q [Type]
forall a b. (a -> b) -> a -> b
$ [Type] -> [Type]
forall a. Eq a => [a] -> [a]
nub ([[Type]] -> [Type]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [[Type]]
wantedInstances)

constraintsWantedFor :: Name -> [Type] -> Q Cxt
constraintsWantedFor :: Name -> [Type] -> Q [Type]
constraintsWantedFor Name
clsNm [Type]
tys
  | Name -> String
forall a. Show a => a -> String
show Name
clsNm String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
"GHC.TypeNats.KnownNat" = do
  -- KnownNat is special, you can't just lookup instances with reifyInstances.
  -- So we just pass KnownNat constraints.
  -- This will most likely require UndecidableInstances.
    [Type] -> Q [Type]
forall (m :: Type -> Type) a. Monad m => a -> m a
return [Name -> [Type] -> Type
conAppsT Name
clsNm [Type]
tys]

constraintsWantedFor Name
clsNm [Type
ty] = case Type
ty of
  VarT Name
_ -> [Type] -> Q [Type]
forall (m :: Type -> Type) a. Monad m => a -> m a
return [Type -> Type -> Type
AppT (Name -> Type
ConT Name
clsNm) Type
ty]
  ConT Name
_ -> [Type] -> Q [Type]
forall (m :: Type -> Type) a. Monad m => a -> m a
return []
  Type
_ -> do
    [Dec]
insts <- Name -> [Type] -> DecsQ
reifyInstances Name
clsNm [Type
ty]
    case [Dec]
insts of
      [InstanceD Maybe Overlap
_ [Type]
cxtInst (AppT Type
autoRegCls Type
instTy) [Dec]
_]
        | Type
autoRegCls Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Name -> Type
ConT Name
clsNm -> do
          let substs :: [(Name, Type)]
substs = Type -> Type -> [(Name, Type)]
findTyVarSubsts Type
instTy Type
ty
              cxt2 :: [Type]
cxt2 = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map ([(Name, Type)] -> Type -> Type
applyTyVarSubsts [(Name, Type)]
substs) [Type]
cxtInst
              okCxt :: [Type]
okCxt = (Type -> Bool) -> [Type] -> [Type]
forall a. (a -> Bool) -> [a] -> [a]
filter Type -> Bool
isOk [Type]
cxt2
              recurseCxt :: [Type]
recurseCxt = (Type -> Bool) -> [Type] -> [Type]
forall a. (a -> Bool) -> [a] -> [a]
filter Type -> Bool
needRecurse [Type]
cxt2
          [[Type]]
recursed <- (Type -> Q [Type]) -> [Type] -> Q [[Type]]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Type -> Q [Type]
recurse [Type]
recurseCxt
          [Type] -> Q [Type]
forall (m :: Type -> Type) a. Monad m => a -> m a
return ([Type]
okCxt [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [[Type]] -> [Type]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [[Type]]
recursed)
      []      -> String -> Q [Type]
forall (m :: Type -> Type) a. MonadFail m => String -> m a
fail (String -> Q [Type]) -> String -> Q [Type]
forall a b. (a -> b) -> a -> b
$ String
"Missing instance " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
clsNm String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" (" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Ppr a => a -> String
pprint Type
ty String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"
      (Dec
_:Dec
_:[Dec]
_) -> String -> Q [Type]
forall (m :: Type -> Type) a. MonadFail m => String -> m a
fail (String -> Q [Type]) -> String -> Q [Type]
forall a b. (a -> b) -> a -> b
$ String
"There are multiple " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
clsNm String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" instances for "
                     String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Ppr a => a -> String
pprint Type
ty String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
":\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Dec] -> String
forall a. Ppr a => a -> String
pprint [Dec]
insts
      [Dec]
_ -> String -> Q [Type]
forall (m :: Type -> Type) a. MonadFail m => String -> m a
fail (String -> Q [Type]) -> String -> Q [Type]
forall a b. (a -> b) -> a -> b
$ String
"Got unexpected instance: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Dec] -> String
forall a. Ppr a => a -> String
pprint [Dec]
insts
 where
  isOk :: Type -> Bool
  isOk :: Type -> Bool
isOk (Type -> (Type, [Type])
unfoldType -> (Type
_cls,[Type]
tys)) =
    case [Type]
tys of
      [VarT Name
_] -> Bool
True
      [Type
_] -> Bool
False
      [Type]
_ -> Bool
True -- see [NOTE: MultiParamTypeClasses]
  needRecurse :: Type -> Bool
  needRecurse :: Type -> Bool
needRecurse (Type -> (Type, [Type])
unfoldType -> (Type
cls,[Type]
tys)) =
    case [Type]
tys of
      [VarT Name
_] -> Bool
False
      [ConT Name
_] -> Bool
False  -- we can just drop constraints like: "AutoReg Bool => ..."
      [AppT Type
_ Type
_] -> Bool
True
      [Type
_] -> String -> Bool
forall a. HasCallStack => String -> a
error ( String
"Error while deriveAutoReg: don't know how to handle: "
                  String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Ppr a => a -> String
pprint Type
cls String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" (" String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Ppr a => a -> String
pprint [Type]
tys String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")" )
      [Type]
_ -> Bool
False  -- see [NOTE: MultiParamTypeClasses]

  recurse :: Type -> Q Cxt
  recurse :: Type -> Q [Type]
recurse (Type -> (Type, [Type])
unfoldType -> (ConT Name
cls,[Type]
tys)) = Name -> [Type] -> Q [Type]
constraintsWantedFor Name
cls [Type]
tys
  recurse Type
t =
    String -> Q [Type]
forall (m :: Type -> Type) a. MonadFail m => String -> m a
fail (String
"Expected a class applied to some arguments but got " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Ppr a => a -> String
pprint Type
t)

constraintsWantedFor Name
clsNm [Type]
tys =
  [Type] -> Q [Type]
forall (m :: Type -> Type) a. Monad m => a -> m a
return [Name -> [Type] -> Type
conAppsT Name
clsNm [Type]
tys] -- see [NOTE: MultiParamTypeClasses]

-- [NOTE: MultiParamTypeClasses]
-- The constraint calculation code doesn't handle MultiParamTypeClasses
-- "properly", but it will try to pass them on, so the resulting instance should
-- still compile with UndecidableInstances enabled.


-- | Find tyVar substitutions between a general type and a second possibly less
-- general type. For example:
--
-- @
-- findTyVarSubsts "Either a b" "Either c [Bool]"
--   == "[(a,c), (b,[Bool])]"
-- @
findTyVarSubsts :: Type -> Type -> [(Name,Type)]
findTyVarSubsts :: Type -> Type -> [(Name, Type)]
findTyVarSubsts = Type -> Type -> [(Name, Type)]
go
 where
  go :: Type -> Type -> [(Name, Type)]
go Type
ty1 Type
ty2 = case (Type
ty1,Type
ty2) of
    (VarT Name
nm1       , VarT Name
nm2) | Name
nm1 Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
nm2 -> []
    (VarT Name
nm        , Type
t)                     -> [(Name
nm,Type
t)]
    (ConT Name
_         , ConT Name
_)                -> []
    (AppT Type
x1 Type
y1     , AppT Type
x2 Type
y2)            -> Type -> Type -> [(Name, Type)]
go Type
x1 Type
x2 [(Name, Type)] -> [(Name, Type)] -> [(Name, Type)]
forall a. [a] -> [a] -> [a]
++ Type -> Type -> [(Name, Type)]
go Type
y1 Type
y2
    (SigT Type
t1 Type
k1     , SigT Type
t2 Type
k2)            -> Type -> Type -> [(Name, Type)]
go Type
t1 Type
t2 [(Name, Type)] -> [(Name, Type)] -> [(Name, Type)]
forall a. [a] -> [a] -> [a]
++ Type -> Type -> [(Name, Type)]
go Type
k1 Type
k2
    (InfixT Type
x1 Name
_ Type
y1 , InfixT Type
x2 Name
_ Type
y2)        -> Type -> Type -> [(Name, Type)]
go Type
x1 Type
x2 [(Name, Type)] -> [(Name, Type)] -> [(Name, Type)]
forall a. [a] -> [a] -> [a]
++ Type -> Type -> [(Name, Type)]
go Type
y1 Type
y2
    (UInfixT Type
x1 Name
_ Type
y1, UInfixT Type
x2 Name
_ Type
y2)       -> Type -> Type -> [(Name, Type)]
go Type
x1 Type
x2 [(Name, Type)] -> [(Name, Type)] -> [(Name, Type)]
forall a. [a] -> [a] -> [a]
++ Type -> Type -> [(Name, Type)]
go Type
y1 Type
y2
    (ParensT Type
x1     , ParensT Type
x2)            -> Type -> Type -> [(Name, Type)]
go Type
x1 Type
x2

#if __GLASGOW_HASKELL__ >= 808
    (AppKindT Type
t1 Type
k1     , AppKindT Type
t2 Type
k2)      -> Type -> Type -> [(Name, Type)]
go Type
t1 Type
t2 [(Name, Type)] -> [(Name, Type)] -> [(Name, Type)]
forall a. [a] -> [a] -> [a]
++ Type -> Type -> [(Name, Type)]
go Type
k1 Type
k2
    (ImplicitParamT String
_ Type
x1, ImplicitParamT String
_ Type
x2) -> Type -> Type -> [(Name, Type)]
go Type
x1 Type
x2
#endif

    (PromotedT Name
_          , PromotedT Name
_          ) -> []
    (TupleT Int
_             , TupleT Int
_             ) -> []
    (UnboxedTupleT Int
_      , UnboxedTupleT Int
_      ) -> []
    (UnboxedSumT Int
_        , UnboxedSumT Int
_        ) -> []
    (Type
ArrowT               , Type
ArrowT               ) -> []
    (Type
EqualityT            , Type
EqualityT            ) -> []
    (Type
ListT                , Type
ListT                ) -> []
    (PromotedTupleT Int
_     , PromotedTupleT Int
_     ) -> []
    (Type
PromotedNilT         , Type
PromotedNilT         ) -> []
    (Type
PromotedConsT        , Type
PromotedConsT        ) -> []
    (Type
StarT                , Type
StarT                ) -> []
    (Type
ConstraintT          , Type
ConstraintT          ) -> []
    (LitT TyLit
_               , LitT TyLit
_               ) -> []
    (Type
WildCardT            , Type
WildCardT            ) -> []
    (Type, Type)
_ -> String -> [(Name, Type)]
forall a. HasCallStack => String -> a
error (String -> [(Name, Type)]) -> String -> [(Name, Type)]
forall a b. (a -> b) -> a -> b
$ [String] -> String
unlines [ String
"findTyVarSubsts: Unexpected types"
                         , String
"ty1:", Type -> String
forall a. Ppr a => a -> String
pprint Type
ty1,String
"ty2:", Type -> String
forall a. Ppr a => a -> String
pprint Type
ty2]

applyTyVarSubsts :: [(Name,Type)] -> Type -> Type
applyTyVarSubsts :: [(Name, Type)] -> Type -> Type
applyTyVarSubsts [(Name, Type)]
substs Type
ty = Type -> Type
go Type
ty
  where
    go :: Type -> Type
go Type
ty' = case Type
ty' of
      VarT Name
n -> case Name -> [(Name, Type)] -> Maybe Type
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
n [(Name, Type)]
substs of
                  Maybe Type
Nothing -> Type
ty'
                  Just Type
m  -> Type
m
      ConT Name
_ -> Type
ty'
      AppT Type
ty1 Type
ty2 -> Type -> Type -> Type
AppT (Type -> Type
go Type
ty1) (Type -> Type
go Type
ty2)
      Type
_ -> String -> Type
forall a. HasCallStack => String -> a
error (String -> Type) -> String -> Type
forall a b. (a -> b) -> a -> b
$ String
"TODO applyTyVarSubsts: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Show a => a -> String
show Type
ty'


-- | Generate a list of fresh Name's:
-- prefix0_.., prefix1_.., prefix2_.., ..
generateNames :: String -> [a] -> Q [Name]
generateNames :: String -> [a] -> Q [Name]
generateNames String
prefix [a]
xs =
  [Q Name] -> Q [Name]
forall (t :: Type -> Type) (m :: Type -> Type) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ((Int -> a -> Q Name) -> [Int] -> [a] -> [Q Name]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
n a
_ -> String -> Q Name
newName (String -> Q Name) -> String -> Q Name
forall a b. (a -> b) -> a -> b
$ String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show @Int Int
n) [Int
0..] [a]
xs)

deriveAutoRegTuples :: [Int] -> DecsQ
deriveAutoRegTuples :: [Int] -> DecsQ
deriveAutoRegTuples [Int]
xs = [[Dec]] -> [Dec]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat ([[Dec]] -> [Dec]) -> Q [[Dec]] -> DecsQ
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int -> DecsQ) -> [Int] -> Q [[Dec]]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Int -> DecsQ
deriveAutoRegTuple [Int]
xs

deriveAutoRegTuple :: Int -> DecsQ
deriveAutoRegTuple :: Int -> DecsQ
deriveAutoRegTuple Int
n
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2 = String -> DecsQ
forall (m :: Type -> Type) a. MonadFail m => String -> m a
fail (String -> DecsQ) -> String -> DecsQ
forall a b. (a -> b) -> a -> b
$ String
"deriveAutoRegTuple doesn't work for " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"-tuples"
  | Bool
otherwise = Name -> DecsQ
deriveAutoReg Name
tupN
  where
    tupN :: Name
tupN = String -> Name
mkName (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"(" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> Char -> String
forall a. Int -> a -> [a]
replicate (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Char
',' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"