{-|
Module      : Nauty.Digraph6.Internal
Description : Internal functions.
Copyright   : (c) Marcelo Garlet Milani, 2026
License     : MIT
Maintainer  : mgmilani@pm.me
Stability   : unstable

This module contains internal functions used by the "Nauty.Digraph6" module.
Except for test cases, you should not import this module.
-}

{-# LANGUAGE OverloadedStrings #-}

module Nauty.Digraph6.Internal where

import           Control.Monad.Trans.Class
import           Control.Monad.Trans.State
import           Data.Bits
import           Data.Word
import           Nauty.Internal.Encoding
import           Nauty.Internal.Parsing
import qualified Data.Array.Unboxed as A
import qualified Data.ByteString.Lazy as B
import qualified Data.Text.Lazy as T
import qualified Data.Text.Lazy.Encoding as T

-- | A digraph represented as an adjacency matrix.
data AdjacencyMatrix = AdjacencyMatrix
  { numberOfVertices :: Word64
  -- | The adjacency matrix, represented as a bit string, stored column-by-column.
  , adjacency :: A.UArray Word64 Word8
  } deriving (Eq, Show)
  
-- | Whether there is an arc from the first vertex to the second.
arcExists :: AdjacencyMatrix
          -> Word64 -- ^ Tail
          -> Word64 -- ^ Head
          -> Bool
arcExists m u v =
    let i = u * (numberOfVertices m) + v 
        b = i `div` 8
    in testBit ((adjacency m) A.! b) (7 - ((fromIntegral i) `mod` 8))

-- | Encode a graph in @digraph6@ format.
encode :: AdjacencyMatrix -> T.Text
encode m = 
  "&"
  `T.append`
  (encodeNumber $ numberOfVertices m)
  `T.append`
  (encodeMatrix m)

-- |Encode the adjacency matrix.
encodeMatrix :: AdjacencyMatrix -> T.Text
encodeMatrix m = 
  let n = numberOfVertices m
      bits = n * n
      lastValidBits' = bits `mod` 8
      lastValidBits = if lastValidBits' == 0 then 8 else lastValidBits'
  in
    T.pack $ map (toEnum . fromIntegral . (+63)) $ encodeVector lastValidBits (A.elems $ adjacency m) 0 6

-- | Create an adjacency matrix from a list of arcs.
-- Vertices need to be in the range from @0@ to @n - 1@.
fromArcList :: Word64 -- ^ Number of vertices
            -> [(Word64, Word64)] -- ^ List of arcs
            -> AdjacencyMatrix
fromArcList n es = 
  AdjacencyMatrix
  { numberOfVertices = n
  , adjacency = A.accumArray (.|.) 0 (0, ( max 0 $ ((n * n) - 1) `div` 8))
      [ (block, shiftL 1 (7 - (fromIntegral $ bitI `mod` 8)))
      | (v,u) <- es
      , let bitI = v * n + u 
      , let block = bitI `div` 8
      ]
  }

-- | The the number of vertices of a digraph together with a list of its arcs.
toArcList :: AdjacencyMatrix -> (Word64, [(Word64, Word64)])
toArcList m =
  ( numberOfVertices m
  , arcs 7 0 0 $ A.elems $ adjacency m)
  where
    arcs _ _ _ [] = []
    arcs i v u (b:bs)
      | i == -1 = arcs 7 v u bs
      | u == numberOfVertices m = arcs i (v + 1) 0 (b:bs)
      | v == numberOfVertices m = []
      | otherwise = 
        if testBit b i then
          (v,u) : arcs (i - 1) v (u + 1) (b:bs)
        else
          arcs (i - 1) v (u + 1) (b:bs)

-- |Parse all digraphs in the input text.
-- Digraphs are stored one per line.
parse :: T.Text -> [Either T.Text AdjacencyMatrix]
parse t = 
  let t' = header ">>digraph6<<" t
  in map digraph $ T.lines t'

-- |Parse a single digraph in @digraph6@ format.
digraph :: T.Text -> Either T.Text AdjacencyMatrix
digraph t = (flip evalStateT) (T.encodeUtf8 t) $ do
  h <- consume 1
  if B.unpack h /= [fromIntegral $ fromEnum '&'] then
    lift $ Left $ "Expected '&', but found " <> (T.pack $ map (toEnum . fromIntegral) $ B.unpack h)
  else do
    n <- parseNumber
    if n == 0 then
      return AdjacencyMatrix{numberOfVertices = 0, adjacency = A.array (0,0) [(0,0)]}
    else do
      parseMatrix n

-- |Parse the adjacency matrix of a digraph.
parseMatrix :: Word64 -> StateT B.ByteString (Either T.Text) AdjacencyMatrix
parseMatrix n = do
  v <- parseVector (n * n)
  return $ AdjacencyMatrix
    { numberOfVertices = n
    , adjacency = A.array (0, ( max 0 $ n * n - 1) `div` 8)
                          $ zip [0..] $ B.unpack v
    }

