{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module Clash.Primitives.Util
  ( generatePrimMap
  , hashCompiledPrimMap
  , constantArgs
  , decodeOrErr
  , getFunctionPlurality
  ) where
import           Control.DeepSeq        (force)
import           Control.Monad          (join)
import           Data.Aeson.Extra       (decodeOrErr)
import qualified Data.ByteString.Lazy   as LZ
import qualified Data.HashMap.Lazy      as HashMap
import qualified Data.HashMap.Strict    as HashMapStrict
import qualified Data.Set               as Set
import           Data.Hashable          (hash)
import           Data.List              (isSuffixOf, sort, find)
import           Data.Maybe             (fromMaybe)
import qualified Data.Text              as TS
import           Data.Text.Lazy         (Text)
import qualified Data.Text.Lazy.IO      as T
import           GHC.Stack              (HasCallStack)
import qualified System.Directory       as Directory
import qualified System.FilePath        as FilePath
import           System.IO.Error        (tryIOError)
import           Clash.Annotations.Primitive
  ( PrimitiveGuard(HasBlackBox, WarnNonSynthesizable, WarnAlways, DontTranslate)
  , extractPrim)
import           Clash.Core.Term        (Term)
import           Clash.Core.Type        (Type)
import           Clash.Primitives.Types
  ( Primitive(BlackBox), CompiledPrimitive, ResolvedPrimitive, ResolvedPrimMap
  , includes, template, TemplateSource(TFile, TInline), Primitive(..)
  , UnresolvedPrimitive, CompiledPrimMap, GuardedResolvedPrimitive)
import           Clash.Netlist.Types    (BlackBox(..), NetlistMonad)
import           Clash.Netlist.Util     (preserveState)
import           Clash.Netlist.BlackBox.Util
  (walkElement)
import           Clash.Netlist.BlackBox.Types
  (Element(Const, Lit), BlackBoxMeta(..))
hashCompiledPrimitive :: CompiledPrimitive -> Int
hashCompiledPrimitive (Primitive {name, primSort}) = hash (name, primSort)
hashCompiledPrimitive (BlackBoxHaskell {function}) = fst function
hashCompiledPrimitive (BlackBox {name, kind, outputReg, libraries, imports, includes, template}) =
  hash (name, kind, outputReg, libraries, imports, includes', hashBlackbox template)
    where
      includes' = map (\(nms, bb) -> (nms, hashBlackbox bb)) includes
      hashBlackbox (BBTemplate bbTemplate) = hash bbTemplate
      hashBlackbox (BBFunction bbName bbHash _bbFunc) = hash (bbName, bbHash)
hashCompiledPrimMap :: CompiledPrimMap -> Int
hashCompiledPrimMap cpm = hash (map (fmap hashCompiledPrimitive) orderedValues)
  where
    
    orderedKeys   = sort (HashMap.keys cpm)
    orderedValues = map (cpm HashMapStrict.!) orderedKeys
resolveTemplateSource
  :: HasCallStack
  => FilePath
  -> TemplateSource
  -> IO Text
resolveTemplateSource _metaPath (TInline text) =
  return text
resolveTemplateSource metaPath (TFile path) =
  let path' = FilePath.replaceFileName metaPath path in
  either (error . show) id <$> (tryIOError $ T.readFile path')
resolvePrimitive'
  :: HasCallStack
  => FilePath
  -> UnresolvedPrimitive
  -> IO (TS.Text, GuardedResolvedPrimitive)
resolvePrimitive' _metaPath (Primitive name wf primType) =
  return (name, HasBlackBox (Primitive name wf primType))
resolvePrimitive' metaPath BlackBox{template=t, includes=i, resultName=r, resultInit=ri, ..} = do
  let resolveSourceM = traverse (traverse (resolveTemplateSource metaPath))
  bb <- BlackBox name workInfo renderVoid kind () outputReg libraries imports functionPlurality
          <$> mapM (traverse resolveSourceM) i
          <*> traverse resolveSourceM r
          <*> traverse resolveSourceM ri
          <*> resolveSourceM t
  case warning of
    Just w  -> pure (name, WarnNonSynthesizable (TS.unpack w) bb)
    Nothing -> pure (name, HasBlackBox bb)
resolvePrimitive' metaPath (BlackBoxHaskell bbName wf usedArgs funcName t) =
  (bbName,) . HasBlackBox . BlackBoxHaskell bbName wf usedArgs funcName <$>
    (mapM (resolveTemplateSource metaPath) t)
resolvePrimitive
  :: HasCallStack
  => FilePath
  -> IO [(TS.Text, GuardedResolvedPrimitive)]
resolvePrimitive fileName = do
  prims <- decodeOrErr fileName <$> LZ.readFile fileName
  mapM (resolvePrimitive' fileName) prims
addGuards
  :: ResolvedPrimMap
  -> [(TS.Text, PrimitiveGuard ())]
  -> ResolvedPrimMap
addGuards = foldl go
 where
  lookupPrim :: TS.Text -> ResolvedPrimMap -> Maybe ResolvedPrimitive
  lookupPrim nm primMap = join (extractPrim <$> HashMapStrict.lookup nm primMap)
  go primMap (nm, guard) =
    HashMapStrict.insert
      nm
      (case (lookupPrim nm primMap, guard) of
        (Nothing, HasBlackBox _) ->
          error $ "No BlackBox definition for '" ++ TS.unpack nm ++ "' even"
               ++ " though this value was annotated with 'HasBlackBox'."
        (Nothing, WarnNonSynthesizable _ _) ->
          error $ "No BlackBox definition for '" ++ TS.unpack nm ++ "' even"
               ++ " though this value was annotated with 'WarnNonSynthesizable'"
               ++ ", implying it has a BlackBox."
        (Nothing, WarnAlways _ _) ->
          error $ "No BlackBox definition for '" ++ TS.unpack nm ++ "' even"
               ++ " though this value was annotated with 'WarnAlways'"
               ++ ", implying it has a BlackBox."
        (Just _, DontTranslate) ->
          error (TS.unpack nm ++ " was annotated with DontTranslate, but a "
                                 ++ "BlackBox definition was found anyway.")
        (Nothing, DontTranslate) -> DontTranslate
        (Just p, g) -> fmap (const p) g)
      primMap
generatePrimMap
  :: HasCallStack
  => [UnresolvedPrimitive]
  
  
  -> [(TS.Text, PrimitiveGuard ())]
  -> [FilePath]
  
  -> IO ResolvedPrimMap
generatePrimMap unresolvedPrims primGuards filePaths = do
  primitiveFiles <- fmap concat $ mapM
     (\filePath -> do
         fpExists <- Directory.doesDirectoryExist filePath
         if fpExists
           then
             fmap ( map (FilePath.combine filePath)
                  . filter (isSuffixOf ".json")
                  ) (Directory.getDirectoryContents filePath)
           else
             return []
     ) filePaths
  primitives0 <- concat <$> mapM resolvePrimitive primitiveFiles
  let metapaths = map (TS.unpack . name) unresolvedPrims
  primitives1 <- sequence $ zipWith resolvePrimitive' metapaths unresolvedPrims
  let primMap = HashMap.fromList (primitives0 ++ primitives1)
  return (force (addGuards primMap primGuards))
{-# SCC generatePrimMap #-}
constantArgs :: TS.Text -> CompiledPrimitive -> Set.Set Int
constantArgs nm BlackBox {template = templ@(BBTemplate _), resultInit = tRIM} =
  Set.fromList (concat [ fromIntForce
                       , maybe [] walkTemplate tRIM
                       , walkTemplate templ
                       ])
 where
  walkTemplate (BBTemplate t) = concatMap (walkElement getConstant) t
  walkTemplate _ = []
  getConstant (Lit i)      = Just i
  getConstant (Const i)    = Just i
  getConstant _            = Nothing
  
  
  
  
  
  fromIntForce
    | nm == "Clash.Sized.Internal.BitVector.fromInteger#"  = [2]
    | nm == "Clash.Sized.Internal.BitVector.fromInteger##" = [0,1]
    | nm == "Clash.Sized.Internal.Index.fromInteger#"      = [1]
    | nm == "Clash.Sized.Internal.Signed.fromInteger#"     = [1]
    | nm == "Clash.Sized.Internal.Unsigned.fromInteger#"   = [1]
    
    
    
    
    
    
    
    
    
    
    
    | nm == "Clash.Sized.Vector.index_int"                 = [1,2]
    | nm == "Clash.Sized.Vector.replace_int"               = [1,2]
    | otherwise = []
constantArgs _ _ = Set.empty
getFunctionPlurality' :: [(Int, Int)] -> Int -> Int
getFunctionPlurality' functionPlurality n =
  fromMaybe 1 (snd <$> (find ((== n) . fst) functionPlurality))
getFunctionPlurality
  :: HasCallStack
  => CompiledPrimitive
  -> [Either Term Type]
  
  -> Type
  
  -> Int
  
  -> NetlistMonad Int
  
  
getFunctionPlurality (Primitive {}) _args _resTy _n = pure 1
getFunctionPlurality (BlackBoxHaskell {name, function, functionName}) args resTy n = do
  errOrMeta <- preserveState ((snd function) False name args resTy)
  case errOrMeta of
    Left err ->
      error $ concat [ "Tried to determine function plurality for "
                     , TS.unpack name, " by quering ", show functionName
                     , ". Function returned an error message instead:\n\n"
                     , err ]
    Right (BlackBoxMeta {bbFunctionPlurality}, _bb) ->
      pure (getFunctionPlurality' bbFunctionPlurality n)
getFunctionPlurality (BlackBox {functionPlurality}) _args _resTy n =
  pure (getFunctionPlurality' functionPlurality n)