inching closer to the truth

This commit is contained in:
anon
2024-10-09 15:16:53 +02:00
parent 4b6bf0f208
commit f8c8f7ef0c
9 changed files with 82 additions and 51 deletions

47
data.py
View File

@ -7,20 +7,24 @@ from sys import argv
from config import *
import tard_wrangler
MAX_DATA_LIMIT = sys.maxsize
#MAX_DATA_LIMIT = sys.maxsize
MAX_DATA_LIMIT = 1000
def get_source(path : str) -> [str]:
DATASET_FILE = "training_set/dataset-linux.pkl"
def get_source(path : str, normpath : str) -> [str]:
'''returns source file in $SOURCE_LINE_BATCH_SIZE line batches'''
r = []
# read data
with open(path, 'r') as file: lines = [line[:-1] for line in file]
with open(path, 'r') as f: lines = [line[:-1] for line in f]
with open(normpath, 'r') as f: normlines = [line[:-1] for line in f]
# pad with empty lines
for i in range(int((SOURCE_LINE_BATCH_SIZE-1)/2)):
lines.insert(0, "")
lines.append("")
normlines.append("")
# batch
for i in range(len(lines)-2):
r.append(lines[i:i+SOURCE_LINE_BATCH_SIZE])
for i in range(len(lines)-1):
r.append([lines[i]] + normlines[i:i+SOURCE_LINE_BATCH_SIZE-1])
return r
def source_to_np_array(source_batches : []) -> np.array:
@ -44,7 +48,8 @@ def read_acc(path : str) -> [[int]]:
for line in file:
try:
l = eval(line)
l = l + [0] * (MAX_SHIMS - len(l))
if len(l) < MAX_SHIMS: l = l + [0] * (MAX_SHIMS - len(l))
else: l = l[:MAX_SHIMS]
r.append(l)
except: pass
return r
@ -54,27 +59,28 @@ def whitespace_to_np_array(spaces : []) -> np.array:
r = np.array(r).reshape(len(spaces), -1)
return r
def compile_data():
def compile_data(from_dir : str) -> {}:
r = {'in': [], 'out': [], 'src': []}
for n, path in enumerate(glob(COMPILE_INPUT_DIRECTORY + "/*.c")):
if n > MAX_DATA_LIMIT: break # XXX
for n, path in enumerate(glob(from_dir + "/*.c")):
if n > MAX_DATA_LIMIT: break
acc_path = path + ".acc"
norm_path = path + ".norm"
r['src'].append(path)
source_batches = get_source(norm_path)
source_batches = get_source(path, norm_path)
accumulation = read_acc(acc_path)
assert len(source_batches) == len(accumulation), (
f"Some retard fucked up strings in {path}."
)
if len(source_batches) != len(accumulation):
print(f"WARNING: Some retard fucked up strings in {path}")
continue
r['src'].append(path)
r['in'] += source_batches
r['out'] += accumulation
print(f"INFO: Read data from ({n}) {path}")
r['in'] = source_to_np_array(r['in'])
r['out'] = whitespace_to_np_array(r['out'])
return r
def get_data():
r = []
with open('dataset-linux.pkl', 'rb') as f: r = pickle.load(f)
def get_data(dataset_file : str) -> {}:
r = {}
with open(dataset_file, 'rb') as f: r = pickle.load(f)
assert len(r['in']) == len(r['out']), (
"data in and out sizes were inconsistent ("
+ str(r['in'].shape)
@ -86,7 +92,6 @@ def get_data():
if __name__ == "__main__":
if len(argv) == 2 and argv[1] == 'c': # clean compile
with open('dataset-linux.pkl', 'wb') as f: pickle.dump(compile_data(), f)
dataset = get_data()
print(dataset)
with open(DATASET_FILE, 'wb') as f: pickle.dump(compile_data(COMPILE_INPUT_DIRECTORY), f)
dataset = get_data(DATASET_FILE)
print(dataset['in'].shape, dataset['out'].shape)