from util import Object
from elliptic_curve import *
import elliptic_curve_display

import random
import sympy
import tkinter as tk
from tkinter import ttk
import matplotlib.pyplot as plt
from matplotlib.backend_bases import key_press_handler
from matplotlib.backends.backend_tkagg import (FigureCanvasTkAgg,
											   NavigationToolbar2Tk)

state = None

MAX_FINITE = 100
prime_list = [0] + list(sympy.primerange(2, MAX_FINITE))

def random_point(curve):
	ps = curve.points()
	i = random.randint(0, len(ps[0])-1)
	p = (ps[0][i], ps[1][i])
	return p

def new_random_points():
	global state
	state.curve.point1 = random_point(state.curve)
	state.curve.point2 = state.curve.point1
	while state.curve.point2 == state.curve.point1:
		state.curve.point2 = random_point(state.curve)

def display():
	global state
	elliptic_curve_display.clear()
	elliptic_curve_display.display(state.ax, state.curve)
	elliptic_curve_display.addition(state.ax, state.curve, *state.curve.point1, *state.curve.point2)
	state.canvas.draw()

def _update_curve():
	global state
	a   = int(state.a_strvar.get())
	b   = int(state.b_strvar.get())
	mod = int(state.mod_strvar.get())
	if mod > MAX_FINITE:
		state.is_finite.set(0)
	state.curve = elliptic_curve_factory(state.is_finite.get(), a, b, mod)
	new_random_points()
	state.a_strvar.set(state.curve.a)
	state.b_strvar.set(state.curve.b)
	state.mod_strvar.set(state.curve.mod)

def update_curve():
	_update_curve()
	display()

def set_curve(curve : (int, int, int)):
	global state
	a, b, mod = curve
	state.a_strvar.set(a)
	state.b_strvar.set(b)
	state.mod_strvar.set(mod)
	update_curve()

def init(curve : (int, int, int)):
	global state
	state = Object()

	tk_init(curve)
	_update_curve()
	tk_fill()
	display()

def tk_init(curve : (int, int, int)):
	global state
	a, b, mod = curve
	state.root = root = tk.Tk()
	root.wm_title("Elliptic curves")
	root.attributes('-zoomed', True)

	state.figure = plt.figure(figsize=(5, 4), dpi=150)
	state.canvas = FigureCanvasTkAgg(state.figure, master=root)
	state.controls = ttk.Frame(root)
	state.toolbar = NavigationToolbar2Tk(state.canvas, root, pack_toolbar=False)
	state.toolbar.update()
	state.a_strvar   = tk.StringVar(value=a)
	state.b_strvar   = tk.StringVar(value=b)
	state.mod_strvar = tk.StringVar(value=mod)
	state.is_finite  = tk.IntVar(value=1)

def tk_fill():
	global state
	# Finite field view toggle
	tk.Checkbutton(state.controls, text="Finite field view",
		variable=state.is_finite, command=update_curve,
	).pack()

	# Horizontal separator
	tk.Frame(state.controls, height=2, bg="black").pack(fill=tk.X, pady=10)

	# Equation -- y^2 = x^3 + a * x^2 + b
	state.curve_equation = ttk.Frame(state.controls)
	equation = [
		ttk.Label(state.curve_equation, text="y² ≡ x³ + "),
		(curve_equation_a := ttk.Label(state.curve_equation, textvariable=state.a_strvar, foreground="blue")),
		ttk.Label(state.curve_equation, text="x² + "),
		(curve_equation_b := ttk.Label(state.curve_equation, textvariable=state.b_strvar, foreground="red")),
		ttk.Label(state.curve_equation, text=" mod "),
		(curve_equation_mod := ttk.Label(state.curve_equation, textvariable=state.mod_strvar, foreground="magenta")),
	]
	state.curve_equation_a, state.curve_equation_b = curve_equation_a, curve_equation_b
	state.curve_equation.pack()
	for i, ix in enumerate(equation):
		ix.grid(row=0, column=i)

	# Input --- a b mod
	entry_keys = ("name", "color", "value_var", "label_name", "max")
	a, b = state.a_strvar.get(), state.b_strvar.get()
	mod  = prime_list.index(int(state.mod_strvar.get()))
	entry_values = [
		("a",       "blue",    a,   "a_input",   100),
		("b",       "red",     b,   "b_input",   100),
		("modulos", "magenta", mod, "mod_input", len(prime_list)-1),
	]
	(f := ttk.Frame(state.controls)).pack()
	for i, d in enumerate([{k:v for k,v in zip(entry_keys, t)} for t in entry_values]):
		ttk.Label(f, text=d["name"], foreground=d["color"]).grid(row=i, column=0)
		w = state[d["label_name"]] = tk.Scale(f,
											from_=0, to=d["max"],
											orient=tk.HORIZONTAL,
											length=200,
											showvalue=False
										)
		w.set(d["value_var"])
		w.config(command=lambda event: update_curve())
		w.grid(row=i, column=1)

	# Preset buttons
	for name, equation in DEFAULT_CURVES.items():
		ttk.Button(state.controls, text=name, command=lambda: set_curve(equation)).pack()

	# Horizontal separator
	tk.Frame(state.controls, height=2, bg="black").pack(fill=tk.X, pady=10)

	# Operator controls
	ttk.Button(state.controls, text="New random points", command=update_curve).pack()

	# Matplotlib init
	state.ax = ax = state.figure.add_subplot()
	ax.grid()
	ax.axhline(0, color='black', linewidth=1.5)
	ax.axvline(0, color='black', linewidth=1.5)
	ax.set_aspect('equal')

	# Extra Packing
	state.controls.pack(side=tk.RIGHT)
	state.toolbar.pack(side=tk.BOTTOM, fill=tk.X)
	state.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)

def run():
	state.root.mainloop()