checks.spr: add SPRCheck.

main
Jean-François Nguyen 2 years ago
parent ca66e3a45e
commit 2988ffc617

@ -1,4 +1,5 @@
from .cr import *
from .gpr import *
from .ia_fwd import *
from .spr import *
from .unique import *

@ -0,0 +1,115 @@
from amaranth import *
from amaranth.asserts import *

from .. import pfv


__all__ = ["SPRCheck", "SPRCover"]


class SPRCheck(Elaboratable):
"""Special Purpose Registers check.

Checks that reads from supported SPRs are consistent with the last value that was written to
them.
"""
def __init__(self):
self.pre = Signal()
self.post = Signal()
self.pfv = pfv.Interface()

def elaborate(self, platform):
m = Module()

spec_order = AnyConst(self.pfv.order.width)

lr_written = Signal()
lr_shadow = Signal(64)
ctr_written = Signal()
ctr_shadow = Signal(64)
xer_written = Signal()
xer_shadow = Signal(64)
tar_written = Signal()
tar_shadow = Signal(64)

with m.If(self.pfv.stb & (self.pfv.order <= spec_order)):
with m.If(self.pfv.lr.w_stb):
m.d.sync += [
lr_written.eq(1),
lr_shadow .eq(self.pfv.lr.w_data),
]
with m.If(self.pfv.ctr.w_stb):
m.d.sync += [
ctr_written.eq(1),
ctr_shadow .eq(self.pfv.ctr.w_data),
]
with m.If(self.pfv.xer.w_stb):
m.d.sync += [
xer_written.eq(1),
xer_shadow .eq(self.pfv.xer.w_data),
]
with m.If(self.pfv.tar.w_stb):
m.d.sync += [
tar_written.eq(1),
tar_shadow .eq(self.pfv.tar.w_data),
]

with m.If(self.post):
m.d.sync += [
Assume(Past(self.pfv.stb)),
Assume(Past(self.pfv.order) == spec_order),
]
with m.If(Past(self.pfv.lr.r_stb)):
m.d.sync += Assert(Past(lr_shadow) == Past(self.pfv.lr.r_data))
with m.If(Past(self.pfv.ctr.r_stb)):
m.d.sync += Assert(Past(ctr_shadow) == Past(self.pfv.ctr.r_data))
with m.If(Past(self.pfv.xer.r_stb)):
m.d.sync += Assert(Past(xer_shadow) == Past(self.pfv.xer.r_data))
with m.If(Past(self.pfv.tar.r_stb)):
m.d.sync += Assert(Past(tar_shadow) == Past(self.pfv.tar.r_data))

return m


class SPRCover(Elaboratable):
def __init__(self):
self.pre = Signal()
self.post = Signal()
self.pfv = pfv.Interface()

def elaborate(self, platform):
m = Module()

insn_count = Signal(range(4))
lr_written = Signal()
ctr_written = Signal()
xer_written = Signal()
tar_written = Signal()

with m.If(self.pfv.stb):
m.d.sync += [
insn_count .eq(insn_count + 1),
lr_written .eq(self.pfv.lr .w_stb),
ctr_written.eq(self.pfv.ctr.w_stb),
xer_written.eq(self.pfv.xer.w_stb),
tar_written.eq(self.pfv.tar.w_stb),
]

cover_1 = Signal()
cover_2 = Signal()
cover_3 = Signal()
cover_4 = Signal()

m.d.comb += [
cover_1.eq((insn_count > 1) & lr_written),
cover_2.eq((insn_count > 1) & ctr_written),
cover_3.eq((insn_count > 1) & xer_written),
cover_4.eq((insn_count > 1) & tar_written),

Cover(cover_1),
Cover(cover_2),
Cover(cover_3),
Cover(cover_4),
]

return m
Loading…
Cancel
Save