import random
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--mut_prob", help="Mutation Probability", type=float, default=0.6)
args = parser.parse_args()

print(args.mut_prob)

MAX_GEN = 1000
POP_SIZE = 50
IND_LEN = 50
CX_PROB = 0.8
MUT_PROB = 0.6  
MUT_FLIP_PROB = 1/IND_LEN

# onemax problem
def fitness(ind):
    return sum(ind)

def create_random_population():
    pop = []
    for _ in range(POP_SIZE):
        ind = [random.randint(0,1) for _ in range(IND_LEN)]
        pop.append(ind)
    return pop

def select(pop, fits):
    return random.choices(pop, weights=fits, k=POP_SIZE)

def crossover(pop):
    off = []
    for p1, p2 in zip(pop[::2], pop[1::2]):
        if random.random() < 0.8:
            point = random.randrange(0, IND_LEN)
            o1 = p1[:point] + p2[point:]
            o2 = p2[:point] + p1[point:]
            off.append(o1)
            off.append(o2)
        else:
            off.append(p1[:])
            off.append(p2[:])
    return off

def mutation(pop):
    off = []
    for p in pop:
        if random.random() < MUT_PROB:
            o = [1-i if random.random() < MUT_FLIP_PROB else i for i in p]
            off.append(o)
        else:
            off.append(p[:])
    return off

def evolution():
    log = []
    pop = create_random_population()
    for gen in range(MAX_GEN):
        fits = [fitness(ind) for ind in pop]
        log.append(max(fits))
        mating_pool = select(pop, fits)
        off = crossover(mating_pool)
        off = mutation(off)
        off[0] = max(pop, key=fitness)
        pop = off[:]
    
    return pop, log

pop, log = evolution()
print(max(pop, key=fitness))
print(log)

import matplotlib.pyplot as plt
plt.plot(log)
plt.show()