{-# LANGUAGE AllowAmbiguousTypes       #-}
{-# LANGUAGE DataKinds                 #-}
{-# LANGUAGE DerivingStrategies        #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts          #-}
{-# LANGUAGE MagicHash                 #-}
{-# LANGUAGE TypeInType                #-}
{-# LANGUAGE UndecidableInstances      #-}
module Membrain.Memory
       ( 
         Memory (..)
       , memory
       , toMemory
       , showMemory
       , readMemory
         
       , toBits
       , toRat
       , floor
         
       , memoryMul
       , memoryDiff
       , memoryPlus
       , memoryDiv
         
         
       , AnyMemory (..)
       ) where
import Prelude hiding (floor)
import Data.Char (isDigit, isSpace)
import Data.Coerce (coerce)
import Data.Foldable (foldl')
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Ratio (Ratio, (%))
import Data.Semigroup (Semigroup (..))
import GHC.Exts (Proxy#, proxy#)
import GHC.Generics (Generic)
import GHC.TypeNats (KnownNat, Nat, natVal')
import Numeric.Natural (Natural)
import Membrain.Units (KnownUnitSymbol, unitSymbol)
import qualified Prelude
newtype Memory (mem :: Nat) = Memory
    { unMemory :: Natural
    } deriving stock   (Show, Read, Generic)
      deriving newtype (Eq, Ord)
instance Semigroup (Memory (mem :: Nat)) where
    (<>) :: Memory mem -> Memory mem -> Memory mem
    (<>) = coerce ((+) @Natural)
    {-# INLINE (<>) #-}
    sconcat :: NonEmpty (Memory mem) -> Memory mem
    sconcat = foldl' (<>) mempty
    {-# INLINE sconcat #-}
    stimes :: Integral b => b -> Memory mem -> Memory mem
    stimes n (Memory m) = Memory (fromIntegral n * m)
    {-# INLINE stimes #-}
instance Monoid (Memory (mem :: Nat)) where
    mempty :: Memory mem
    mempty = Memory 0
    {-# INLINE mempty #-}
    mappend :: Memory mem -> Memory mem -> Memory mem
    mappend = (<>)
    {-# INLINE mappend #-}
    mconcat :: [Memory mem] -> Memory mem
    mconcat = foldl' (<>) mempty
    {-# INLINE mconcat #-}
showMemory :: forall mem . (KnownNat mem, KnownUnitSymbol mem) => Memory mem -> String
showMemory (Memory m) = showFrac m (nat @mem) ++ unitSymbol @mem
  where
    showFrac :: Natural -> Natural -> String
    showFrac number d = goIntegral number
      where
        
        goIntegral :: Natural -> String
        goIntegral n =
            let (q, r) = n `divMod` d
                integral = show q
            in if r == 0
               then integral
               else integral ++ '.' : goFractional r
        
        goFractional :: Natural -> String
        goFractional 0 = ""
        goFractional n =
            let (q, r) = (n * 10) `divMod` d
            in show q ++ goFractional r
readMemory
    :: forall (mem :: Nat)
     . (KnownUnitSymbol mem, KnownNat mem)
    => String
    -> Maybe (Memory mem)
readMemory (dropWhile isSpace -> str) = case span isDigit str of
    ([], _) -> Nothing
    (_, []) -> Nothing
    (ds, '.': rest) -> case span isDigit rest of
        ([], _)           -> Nothing
        (numerator, unit) -> makeMemory ds numerator unit
    (ds, unit) -> makeMemory ds "0" unit
  where
    makeMemory :: String -> String -> String -> Maybe (Memory mem)
    makeMemory (read @Natural -> whole) numStr u =
        if unitSymbol @mem == u
        then case ((whole * numPow + num) * unit) `divMod` numPow of
            (b, 0) -> Just $ Memory b
            _      -> Nothing
        else Nothing
      where
          unit :: Natural
          unit = nat @mem
          num :: Natural
          num = read @Natural numStr
          numPow :: Natural
          numPow = 10 ^ length numStr
memory :: forall (mem :: Nat) . KnownNat mem => Natural -> Memory mem
memory = Memory . (* nat @mem)
{-# INLINE memory #-}
toMemory :: forall (to :: Nat) (from :: Nat) . Memory from -> Memory to
toMemory = coerce
{-# INLINE toMemory #-}
toBits :: Memory mem -> Natural
toBits = coerce
{-# INLINE toBits #-}
toRat :: forall (mem :: Nat) . KnownNat mem => Memory mem -> Ratio Natural
toRat (Memory m) = m % nat @mem
{-# INLINE toRat #-}
floor
    :: forall (n :: Type) (mem :: Nat) .
       (Integral n, KnownNat mem)
    => Memory mem
    -> n
floor = Prelude.floor . toRat
{-# INLINE floor #-}
{-# SPECIALIZE floor :: KnownNat mem => Memory mem -> Int     #-}
{-# SPECIALIZE floor :: KnownNat mem => Memory mem -> Word    #-}
{-# SPECIALIZE floor :: KnownNat mem => Memory mem -> Integer #-}
{-# SPECIALIZE floor :: KnownNat mem => Memory mem -> Natural #-}
memoryMul  :: Natural -> Memory mem -> Memory mem
memoryMul = stimes
{-# INLINE memoryMul #-}
memoryDiff :: Memory mem -> Memory mem -> (Ordering, Memory mem)
memoryDiff (Memory m1) (Memory m2) = case compare m1 m2 of
    LT -> (LT, Memory $ m2 - m1)
    GT -> (GT, Memory $ m1 - m2)
    EQ -> (EQ, Memory 0)
{-# INLINE memoryDiff #-}
memoryPlus :: Memory mem1 -> Memory mem2 -> Memory mem2
memoryPlus m1 = (<>) (toMemory m1)
{-# INLINE memoryPlus #-}
memoryDiv :: Memory mem1 -> Memory mem2 -> Ratio Natural
memoryDiv (Memory m1) (Memory m2) = m1 % m2
{-# INLINE memoryDiv #-}
data AnyMemory
    = forall (mem :: Nat) . (KnownNat mem, KnownUnitSymbol mem)
    => MkAnyMemory (Memory mem)
instance Show AnyMemory where
    show (MkAnyMemory t) = showMemory t
nat :: forall (mem :: Nat) . KnownNat mem => Natural
nat = natVal' (proxy# :: Proxy# mem)
{-# INLINE nat #-}