from scanner import scan
from scanner import Token
from scanner import TokenType

from ast import DrawingNode
from ast import RowNode
from ast import RepeatNode
from ast import ChunkNode

from enum import Enum

# ---------------------------------
# non-terminals
# ---------------------------------

class NonTerminal(Enum):
    Drawing = 0
    Row     = 1
    Repeat  = 2
    Chunk   = 3
    Chunks  = 4

# ---------------------------------
# semantic actions
# ---------------------------------

class AstAction(Enum):
    MakeDrawing = 0
    MakeRow     = 1
    MakeRepeat  = 2
    MakeChunk   = 3

# ---------------------------------
# stack operators for grammar rules
# ---------------------------------

def top(stack):
    return stack[-1]

def pop(stack):
    return stack.pop()

def push(element, stack):
    stack.append(element)

def push_rule(lst, stack):
    for element in reversed(lst):
        stack.append(element)

# ---------------------------------
# semantic action table
# ---------------------------------

def make_repeat_node(ast_stack):
    count = pop(ast_stack)
    new_node = RepeatNode(count)
    push(new_node, ast_stack)

def make_chunk_node(ast_stack):
    string = pop(ast_stack)
    count  = pop(ast_stack)
    new_node = ChunkNode(count,string)
    push(new_node, ast_stack)

def make_row_node(ast_stack):
    list_of_chunks = []
    while isinstance(top(ast_stack), ChunkNode):
        list_of_chunks.insert(0, pop(ast_stack))
    repeat_value = pop(ast_stack)
    new_node = RowNode(repeat_value, list_of_chunks)
    push(new_node, ast_stack)

def make_drawing_node(ast_stack):
    list_of_rows = []
    while ast_stack:
        list_of_rows.insert(0, pop(ast_stack))
    new_node = DrawingNode(list_of_rows)
    push(new_node, ast_stack)

action_table = {
    AstAction.MakeDrawing : make_drawing_node,
    AstAction.MakeRow     : make_row_node,
    AstAction.MakeRepeat  : make_repeat_node,
    AstAction.MakeChunk   : make_chunk_node
}

# ---------------------------------
# parse table
# ---------------------------------

parse_table = {
    (NonTerminal.Drawing, TokenType.int_token)  : [ NonTerminal.Row,
                                                    NonTerminal.Drawing ],
    (NonTerminal.Row,     TokenType.int_token)  : [ NonTerminal.Repeat,
                                                    NonTerminal.Chunks,
                                                    TokenType.terminator,
                                                    AstAction.MakeRow ] ,
    (NonTerminal.Chunks,  TokenType.int_token)  : [ NonTerminal.Chunk,
                                                    NonTerminal.Chunks ] ,
    (NonTerminal.Repeat,  TokenType.int_token)  : [ TokenType.int_token,
                                                    AstAction.MakeRepeat ] ,
    (NonTerminal.Chunk,   TokenType.int_token)  : [ TokenType.int_token,
                                                    TokenType.str_token,
                                                    AstAction.MakeChunk ],
    (NonTerminal.Drawing, TokenType.EndOfStream): [ AstAction.MakeDrawing ],
    (NonTerminal.Chunks,  TokenType.terminator) : []
}

# ---------------------------------
# parser
# ---------------------------------

def parse(token_stream):
    parse_stack    = [ TokenType.EndOfStream, NonTerminal.Drawing ]
    semantic_stack = []
    
    pos   = 0
    while parse_stack and pos < len(token_stream):
        A = top(parse_stack)
        t = token_stream[pos]

        if isinstance( A, TokenType ):
            if A == t.token_type:
                pop(parse_stack)
                if t.isInteger() or t.isString():
                    push(t.value(), semantic_stack)
                pos += 1
            else:
                print('Parse error -- token mismatch:', A, t)
                return False
        elif isinstance( A, NonTerminal ):
            rule = parse_table.get( (A, t.token_type) )
            if rule is not None:
                pop(parse_stack)
                push_rule(rule, parse_stack)
            else:
                print('Parse error -- cannot expand', A, 'on', t)
                return False
        elif isinstance( A, AstAction ):
            action = action_table.get(A)
            action(semantic_stack)
            pop(parse_stack)
        else:
            print('Parse error -- invalid item on stack', A)
            return False
    if pos < len(token_stream):
        print('Parse error -- unexpected tokens at end:', token_stream[pos:])
        return False
    if parse_stack:
        print('Parse error -- unexpected end of stream:', A)
        return False
    if len(semantic_stack) != 1:
        print('Parse error -- unexpected number of AST nodes:', semantic_stack)
        return False
    return top(semantic_stack)

if __name__ == "__main__":
    program  = '3 9 x;\n6 3 b 3 X 3 b;\n3 9 x;'
    tokens   = scan(program)
    ast_node = parse(tokens)
    print(ast_node, end='')
