module ShamII where

import Hardware.Chalk hiding (zip, unzip)
import qualified Hardware.Chalk.Circuit as Chalk (zip,unzip)
import Control.Applicative
import Control.Monad.State
import Data.Stream (Stream(..))

data Reg = R0 | R1 | R2 | R3 deriving (Show, Eq)

type Regs = (Int, Int, Int, Int) 

data Cmd = ADD | SUB | INC deriving (Show, Eq)

initRegs :: Regs
initRegs = (0,0,0,0)

regFile ::
  Signal Reg                    -- write port
  -> Signal Int                 -- write val
  -> Signal Reg                 -- first read port
  -> Signal Reg                 -- second read port
  -> Signal (Int, Int)          -- read port outputs and next state
regFile wr val rd1 rd2 = 
  loop (regStep <$> wr <*> val <*> rd1 <*> rd2) initRegs

regStep :: Reg -> Int -> Reg -> Reg -> Regs -> ((Int, Int), Regs)
regStep wr x rd1 rd2 regs =
  let regs' = updateReg (wr,x) regs
  in ((lookupReg rd1 regs', lookupReg rd2 regs'), regs')  

updateReg (R0,x) (a,b,c,d) = (x,b,c,d)
updateReg (R1,x) (a,b,c,d) = (a,x,c,d)
updateReg (R2,x) (a,b,c,d) = (a,b,x,d)
updateReg (R3,x) (a,b,c,d) = (a,b,c,x)

lookupReg R0 (a,b,c,d) = a
lookupReg R1 (a,b,c,d) = b
lookupReg R2 (a,b,c,d) = c
lookupReg R3 (a,b,c,d) = d

alu :: Signal Cmd -> Signal (Int, Int) -> Signal Int
alu cmds (xys) = interpret <$> cmds <*> xys
  where
  interpret ADD (x,y) = x + y
  interpret SUB (x,y) = x - y
  interpret INC (x,_) = x + 1

sham :: Signal Cmd -> Signal Reg -> Signal Reg -> Signal Reg -> (Signal Reg, Signal Int)
sham cmd dest srcA srcB = (destReg'' , aluOutput')
  where

  valueAB = regFile dest' aluOutput' srcA srcB
  valueAB' = delay (0,0) valueAB
  (valueA, valueB) = Chalk.unzip valueAB
  dest' = delay R0 dest
  cmd = delay ADD cmd
  
  aluInputA = select validA valueA aluOutput'
  aluInputB = select validB valueB aluOutput'

  -- alu outputs
  aluOutput = alu cmd (Chalk.zip (aluInputA, aluInputB))
  aluOutput' = delay 0 aluOutput
  destReg'' = delay R0 dest'
  -- Control logic
  validA = delay True (noHazard srcA)
  validB = delay True (noHazard srcB)
  noHazard :: Signal Reg -> Signal Bool
  noHazard src = pure (\s d -> s /= d || d == R0) <*> src <*> dest'
