#!/usr/bin/python3
#
# Parse table comments and create equations.

from optparse import OptionParser
import re
from shutil import copyfile

#--------------------------------------------------------------------------------------------------
# Initialize

TYPE_INPUT = 0
TYPE_OUTPUT = 1
TYPE_SKIP = 99

lines = []
tableMatches = []
tableNames = []
tableLines = []
tables = {}

failOnError = True
inFile = 'test.vhdl'
outFileExt = 'vtable'
overwrite = True
backupExt = 'orig'
backup = True
noisy = False
quiet = False
verilog = False

#--------------------------------------------------------------------------------------------------
# Handle command line

usage = 'vtable [options] inFile'

parser = OptionParser(usage)
parser.add_option('-f', '--outfile', dest='outFile', help='output file, default=[inFile]' + outFileExt)
parser.add_option('-o', '--overwrite', dest='overwrite', help='overwrite inFile, default=' + str(overwrite))
parser.add_option('-b', '--backup', dest='backup', help='backup original file, default=' + str(backup))
parser.add_option('-q', '--quiet', dest='quiet', action='store_true', help='quiet messages, default=' + str(quiet))
parser.add_option('-n', '--noisy', dest='noisy', action='store_true', help='noisy messages, default=' + str(noisy))
parser.add_option('-V', '--verilog', dest='verilog', action='store_true', help='source is verilog, default=' + str(verilog))

options, args = parser.parse_args()

if len(args) != 1:
    parser.error(usage)
    quit(-1)
else:
    inFile = args[0]

if options.overwrite == '0':
    overwrite = False
elif options.overwrite == '1':
    overwrite == True
    if options.outFile is not None:
        parser.error('Can\'t specify outfile and overrite!')
        quit(-1)
elif options.overwrite is not None:
    parser.error('overwrite: 0|1')
    quit(-1)

if options.quiet is not None:
    quiet = True

if options.noisy is not None:
    noisy = True

if options.verilog is not None:
    verilog = True

if options.backup == '0':
    backup = False
elif options.backup == '1':
    backup == True
elif options.backup is not None:
    parser.error('backup: 0|1')
    quit(-1)

if options.outFile is not None:
    outFile = options.outFile
elif overwrite:
    outFile = inFile
else:
    outFile = inFile + '.' + outFileExt

backupFile = inFile + '.' + backupExt

#--------------------------------------------------------------------------------------------------
# Objects

class Signal:

    def __init__(self, name, type):
        self.name = name;
        self.type = type;

class Table:

    def __init__(self, name):
        self.name = name
        self.source = []
        self.signals = {}
        self.signalsByCol = {}
        self.typesByCol = {}
        self.specs = [] # list of specsByCol
        self.equations = []
        self.added = False

    def validate(self):
        # check that all signals have a good type
        for col in self.signalsByCol:
            if col not in self.typesByCol:
                error('Table ' + self.name + ': no signal type for ' + self.signalsByCol[col])
            elif self.typesByCol[col] == None:
                error('Table ' + self.name + ': bad signal type (' + str(self.typesByCol[col]) + ') for ' + str(self.signalsByCol[col]))

    def makeRTL(self, form=None):
        outputsByCol = {}


        #for col,type in self.typesByCol.items():
        for col in sorted(self.typesByCol):
          type = self.typesByCol[col]
          if type == TYPE_OUTPUT:
            if col in self.signalsByCol:
               outputsByCol[col] = self.signalsByCol[col]
            else:
               print(self.signalsByCol)
               print(self.typesByCol)
               error('Table ' + self.name + ': output is specified in col ' + str(col) + ' but no signal exists')

        #for sigCol,sig in outputsByCol.items():
        for sigCol in sorted(outputsByCol):
            sig = outputsByCol[sigCol]
            if not verilog:
               line = sig + ' <= '
            else:
               line = 'assign ' + sig + ' = '
            nonzero = False
            for specsByCol in self.specs:
                terms = []
                if sigCol not in specsByCol:
                    #error('* Output ' + sig + ' has no specified value for column ' + str(col))
                    1 # no error, can be dontcare
                elif specsByCol[sigCol] == '1':
                    for col,val in specsByCol.items():
                        if col not in self.typesByCol:
                            if noisy:
                                error('Table ' + self.name +': unexpected value in spec column ' + str(col) + ' (' + str(val) + ') - no associated signal', False) #wtf UNTIL CAN HANDLE COMMENTS AT END!!!!!!!!!!!!!!!!!!!
                        elif self.typesByCol[col] == TYPE_INPUT:
                            if val == '0':
                                terms.append(opNot + self.signalsByCol[col])
                                if nonzero and len(terms) == 1:
                                    line = line + ') ' + opOr + '\n  (';
                                elif len(terms) == 1:
                                    line = line + '\n  ('
                                nonzero = True
                            elif val == '1':
                                terms.append(self.signalsByCol[col])
                                if nonzero and len(terms) == 1:
                                    line = line + ') ' + opOr + '\n  (';
                                elif len(terms) == 1:
                                    line = line + '\n  ('
                                nonzero = True
                            else:
                                error('Table ' + self.name +': unexpected value in spec column ' + str(col) + ' (' + str(val) + ')')
                if len(terms) > 0:
                    line = line + (' ' + opAnd + ' ').join(terms)
            if not nonzero:
                line = line + zero + ";";
            else:
                line = line + ');'
            self.equations.append(line)

        return self.equations

    def printv(self):
        self.makeRTL()
        print('\n'.join(self.equations))

    def printinfo(self):
        print('Table: ' + self.name)
        print
        for l in self.source:
            print(l)
        print
        print('Signals by column:')
        for col in sorted(self.signalsByCol):
            print('{0:>3}. {1:} ({2:}) '.format(col, self.signalsByCol[col], 'in' if self.typesByCol[col] == TYPE_INPUT else 'out'))


#--------------------------------------------------------------------------------------------------
# Functions

def error(msg, quitOverride=None):
    print('*** ' + msg)
    if quitOverride == False:
        1
    elif (quitOverride == None) or failOnError:
        quit(-10)
    elif quitOverride:
        quit(-10)

#--------------------------------------------------------------------------------------------------
# Do something

if not verilog:
   openBracket = '('
   closeBracket = ')'
   opAnd = 'and'
   opOr = 'or'
   opNot = 'not '
   zero = "'0'"
   tablePattern = re.compile(r'^\s*?--tbl(?:\s+([^\s]+).*$|\s*$)')
   tableGenPattern = re.compile(r'^\s*?--vtable(?:\s+([^\s]+).*$)')
   commentPattern = re.compile(r'^\s*?(--.*$|\s*$)')
   tableLinePattern = re.compile(r'^.*?--(.*)')
   namePattern = re.compile(r'([a-zA-z\d_\(\)\.\[\]]+)')
else:
   openBracket = '['
   closeBracket = ']'
   opAnd = '&'
   opOr = '+'
   opNot = '~'
   zero = "'b0"
   tablePattern = re.compile(r'^\s*?\/\/tbl(?:\s+([^\s]+).*$|\s*$)')
   tableGenPattern = re.compile(r'^\s*?\/\/vtable(?:\s+([^\s]+).*$)')
   commentPattern = re.compile(r'^\s*?(\/\/.*$|\s*$)')
   tableLinePattern = re.compile(r'^.*?\/\/(.*)')
   namePattern = re.compile(r'([a-zA-z\d_\(\)\.\[\]]+)')

# find the lines with table spec
try:
    inf = open(inFile)
    for i, line in enumerate(inf):
        lines.append(line.strip('\n'))
        for match in re.finditer(tablePattern, line):
            tableMatches.append(i)
    inf.close()
except Exception as e:
    error('Error opening input file ' + inFile + '\n' + str(e), True)

# validate matches; should be paired, nothing but comments and empties; table may be named
# between them

for i in range(0, len(tableMatches), 2):

    if i + 1 > len(tableMatches) - 1:
        error('Mismatched table tags.\nFound so far: ' + ', '.join(tableNames), True)

    tLines = lines[tableMatches[i]:tableMatches[i+1]+1]
    tableLines.append(tLines)
    tName = re.match(tablePattern, lines[tableMatches[i]]).groups()[0]
    if tName is None:
        tName = 'noname_' + str(tableMatches[i] + 1)
    tableNames.append(tName)

    for line in tLines:
        if not re.match(commentPattern, line):
            error('Found noncomment, nonempty line in table ' + tName + ':\n' + line, True)

print('Found tables: ' + ', '.join(tableNames))

# build table objects

for table, tName in zip(tableLines, tableNames):
    print('Parsing ' + tName + '...')
    namesByCol = {}
    colsByName = {}
    bitsByCol = {}
    typesByCol = {}
    specs = []

# parse the table - do by Table.parse()
    tLines = table[1:-1]     # exclude --tbl
    for line in tLines:
        if line.strip() == '':
          continue
        try:
          spec = re.search(tableLinePattern, line).groups()[0]
        except Exception as e:
          error('Problem parsing table line:\n' + line, True)
        if len(spec) > 0:
            if spec[0] == 'n':
                for match in re.finditer(namePattern, spec[1:]):
                    # col 0 is first col after n
                    namesByCol[match.start()] = match.groups()[0]
                    colsByName[match.groups()[0]] = match.start()
            elif spec[0] == 'b':
                for i, c in enumerate(spec[1:]):
                    if c == ' ' or c == '|':
                        continue
                    try:
                        bit = int(c)
                    except:
                        error('Unexpected char in bit line at position ' + str(i) + ' (' + c + ')\n' + line)
                        bit = None
                    if i in bitsByCol and bitsByCol[i] is not None:
                        bitsByCol[i] = bitsByCol[i]*10+bit
                    else:
                        bitsByCol[i] = bit
            elif spec[0] == 't':
                for i, c in enumerate(spec[1:]):
                    if c.lower() == 'i':
                        typesByCol[i] = TYPE_INPUT
                    elif c.lower() == 'o':
                        typesByCol[i] = TYPE_OUTPUT
                    elif c.lower() == '*':
                        typesByCol[i] = TYPE_SKIP
                    elif c != ' ':
                        error('Unexpected char in type line at position ' + str(i) + ' (' + c + ')\n' + line)
                        typesByCol[i] = None
                    else:
                        typesByCol[i] = None
            elif spec[0] == 's':
                specsByCol = {}
                for i, c in enumerate(spec[1:]):
                    if c == '0' or c == '1':
                        specsByCol[i] = c
                specs.append(specsByCol)
            else:
                #print('other:')
                #print(line)
                1

# create table object

# add strand to name where defined; don't combine for now into vector
# consecutive strands belong to the last defined name
    lastName = None
    lastCol = 0
    signalsByCol = {}

    for col,name in namesByCol.items():     # load with unstranded names
        signalsByCol[col] = name

# sort by col so consecutive columns can be easily tracked
    #for col,val in bitsByCol.items():       # update with stranded names
    for col in sorted(bitsByCol):
        val = bitsByCol[col]

        if col > lastCol + 1:
            lastName = None
        if val is None:
            lastName = None
        if col in namesByCol:
            if val is None:
                signalsByCol[col] = namesByCol[col]
            else:
                lastName = namesByCol[col]
                signalsByCol[col] = lastName + openBracket + str(val) + closeBracket
        elif lastName is not None:
            signalsByCol[col] = lastName + openBracket + str(val) + closeBracket
        else:
            error('Can\'t associate bit number ' + str(val) + ' in column ' + str(col) + ' with a signal name.')
        lastCol = col

    t = Table(tName)
    t.source = table
    t.signalsByCol = signalsByCol
    t.typesByCol = typesByCol
    t.specs = specs

    tables[tName] = t

for name in tables:
    t = tables[name]
    t.validate()
    t.makeRTL()

print()
print('Results:')

# find the lines with generate spec and replace them with new version
outLines = []
inTable = False
for i, line in enumerate(lines):
    if not inTable:
        match = re.search(tableGenPattern, line)
        if match is not None:
            tName = match.groups(1)[0]
            if tName not in tables:
                if tName == 1:
                    tName = '<blank>'
                error('Found vtable start for \'' + tName + '\' but didn\'t generate that table: line ' + str(i+1) + '\n' + line, True)
            else:
                outLines.append(line)
                outLines += tables[tName].equations
                tables[tName].added = True
                inTable = True
        else:
            outLines.append(line)
    else:
        match = re.search(tableGenPattern, line)
        if match is not None:
            if match.groups(1)[0] != tName:
                error('Found vtable end for \'' + match.groups(1)[0] + '\' but started table \'' + tName + '\': line ' + str(i+1) + '\n' + line, True)
            outLines.append(line)
            inTable = False
        else:
            1#print('stripped: ' + line)

if backup:
    try:
        copyfile(inFile, backupFile)
    except Exception as e:
        error('Error creating backup file!\n' + str(e), True)

try:
    of = open(outFile, 'w')
    for line in outLines:
        of.write("%s\n" % line)
except Exception as e:
    error('Error writing output file ' + outFile + '!\n' + str(e), True)

print('Generated ' + str(len(tables)) + ' tables: ' + ', '.join(tableNames))
notAdded = {}
for table in tables:
    if not tables[table].added:
        notAdded[table] = True
print('Output file: ' + outFile)
if backup:
    print('Backup file: ' + backupFile)
if len(notAdded) != 0:
    error('Tables generated but not added to file! ' + ', '.join(notAdded))