import torch 
import torch.nn as nn
import threading
from torchvision import transforms
from facenet_pytorch import MTCNN
import tkinter as tk
from tkinter import ttk
import cv2
import time

class App:
    def __init__(self):
        self.running = None

        self.root = tk.Tk()
        self.root.title("Emotion Detector")
        self.root.geometry("400x150")
        self.root.configure(bg="#74605c")  # Light gray background

        # Style configuration
        self.style = ttk.Style()
        self.style.theme_use("clam")  # Modern theme
        self.style.configure("TButton", font=("Helvetica", 12), padding=10, relief="flat")
        self.style.map("TButton",
                       background=[("active", "#6a9074")],  # Light green on hover
                       foreground=[("active", "#FFFFFF")])
        
        self.runThread = threading.Thread(target=self.runProgram)
        self.killThread = threading.Thread(target=self.killProgram)

        self.start_button = ttk.Button(self.root, text="Start", command=lambda: self.runThread.start())
        self.start_button.pack(side="left", padx=35) 

        self.stop_button = ttk.Button(self.root, text="Stop", command=lambda: self.killThread.start())
        self.stop_button.pack(side="left", padx=35)

        self.observer = emotionObserver()

        self.runTime = 0

    def runProgram(self):
        print("run program called")

        #prevent multiple starts
        if self.running:
            return
        
        self.running = True

        #open default camera
        cap = cv2.VideoCapture(0)

        if not cap.isOpened():
            print("Error opening camera")
            exit()

        while self.running:
            #take image
            ret, matImage = cap.read()
            if not ret:
                print("failed capture")
                continue

            #process image using CNN
            processed_img = self.observer.preprocess(matImage)
            if processed_img == None:
                print("no face detected in preprocessing")
                continue

            confidence, emotion = self.observer.analyze(processed_img)

            numToEmotion = ["sad", "neutral", "happy"]

            print(f"{numToEmotion[emotion]} {confidence*100+20: .2f}%") 
            self.observer.emotions[self.runTime] = (emotion, confidence)

            #rest for 1 sec
            self.runTime += 1
            time.sleep(1)


    def killProgram(self):
        print("kill program called")
        self.running = False

        sad, neutral, happy = self.observer.getStats()

        if happy < 10:
            suggestion = "smile more"
        else:
            suggestion = "you smiled enough"

        print(f"Happy: {happy}%\nNeutral: {neutral}%\nSad: {sad}%")

        self.start_button.destroy()
        self.stop_button.destroy()

        stats_display = f"Happy: {happy}%\nNeutral: {neutral}%\nSad: {sad}%\n\nSuggestion: {suggestion}"
        result_label = tk.Label(self.root, text=stats_display, font=("Helvetica", 14), justify="center")
        result_label.pack(pady=20)


        

class emotionObserver:
    def __init__(self):
        self.running = True

        self.emotions = {}

        device=torch.device('cpu')

        self.model = torch.load("emotionCNN.pth", map_location=device)
        self.model.eval()

        self.faceDetector = MTCNN(keep_all=False, device=device) 

    def preprocess(self, cv2_image):
        #check if face detected
        face = self.faceDetector(cv2_image)
        if face is None:
            return None  
        

        rgb_image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)

        transform = transforms.Compose([
            transforms.ToPILImage(), 
            transforms.Resize((224, 224)),  
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  
        ])
    
        tensor_image = transform(rgb_image)

        # Add a batch dimension (1, C, H, W)
        tensor_image = tensor_image.unsqueeze(0)

        return tensor_image

    def analyze(self, tensor_img):
        output = self.model(tensor_img)

        #convert raw logits to probabilities
        probabilities = torch.nn.functional.softmax(output, dim=1)

        #get highest confidence
        probability, predicition = torch.max(probabilities, dim=1)

        return probability.item(), predicition.item()


    def getStats(self):
        happy_caps = len([value for value in self.emotions.values() if value[0] == 2])
        print(f"happy pics: {happy_caps}")
        percent_happy = round((happy_caps / len(self.emotions))*100, 2)

        neutral_caps = len([value for value in self.emotions.values() if value[0] == 1])
        print(f"neutral pics: {neutral_caps}")
        percent_neutral = round((neutral_caps / len(self.emotions))*100, 2)

        sad_caps = len([value for value in self.emotions.values() if value[0] == 0])
        print(f"sad pics: {sad_caps}")
        percent_sad = round((sad_caps / len(self.emotions))*100, 2)

        return percent_sad, percent_neutral, percent_happy

class EmotionCNN(nn.Module): 
    def __init__(self):
        super(EmotionCNN, self).__init__()

        # First convolutional layer with dropout
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3), 
            nn.ReLU(),
            nn.Dropout(p=0.2)  # Dropout with 20% probability
        )
        
        # Second convolutional layer with dropout
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
            nn.ReLU(),
            nn.Dropout(p=0.2)  # Dropout with 20% probability
        )
        
        # Third convolutional layer with dropout
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3),
            nn.ReLU(),
            nn.Dropout(p=0.3)  # Dropout with 30% probability
        )

        # Fourth convolutional layer with dropout
        self.conv4 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3),
            nn.ReLU(),
            nn.Dropout(p=0.3)  # Dropout with 30% probability
        )

        # Fourth convolutional layer with dropout
        self.conv5 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3),
            nn.ReLU(),
            nn.Dropout(p=0.3)  # Dropout with 30% probability
        )

        flattenSize = self.calculate_flattened_size((3,224,224))

        self.fc1 = nn.Sequential(
            nn.Linear(in_features=flattenSize, out_features=512),
            nn.ReLU(),
            nn.Dropout(p=0.5)  # Dropout with 50% probability           
        )

        self.fc2 = nn.Linear(512, 3)  # Three classes: Happy, Sad, Neutral

    def calculate_flattened_size(self, input_size):
        x = torch.zeros(1, *input_size)
        x = self.conv1(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        x = self.conv2(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        x = self.conv3(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        x = self.conv4(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        x = self.conv5(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        return x.numel()
    
    def forward(self, x):       #automatically called when data passed into class
        x = self.conv1(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        x = self.conv2(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        x = self.conv3(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        x = self.conv4(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        x = self.conv5(x)
        x = torch.nn.functional.max_pool2d(x, 2)

        x = x.view(x.size(0), -1)  # Flatten

        x = self.fc1(x)
        x = self.fc2(x)            
        return x
    

app = App()
app.root.mainloop()