import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import csv

dataFile = open('iris.csv', 'r')
dataset = csv.reader(dataFile)
#skip first row which contains csv header
nrAttributes = len(next(dataset))-1
dataset = list(dataset)
nrInstances = len(dataset)

instances = np.empty([nrInstances, nrAttributes])
labelStrings = [None] * nrInstances
labels = np.empty(nrInstances)

idx = 0
for row in dataset:
    instances[idx] = np.array(row[:nrAttributes])
    labelStrings[idx] = row[-1]
    idx += 1

uniqueLabelStrings = sorted(set(labelStrings))
labelDict = {}
labelIdx = 0
for label in uniqueLabelStrings:
    labelDict[label] = labelIdx
    labelIdx += 1

for i in range(len(labelStrings)):
    labels[i] = labelDict[labelStrings[i]]

#shuffle data
randomIdx = np.random.permutation(len(instances))
instances = instances[randomIdx]
labels = labels[randomIdx]

inputs = torch.Tensor(instances)
targets = torch.Tensor(labels).long()


class SimpleNN(nn.Module):
    def __init__(self, inputSize, hiddenSize, outputSize):
        super(SimpleNN, self).__init__()
        
        self.fc1 = torch.nn.Linear(inputSize, hiddenSize)
        self.fc2 = torch.nn.Linear(hiddenSize, outputSize)
                
    def forward(self, input):
        output = self.fc1(input)
        output = F.relu(output)
        output = self.fc2(output)
        return output
        
    def train(self, inputs, targets):
        lossFunc = nn.CrossEntropyLoss()
        nrEpochs = 10
        learnRate = 0.01
        optimizer = torch.optim.SGD(self.parameters(), learnRate)

        for epoch in range(nrEpochs):
            accuracy = 0

            for input, target in zip(inputs, targets):            
                optimizer.zero_grad()
                predicted = self.forward(input.unsqueeze(0))
                loss = lossFunc(predicted, target.unsqueeze(0))
                loss.backward()
                optimizer.step()
                accuracy += (predicted.argmax().item() == target.item())
            
            accuracy /= len(inputs)
            print('Epoch', epoch, 'loss', loss.item(), 'accuracy', accuracy)
            

myNet = SimpleNN(4, 5, 3)
myNet.train(inputs, targets)    

        



    



