{-# LANGUAGE ConstraintKinds#-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module What4.BaseTypes
  ( 
    type BaseType
    
  , BaseBoolType
  , BaseIntegerType
  , BaseNatType
  , BaseRealType
  , BaseStringType
  , BaseBVType
  , BaseFloatType
  , BaseComplexType
  , BaseStructType
  , BaseArrayType
    
  , StringInfo
    
  , Char8
  , Char16
  , Unicode
    
  , type FloatPrecision
  , type FloatPrecisionBits
    
  , FloatingPointPrecision
    
  , Prec16
  , Prec32
  , Prec64
  , Prec80
  , Prec128
    
  , BaseTypeRepr(..)
  , FloatPrecisionRepr(..)
  , StringInfoRepr(..)
  , arrayTypeIndices
  , arrayTypeResult
  , floatPrecisionToBVType
  , lemmaFloatPrecisionIsPos
  , module Data.Parameterized.NatRepr
    
  , KnownRepr(..)  
  , KnownCtx
  ) where
import           Data.Hashable
import           Data.Kind
import           Data.Parameterized.Classes
import qualified Data.Parameterized.Context as Ctx
import           Data.Parameterized.NatRepr
import           Data.Parameterized.TH.GADT
import           GHC.TypeNats as TypeNats
import           Text.PrettyPrint.ANSI.Leijen
type KnownCtx f = KnownRepr (Ctx.Assignment f)
data StringInfo
     
   = Char8
     
   | Char16
     
   | Unicode
type Char8   = 'Char8   
type Char16  = 'Char16  
type Unicode = 'Unicode 
data BaseType
     
   = BaseBoolType
     
   | BaseNatType
     
   | BaseIntegerType
     
   | BaseRealType
     
   | BaseBVType TypeNats.Nat
     
     
   | BaseFloatType FloatPrecision
     
   | BaseStringType StringInfo
     
   | BaseComplexType
     
   | BaseStructType (Ctx.Ctx BaseType)
     
     
     
     
     
     
   | BaseArrayType  (Ctx.Ctx BaseType) BaseType
type BaseBoolType    = 'BaseBoolType    
type BaseIntegerType = 'BaseIntegerType 
type BaseNatType     = 'BaseNatType     
type BaseRealType    = 'BaseRealType    
type BaseBVType      = 'BaseBVType      
type BaseFloatType   = 'BaseFloatType   
type BaseStringType  = 'BaseStringType  
type BaseComplexType = 'BaseComplexType 
type BaseStructType  = 'BaseStructType  
type BaseArrayType   = 'BaseArrayType   
data FloatPrecision where
  FloatingPointPrecision :: TypeNats.Nat   
                         -> TypeNats.Nat   
                         -> FloatPrecision
type FloatingPointPrecision = 'FloatingPointPrecision 
type family FloatPrecisionBits (fpp :: FloatPrecision) :: Nat where
  FloatPrecisionBits (FloatingPointPrecision eb sb) = eb + sb
type Prec16  = FloatingPointPrecision  5  11
type Prec32  = FloatingPointPrecision  8  24
type Prec64  = FloatingPointPrecision 11  53
type Prec80  = FloatingPointPrecision 15  65
type Prec128 = FloatingPointPrecision 15 113
data BaseTypeRepr (bt::BaseType) :: Type where
   BaseBoolRepr    :: BaseTypeRepr BaseBoolType
   BaseBVRepr      :: (1 <= w) => !(NatRepr w) -> BaseTypeRepr (BaseBVType w)
   BaseNatRepr     :: BaseTypeRepr BaseNatType
   BaseIntegerRepr :: BaseTypeRepr BaseIntegerType
   BaseRealRepr    :: BaseTypeRepr BaseRealType
   BaseFloatRepr   :: !(FloatPrecisionRepr fpp) -> BaseTypeRepr (BaseFloatType fpp)
   BaseStringRepr  :: StringInfoRepr si -> BaseTypeRepr (BaseStringType si)
   BaseComplexRepr :: BaseTypeRepr BaseComplexType
   
   BaseStructRepr :: !(Ctx.Assignment BaseTypeRepr ctx)
                  -> BaseTypeRepr (BaseStructType ctx)
   BaseArrayRepr :: !(Ctx.Assignment BaseTypeRepr (idx Ctx.::> tp))
                 -> !(BaseTypeRepr xs)
                 -> BaseTypeRepr (BaseArrayType (idx Ctx.::> tp) xs)
data FloatPrecisionRepr (fpp :: FloatPrecision) where
  FloatingPointPrecisionRepr
    :: (2 <= eb, 2 <= sb)
    => !(NatRepr eb)
    -> !(NatRepr sb)
    -> FloatPrecisionRepr (FloatingPointPrecision eb sb)
data StringInfoRepr (si::StringInfo) where
  Char8Repr     :: StringInfoRepr Char8
  Char16Repr    :: StringInfoRepr Char16
  UnicodeRepr   :: StringInfoRepr Unicode
arrayTypeIndices :: BaseTypeRepr (BaseArrayType idx tp)
                 -> Ctx.Assignment BaseTypeRepr idx
arrayTypeIndices (BaseArrayRepr i _) = i
arrayTypeResult :: BaseTypeRepr (BaseArrayType idx tp) -> BaseTypeRepr tp
arrayTypeResult (BaseArrayRepr _ rtp) = rtp
floatPrecisionToBVType
  :: FloatPrecisionRepr (FloatingPointPrecision eb sb)
  -> BaseTypeRepr (BaseBVType (eb + sb))
floatPrecisionToBVType fpp@(FloatingPointPrecisionRepr eb sb)
  | LeqProof <- lemmaFloatPrecisionIsPos fpp
  = BaseBVRepr $ addNat eb sb
lemmaFloatPrecisionIsPos
  :: forall eb' sb'
   . FloatPrecisionRepr (FloatingPointPrecision eb' sb')
  -> LeqProof 1 (eb' + sb')
lemmaFloatPrecisionIsPos (FloatingPointPrecisionRepr eb sb)
  | LeqProof <- leqTrans (LeqProof @1 @2) (LeqProof @2 @eb')
  , LeqProof <- leqTrans (LeqProof @1 @2) (LeqProof @2 @sb')
  = leqAddPos eb sb
instance KnownRepr BaseTypeRepr BaseBoolType where
  knownRepr = BaseBoolRepr
instance KnownRepr BaseTypeRepr BaseIntegerType where
  knownRepr = BaseIntegerRepr
instance KnownRepr BaseTypeRepr BaseNatType where
  knownRepr = BaseNatRepr
instance KnownRepr BaseTypeRepr BaseRealType where
  knownRepr = BaseRealRepr
instance KnownRepr StringInfoRepr si => KnownRepr BaseTypeRepr (BaseStringType si) where
  knownRepr = BaseStringRepr knownRepr
instance (1 <= w, KnownNat w) => KnownRepr BaseTypeRepr (BaseBVType w) where
  knownRepr = BaseBVRepr knownNat
instance (KnownRepr FloatPrecisionRepr fpp) => KnownRepr BaseTypeRepr (BaseFloatType fpp) where
  knownRepr = BaseFloatRepr knownRepr
instance KnownRepr BaseTypeRepr BaseComplexType where
  knownRepr = BaseComplexRepr
instance KnownRepr (Ctx.Assignment BaseTypeRepr) ctx
      => KnownRepr BaseTypeRepr (BaseStructType ctx) where
  knownRepr = BaseStructRepr knownRepr
instance ( KnownRepr (Ctx.Assignment BaseTypeRepr) idx
         , KnownRepr BaseTypeRepr tp
         , KnownRepr BaseTypeRepr t
         )
      => KnownRepr BaseTypeRepr (BaseArrayType (idx Ctx.::> tp) t) where
  knownRepr = BaseArrayRepr knownRepr knownRepr
instance (2 <= eb, 2 <= es, KnownNat eb, KnownNat es) => KnownRepr FloatPrecisionRepr (FloatingPointPrecision eb es) where
  knownRepr = FloatingPointPrecisionRepr knownNat knownNat
instance KnownRepr StringInfoRepr Char8 where
  knownRepr = Char8Repr
instance KnownRepr StringInfoRepr Char16 where
  knownRepr = Char16Repr
instance KnownRepr StringInfoRepr Unicode where
  knownRepr = UnicodeRepr
$(return [])
instance HashableF BaseTypeRepr where
  hashWithSaltF = hashWithSalt
instance Hashable (BaseTypeRepr bt) where
  hashWithSalt = $(structuralHashWithSalt [t|BaseTypeRepr|] [])
instance HashableF FloatPrecisionRepr where
  hashWithSaltF = hashWithSalt
instance Hashable (FloatPrecisionRepr fpp) where
  hashWithSalt = $(structuralHashWithSalt [t|FloatPrecisionRepr|] [])
instance HashableF StringInfoRepr where
  hashWithSaltF = hashWithSalt
instance Hashable (StringInfoRepr si) where
  hashWithSalt = $(structuralHashWithSalt [t|StringInfoRepr|] [])
instance Pretty (BaseTypeRepr bt) where
  pretty = text . show
instance Show (BaseTypeRepr bt) where
  showsPrec = $(structuralShowsPrec [t|BaseTypeRepr|])
instance ShowF BaseTypeRepr
instance Pretty (FloatPrecisionRepr fpp) where
  pretty = text . show
instance Show (FloatPrecisionRepr fpp) where
  showsPrec = $(structuralShowsPrec [t|FloatPrecisionRepr|])
instance ShowF FloatPrecisionRepr
instance Pretty (StringInfoRepr si) where
  pretty = text . show
instance Show (StringInfoRepr si) where
  showsPrec = $(structuralShowsPrec [t|StringInfoRepr|])
instance ShowF StringInfoRepr
instance TestEquality BaseTypeRepr where
  testEquality = $(structuralTypeEquality [t|BaseTypeRepr|]
                   [ (TypeApp (ConType [t|NatRepr|]) AnyType, [|testEquality|])
                   , (TypeApp (ConType [t|FloatPrecisionRepr|]) AnyType, [|testEquality|])
                   , (TypeApp (ConType [t|StringInfoRepr|]) AnyType, [|testEquality|])
                   , (TypeApp (ConType [t|BaseTypeRepr|]) AnyType, [|testEquality|])
                   , ( TypeApp (TypeApp (ConType [t|Ctx.Assignment|]) AnyType) AnyType
                     , [|testEquality|]
                     )
                   ]
                  )
instance OrdF BaseTypeRepr where
  compareF = $(structuralTypeOrd [t|BaseTypeRepr|]
                   [ (TypeApp (ConType [t|NatRepr|]) AnyType, [|compareF|])
                   , (TypeApp (ConType [t|FloatPrecisionRepr|]) AnyType, [|compareF|])
                   , (TypeApp (ConType [t|StringInfoRepr|]) AnyType, [|compareF|])
                   , (TypeApp (ConType [t|BaseTypeRepr|]) AnyType, [|compareF|])
                   , (TypeApp (TypeApp (ConType [t|Ctx.Assignment|]) AnyType) AnyType
                     , [|compareF|]
                     )
                   ]
                  )
instance TestEquality FloatPrecisionRepr where
  testEquality = $(structuralTypeEquality [t|FloatPrecisionRepr|]
      [(TypeApp (ConType [t|NatRepr|]) AnyType, [|testEquality|])]
    )
instance OrdF FloatPrecisionRepr where
  compareF = $(structuralTypeOrd [t|FloatPrecisionRepr|]
      [(TypeApp (ConType [t|NatRepr|]) AnyType, [|compareF|])]
    )
instance TestEquality StringInfoRepr where
  testEquality = $(structuralTypeEquality [t|StringInfoRepr|] [])
instance OrdF StringInfoRepr where
  compareF = $(structuralTypeOrd [t|StringInfoRepr|] [])