from util import *

import numpy as np
import libnum
import sympy

CURVE_LINSPACE_OFFSET = 10
CURVE_LINSPACE_NUM = 2000

# NOTE: ADD NEW CURVES HERE
# this way they are automatically added to the gui
DEFAULT_CURVES = {
	'Default': (3, 3, 11),
	'secp384r1 (NSA backdoor)': (3, 0, 2**384 - 2**128 - 2**96 + 2**32 - 1),
	'secp256k1 (BTC)': (0, 7, 2**256 - 2**32 - 2**9 - 2**8 - 2**7 - 2**6 - 2**4 - 1),
}

class EllipticCurve(Object):
	def __init__(self, a, b):
		self.a = a
		self.b = b
		self.mod = 0
		self._points()
	def _points(self):
		def iterih_square(x):
			return (x**3) + (self.a * x**2) + self.b
		start = 0
		step  = 0.1
		while True:
			if iterih_square(start) < 0: break
			else: start -= step
		self.pp = np.empty((2, CURVE_LINSPACE_NUM))
		self.np = np.empty((2, CURVE_LINSPACE_NUM))
		for i, xi in enumerate(np.linspace(start, start + CURVE_LINSPACE_OFFSET, CURVE_LINSPACE_NUM)):
			t = np.sqrt(iterih_square(xi))
			self.pp[0][i] = xi
			self.pp[1][i] = t
			self.np[0][i] = xi
			self.np[1][i] = -t
	def points(self):
		return np.concatenate((self.pp, self.np), axis=0)
	def _cord_slope(self, x1, y1, x2, y2):
		return (y2 - y1) / (x2 - x1)
	def _tangent_slope(self, x, y):
		return (3 * x**2 + self.a) / (2 * y)
	def _add(self, s, x1, y1, x2, y2):
		x = s**2 - x1 - x2
		y = s * (x1 - x) - y1
		return (x, y)
	def add(self, x1, y1, x2, y2):
		return self._add(self._cord_slope(x1, y1, x2, y2), x1, y1, x2, y2)
	def double(self, x, y):
		return self._add(self._tangent_slope(x, y), x, y, x, y)
	def scalar_multiply(point, n):
		pass
	#def yfromx(self, x, is_top = True):
	#	r = np.sqrt((x**3) + (self.a * x**2) + self.b)
	#	r = +r if is_top else -r
	#	return r

class EllipticCurveOverFiniteField(Object):
	def __init__(self, a, b, mod):
		if mod < 3:
			raise ValueError("The modulus has to be over 2.")
		if not sympy.isprime(mod):
			raise ValueError("The modulus has to be a prime.")
		self.a = a
		self.b = b
		self.mod = mod
		self._points()
		self.inverse_table = [None] + [multiplicative_inverse_over_prime_finite_field(i, self.mod) for i in range(1, self.mod)]
	def _points(self):
		self.xs = []
		self.ys = []
		def iterih_square(x):
			return ((x**3) + (self.a * x**2) + self.b) % self.mod
		for x in range(0, self.mod):
			if libnum.has_sqrtmod_prime_power(iterih_square(x), self.mod, 1):
				square_roots = libnum.sqrtmod_prime_power(iterih_square(x), self.mod, 1)
				for sr in square_roots:
					self.ys.append(sr)
					self.xs.append(x)
	def _line_slope(self, x1, y1, x2, y2):
		return (y2 - y1) * self.inverse_table[(x2 - x1) % self.mod]
	def is_point_out_of_bounds(self, x, y):
		return x > self.mod or x < 0 or y > self.mod or y < 0
	def bound_intersections(self, x1, y1, x2, y2):
		left_border   = (0, 0, 0, self.mod)
		right_border  = (self.mod, 0, self.mod, self.mod)
		bottom_border = (0, 0, self.mod, 0)
		top_border    = (0, self.mod, self.mod, self.mod)
		if x1 > x2:
			x1, y1, x2, y2 = x2, y2, x1, y1
		if y1 > y2:
			p1 = intersection(x1, y1, x2, y2, *left_border)
			if self.is_point_out_of_bounds(*p1):
				p1 = intersection(x1, y1, x2, y2, *top_border)
			p2 = intersection(x1, y1, x2, y2, *bottom_border)
			if self.is_point_out_of_bounds(*p2):
				p2 = intersection(x1, y1, x2, y2, *right_border)
		else:
			p1 = intersection(x1, y1, x2, y2, *bottom_border)
			if self.is_point_out_of_bounds(*p1):
				p1 = intersection(x1, y1, x2, y2, *left_border)
			p2 = intersection(x1, y1, x2, y2, *right_border)
			if self.is_point_out_of_bounds(*p2):
				p2 = intersection(x1, y1, x2, y2, *top_border)
		return *p1, *p2
	def points(self):
		return self.xs, self.ys
	def add(self, x1, y1, x2, y2):
		if (x1, y1) == (x2, y2):
			raise ValueError("Adding a point to itself is not allowed, please use double.")
		s = self._line_slope(x1, y1, x2, y2)
		x = (s**2 - x1 - x2) % self.mod
		y = (s * ((x1 - x) % self.mod) - y1) % self.mod
		return (x, y)
	def double(self, x, y):
		return 0, 0
	def scalar_multiply(point, n):
		pass

def elliptic_curve_factory(is_finite, a=None, b=None, mod=None, curve=None):
	if curve != None:
		if isinstance(curve, tuple):
			a, b, mod = curve
		else:
			a = curve.a
			b = curve.b
			try:
				mod = curve.mod
			except:
				mod = mod
	else:
		if a == None or b == None or mod == None:
			raise ValueError("Insufficent information passed")
	if is_finite:
		return EllipticCurveOverFiniteField(a, b, mod)
	else:
		return EllipticCurve(a, b)