from util import Object

import numpy as np
# XXX: remove this
import libnum

CURVE_LINSPACE_OFFSET = 10
CURVE_LINSPACE_NUM = 2000

# NOTE: ADD NEW CURVES HERE
# this way they are automatically added to the gui
DEFAULT_CURVES = {
	'Default': (3, 3, 11),
	'secp384r1 (NSA backdoor)': (3, 0, 2**384 - 2**128 - 2**96 + 2**32 - 1),
	'secp256k1 (BTC)': (0, 7, 2**256 - 2**32 - 2**9 - 2**8 - 2**7 - 2**6 - 2**4 - 1),
}

def line_slope(x1, y1, x2, y2):
	return (y2 - y1) / (x2 - x1)

def intersection(x1, y1, x2, y2, x3, y3, x4, y4):
	s1 = line_slope(x1, y1, x2, y2)
	s2 = line_slope(x3, y3, x4, y4)
	if s1 == s2:
		raise Exception("passed lines are paralel")
	c1 = y1 - s1 * x1
	c2 = y3 - s2 * x3
	x = (c2 - c1) / (s1 - s2)
	y = s1 * x + c1
	return x, y

class EllipticCurve(Object):
	def __init__(self, a, b):
		self.a = a
		self.b = b
		self._points()
	def _points(self):
		def iterih_square(x):
			return (x**3) + (self.a * x**2) + self.b
		start = 0
		step  = 0.1
		while True:
			if iterih_square(start) < 0: break
			else: start -= step
		self.pp = np.empty((2, CURVE_LINSPACE_NUM))
		self.np = np.empty((2, CURVE_LINSPACE_NUM))
		for i, xi in enumerate(np.linspace(start, start + CURVE_LINSPACE_OFFSET, CURVE_LINSPACE_NUM)):
			t = np.sqrt(iterih_square(xi))
			self.pp[0][i] = xi
			self.pp[1][i] = t
			self.np[0][i] = xi
			self.np[1][i] = -t
	def points(self):
		return np.concatenate((self.pp, self.np), axis=0)
	def _cord_slope(self, x1, y1, x2, y2):
		return (y2 - y1) / (x2 - x1)
	def _tangent_slope(self, x, y):
		return (3 * x**2 + self.a) / (2 * y)
	def _add(self, s, x1, y1, x2, y2):
		x = s**2 - x1 - x2
		y = s * (x1 - x) - y1
		return (x, y)
	def add(self, x1, y1, x2, y2):
		return self._add(self._cord_slope(x1, y1, x2, y2), x1, y1, x2, y2)
	def double(self, x, y):
		return self._add(self._tangent_slope(x, y), x, y, x, y)
	def scalar_multiply(point, n):
		pass
	#def yfromx(self, x, is_top = True):
	#	r = np.sqrt((x**3) + (self.a * x**2) + self.b)
	#	r = +r if is_top else -r
	#	return r

class EllipticCurveOverFiniteField(Object):
	def __init__(self, a, b, mod):
		self.a = a
		self.b = b
		self.mod = mod
		self._points()
	def _points(self):
		self.xs = []
		self.ys = []
		def iterih_square(x):
			return ((x**3) + (self.a * x**2) + self.b) % self.mod
		for x in range(0, self.mod):
			if libnum.has_sqrtmod_prime_power(iterih_square(x), self.mod, 1):
				square_roots = libnum.sqrtmod_prime_power(iterih_square(x), self.mod, 1)
				for sr in square_roots:
					self.ys.append(sr)
					self.xs.append(x)
	def points(self):
		return self.xs, self.ys
	def add(self, x1, y1, x2, y2):
		s = line_slope(x1, y1, x2, y2)
		x = (s**2 - x1 - x2) % self.mod
		y = (s * (x1 - x) - y1) % self.mod
		return (x, y)
	def double(self, x, y):
		return 0, 0
	def scalar_multiply(point, n):
		pass

def elliptic_curve_factory(is_finite, a, b, mod, curve = None):
	if curve != None:
		a = curve.a
		b = curve.b
		try:
			mod = curve.mod
		except:
			mod = mod
	if is_finite:
		return EllipticCurveOverFiniteField(a, b, mod)
	else:
		return EllipticCurve(a, b)