module ShamV where

import Hardware.Chalk
import Control.Applicative
import Control.Monad.State
import Data.Stream (Stream(..))
import Data.Maybe
import Fifo

data Reg = R0 | R1 | R2 | R3 deriving (Show, Eq)

type Regs = (Int, Int, Int, Int) 

data Cmd = ADD | SUB | INC | LOAD | STORE deriving (Show, Eq)

type Operand = (Reg, Maybe Int)
data Transaction = 
  Transaction {dest :: Operand, cmd :: Cmd, srcs :: [Operand]}
  deriving (Show, Eq)

setDest :: Transaction -> Int -> Transaction
setDest (Transaction (r,_) cmd srcs) i = Transaction (r, Just i) cmd srcs

defaultDelay :: Signal Bool -> a -> Signal a -> Signal a
defaultDelay cs x xs = select cs (pure x) xs

isLoad :: Signal Transaction -> Signal Bool
isLoad ts = pure isL <*> ts
  where
  isL (Transaction _ LOAD _) = True
  isL _ = False

instrCache :: 
  Signal Bool -> Signal Transaction -> Signal Transaction
instrCache cs ts = loop (cache <$> cs <*> ts) new
  where
  cache :: Bool -> Transaction -> Fifo Transaction -> 
    (Transaction, Fifo Transaction)
  cache c t fifo = 
      case dequeue fifo of
        Nothing -> (t , update c t fifo)
        Just (x,fifo') -> (x , update c t fifo')
  update c t f = if c then enqueue t f else f


-- How should this work?
--  - use a loop and maintain a copy of the memory locally?
--  - revise definition?
mem :: Signal Transaction -> Signal Transaction
mem ts = pure m <*> ts
  where
  m (Transaction (dest, _) LOAD src) = 
    Transaction (dest, undefined) LOAD src
  m (Transaction (dest, n) STORE src) = 
    Transaction (R0, undefined) STORE undefined
  m t = t

transHazard :: Signal Transaction -> Signal Transaction -> Signal Bool
transHazard as bs = pure detectHazard <*> as <*> bs
  where
  detectHazard t1 t2 = reg (dest t1) `elem` (map reg (srcs t2))

sham3Trans :: Signal Transaction -> Signal Transaction
sham3Trans inputs  = memOut'
  where
  -- register fetch stage --
  instr = instrCache loadHzd inputs
  readyInstr = regFile memOut' instr
  readyInstr' = defaultDelay loadHzd nop readyInstr

  -- ALU stage --
  aluIn, aluOut, aluOut' :: Signal Transaction
  aluIn = bypass (bypass readyInstr' memOut') aluOut'
  aluOut = alu aluIn
  aluOut' = delay nop aluOut

  -- memory stage --
  memIn, memOut, memOut' :: Signal Transaction
  memIn = bypass aluOut' memOut'
  memOut = mem memIn
  memOut' = delay nop memOut

  -- control logic --
  loadHzd :: Signal Bool
  loadHzd = pure (&&) <*> (isLoad readyInstr')
                      <*> (transHazard readyInstr readyInstr')

initRegs :: Regs
initRegs = (0,0,0,0)

regFile :: Signal Transaction -> Signal Transaction -> Signal Transaction
regFile writes reads = 
  loop (regStep <$> writes <*> reads) initRegs

regStep :: Transaction -> Transaction -> Regs -> (Transaction , Regs)
regStep write@(Transaction wrOp _ _) read regs
  = let regs' = updateReg wrOp regs
        read' = updateTransaction regs read
    in (read' , regs')

updateReg (R0, Just x) (a,b,c,d) = (x,b,c,d)
updateReg (R1, Just x) (a,b,c,d) = (a,x,c,d)
updateReg (R2, Just x) (a,b,c,d) = (a,b,x,d)
updateReg (R3, Just x) (a,b,c,d) = (a,b,c,x)

updateTransaction :: Regs -> Transaction -> Transaction
updateTransaction regs t = t {srcs = map (updateOperand regs) (srcs t)}

updateOperand regs (r, _ ) = (r , Just (lookupReg r regs))
lookupReg R0 (a,b,c,d) = a
lookupReg R1 (a,b,c,d) = b
lookupReg R2 (a,b,c,d) = c
lookupReg R3 (a,b,c,d) = d

alu :: Signal Transaction -> Signal Transaction
alu cmds = interpret <$> cmds
  where
  interpret :: Transaction -> Transaction
  interpret trans@(Transaction dest cmd srcs) = 
    setDest trans (eval cmd (map (fromJust . snd) srcs))
  eval :: Cmd -> [Int] -> Int
  eval ADD [x, y] = x + y
  eval SUB [x, y] = x - y
  eval INC [x] = x + 1    

nop = Transaction (R0, Just 0) ADD [(R0,Just 0) , (R0,Just 0)]

bypass :: Signal Transaction -> Signal Transaction -> Signal Transaction
bypass ins outs = checkHazard <$> ins <*> outs
  where
  checkHazard t1 t2 =
    let destReg = reg $ dest t2
        sourceRegs = map reg $ srcs t1
        new = t1 {srcs = merge (dest t2) (srcs t1)}
        merge :: Operand -> [Operand] -> [Operand]
        merge o os = map (mergeOp o) os
        mergeOp (r1,x) (r2,y)
          | r1 == r2 = (r2,x)
          | otherwise = (r1,x)
    in if destReg `elem` sourceRegs then new else t1

reg :: Operand -> Reg
reg = fst
