import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

datasetDir = 'images/'

firstImg = True
nrClasses = len(os.listdir(datasetDir))
for classDir in os.listdir(datasetDir):
    label = int(classDir)    
    imgDir = datasetDir + classDir + '/' 
    for imgFile in os.listdir(imgDir):
        img = mpimg.imread(imgDir + imgFile)
        if firstImg == True:
            imgWidth = img.shape[0]
            imgHeight = img.shape[1]
            images = np.array([img])
            labels = np.array([label])
            firstImg = False
        else:
            images = np.vstack([images, [img]])
            labels = np.append(labels, label)
nrImages = images.shape[0]

#shuffle data
randomIdx = np.random.permutation(len(images))
images = images[randomIdx]
labels = labels[randomIdx]

images = torch.Tensor(images)
images = images.view([images.shape[0], 1, images.shape[2], images.shape[1]])
labels = torch.Tensor(labels).long()

class SimpleCNN(nn.Module):
    def __init__(self, imgWidth, imgHeight):
        super(SimpleCNN, self).__init__()
        
        inputWidth = imgWidth
        inputHeight = imgHeight
        nrConvFilters = 3
        convFilterSize = 5
        poolSize = 2
        outputSize = 10

        self.convLayer = nn.Conv2d(1, nrConvFilters, convFilterSize)
        self.poolLayer = nn.MaxPool2d(poolSize)
        fcInputSize = (inputWidth - 2*(convFilterSize // 2)) * (inputWidth - 2*(convFilterSize // 2)) * nrConvFilters // (2 * poolSize)
        self.fcLayer = nn.Linear(fcInputSize, outputSize)

    def forward(self, input):
        output = self.convLayer(input)
        output = self.poolLayer(output)
        output = F.relu(output)
        output = output.view([1, -1])
        output = self.fcLayer(output)
        return output            
        
    def train(self, images, labels):
        lossFunc = nn.CrossEntropyLoss()
        nrEpochs = 10
        learnRate = 0.01
        optimizer = torch.optim.SGD(self.parameters(), learnRate)
        
        for epoch in range(nrEpochs):
                
            for image, label in zip(images, labels):

                optimizer.zero_grad()
                predicted = self.forward(image.unsqueeze(0))
                loss = lossFunc(predicted, label.unsqueeze(0))
                loss.backward()
                optimizer.step()
            
            print('Epoch', epoch, 'loss', loss.item())

myCNN = SimpleCNN(imgWidth, imgHeight)
myCNN.train(images, labels)
