{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b0e95d11-bf02-41bb-913d-a04dc4d32e25",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "source": [
    "# 📘 **NOAI 2026 - Programming Task 3**\n",
    "* **Subject:** Deep Learning (NLP / Transformers)\n",
    "* **Title:** The Attention Mechanism (Building the \"Brain\")\n",
    "* **Total** = 35 Marks\n",
    "  * **Part 1: The Core Math (10 Marks)**\n",
    "  * **Part 2: The \"Committee\" (10 Marks)**\n",
    "  * **Part 3: The Timekeeper (5 Marks)**\n",
    "  * **Part 4: The Decision (5 Marks)**\n",
    "  * **Part 5: Bonus Question (5 Marks)**\n",
    "\n",
    "Timestamp: 20 Feb 2026"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c316d70",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "source": [
    "## **1. The Scenario**\n",
    "You are the Lead Architect at *LinguaTech*. Your team is building a lightweight Language Model to analyze ancient text fragments.\n",
    "\n",
    "Standard Neural Networks read words one by one and forget the beginning by the time they reach the end. To fix this, you must implement the **Transformer Architecture**, specifically the mechanism that allows the model to look at *all words simultaneously* and understand context.\n",
    "\n",
    "This mechanism is called **Attention**.\n",
    "\n",
    "### **The Analogy: The \"Search Engine\"**\n",
    "Think of Attention as a database lookup:\n",
    "1.  **Query ($Q$):** What you are looking for (e.g., \"Bank\").\n",
    "2.  **Key ($K$):** The labels in the database (e.g., \"River\", \"Money\").\n",
    "3.  **Value ($V$):** The actual meaning/content.\n",
    "4.  **Attention Score:** How much the Query matches the Key. If \"Bank\" appears near \"Water\", the score for \"River\" goes up."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df4c390b",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "source": [
    "### **📚 Supplementary Reading**\n",
    "\n",
    "* **[The Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ba09d9f",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "source": [
    "## **2. The Architecture**\n",
    "You will build the four pillars of the Transformer Encoder:\n",
    "1.  **Scaled Dot-Product Attention:** The math that calculates relevance.\n",
    "2.  **Multi-Head Attention:** Running multiple attention checks in parallel.\n",
    "3.  **Positional Encoding:** Adding \"timestamps\" so the model knows word order.\n",
    "4.  **Classification Head:** Making a final decision based on the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd1a7215-46e3-48af-89f6-97cf8bacbf5d",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "outputs": [],
   "source": [
    "# --- CELL 1: SETUP (DO NOT EDIT) ---\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import math\n",
    "import numpy as np\n",
    "\n",
    "# Set seed for reproducibility\n",
    "torch.manual_seed(2026)\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83d15008",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "source": [
    "## **Part 1: The Core Math (10 Marks)**\n",
    "---\n",
    "### **Question 3.1 Scaled Dot-Product Attention (8 Marks)**\n",
    "The heart of the Transformer is this formula:\n",
    "\n",
    "$\\displaystyle \\text{Attention}(Q, K, V) = \\text{softmax}\\left(\\frac{QK^T}{\\sqrt{d_k}}\\right)V$\n",
    "\n",
    "**Your Task:**\n",
    "Implement this function using PyTorch.\n",
    "1.  **MatMul:** Multiply Query ($Q$) and Key Transposed ($K^T$).\n",
    "2.  **Scale:** Divide by the square root of the dimension ($\\sqrt{d_k}$).\n",
    "3.  **Mask (Optional):** If a mask is provided, use `masked_fill` to set the masked positions to a large negative number like `-1e9` (so Softmax makes them zero).<br>\n",
    "    i.  Do not use `-inf` as it can cause numerical instability.\n",
    "4.  **Softmax:** Apply softmax to get probabilities (scores must sum to 1).\n",
    "5.  **Weighted Sum:** Multiply the scores by Value ($V$)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f05e1872-b610-4934-99c9-577c77d1ab52",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "source": [
    "### Question 3.1 (8 points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1583b2f0-9b92-4932-a04e-a4dbde770d64",
   "metadata": {
    "tags": [
     "otter_answer_cell"
    ]
   },
   "outputs": [],
   "source": [
    "### ANSWER QUESTION 3.1 HERE\n",
    "\n",
    "\n",
    "def scaled_dot_product_attention(query, key, value, mask=None, dropout=None):\n",
    "    \"\"\"\n",
    "    Compute scaled dot-product attention.\n",
    "    \n",
    "    Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V\n",
    "    \n",
    "    Parameters:\n",
    "    -----------\n",
    "    query : torch.Tensor, shape (batch, num_heads, seq_len, d_k)\n",
    "        Query vectors\n",
    "    key : torch.Tensor, shape (batch, num_heads, seq_len, d_k)\n",
    "        Key vectors\n",
    "    value : torch.Tensor, shape (batch, num_heads, seq_len, d_v)\n",
    "        Value vectors\n",
    "    mask : torch.Tensor, optional, shape (batch, 1, 1, seq_len) or (batch, 1, seq_len, seq_len)\n",
    "        Attention mask (True/1 = keep, False/0 = mask out)\n",
    "    dropout : nn.Dropout, optional\n",
    "        Dropout layer for attention weights\n",
    "        \n",
    "    Returns:\n",
    "    --------\n",
    "    output : torch.Tensor, shape (batch, num_heads, seq_len, d_v)\n",
    "        Attention output\n",
    "    attention_weights : torch.Tensor, shape (batch, num_heads, seq_len, seq_len)\n",
    "        Attention weight matrix\n",
    "    \"\"\"\n",
    "    d_k = query.size(-1)\n",
    "    \n",
    "    ### QUESTION 3.1 START\n",
    "\n",
    "    # TODO Step 1: Compute attention scores: QK^T\n",
    "\n",
    "    # TODO Step 2: Scale by sqrt(d_k)\n",
    "\n",
    "    # TODO Step 3: Apply mask if provided\n",
    "\n",
    "    # TODO Step 4: Apply softmax to get attention weights\n",
    "\n",
    "    # TODO Step 5: Apply dropout if provided\n",
    "\n",
    "    # TODO Step 6: Compute output: attention_weights @ V\n",
    "\n",
    "\n",
    "    ### QUESTION 3.1 END\n",
    "\n",
    "    return output, attention_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f576b094",
   "metadata": {
    "deletable": false,
    "editable": false,
    "otter": {
     "tests": [
      "q3.1"
     ]
    }
   },
   "outputs": [],
   "source": [
    "### TEST FOR QUESTION 3.1\n",
    "\n",
    "\n",
    "# Define test variables\n",
    "batch = 2\n",
    "heads = 8\n",
    "seq_len = 10\n",
    "d_k = 64\n",
    "\n",
    "Q = torch.randn(batch, heads, seq_len, d_k)\n",
    "K = torch.randn(batch, heads, seq_len, d_k)\n",
    "V = torch.randn(batch, heads, seq_len, d_k)\n",
    "\n",
    "print(\"Test variables initialized:\")\n",
    "print(f\"  Q shape: {Q.shape}\")\n",
    "print(f\"  K shape: {K.shape}\")\n",
    "print(f\"  V shape: {V.shape}\")\n",
    "print()\n",
    "\n",
    "print(\"=\" * 60)\n",
    "print(\"Testing Scaled Dot-Product Attention...\")\n",
    "print(\"=\" * 60)\n",
    "\n",
    "try:\n",
    "    out, attn = scaled_dot_product_attention(Q, K, V)\n",
    "    print(f\"  Output shape: {out.shape}\")\n",
    "    print(f\"  Attention shape: {attn.shape}\")\n",
    "    print(f\"  Attention sums to 1: {attn.sum(-1).mean().item():.4f}\")\n",
    "    assert out.shape == (batch, heads, seq_len, d_k)\n",
    "    assert attn.shape == (batch, heads, seq_len, seq_len)\n",
    "    assert torch.allclose(attn.sum(-1), torch.ones(batch, heads, seq_len), atol=1e-5)\n",
    "    print(\"  ✓ Attention working!\")\n",
    "except Exception as e:\n",
    "    print(f\"  ✗ Attention failed: {e}\")\n",
    "\n",
    "\n",
    "# Test with mask\n",
    "print(\"\\nTesting Attention with Mask...\")\n",
    "mask = torch.ones(batch, 1, 1, seq_len)\n",
    "mask[:, :, :, -3:] = 0  # Mask last 3 positions\n",
    "try:\n",
    "    out_masked, attn_masked = scaled_dot_product_attention(Q, K, V, mask=mask)\n",
    "    # Check that masked positions have zero attention\n",
    "    assert (attn_masked[:, :, :, -3:] == 0).all(), \"Masked positions should have 0 attention\"\n",
    "    print(\"  ✓ Masked attention working!\")\n",
    "except Exception as e:\n",
    "    print(f\"  ✗ Masked attention failed: {e}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "307477a3-9d17-4adc-96b1-0c8617d09c91",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "source": [
    "<!-- BEGIN QUESTION -->\n",
    "\n",
    "### Question 3.2 (2 points)\n",
    "\n",
    "Why is scaling by `sqrt(d_k)` necessary?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d705ae4-076a-4b87-bb67-8982e672d144",
   "metadata": {
    "tags": [
     "otter_answer_cell"
    ]
   },
   "outputs": [],
   "source": [
    "### ANSWER QUESTION 3.2 HERE\n",
    "\n",
    "\"\"\"\n",
    "Why is scaling by sqrt(d_k) necessary?\n",
    "\n",
    "Your answer should explain (2-4 sentences):\n",
    "1. What happens to dot product magnitudes as d_k increases\n",
    "2. How this affects softmax behavior (hint: saturation)\n",
    "3. Why this causes gradient flow problems\n",
    "\n",
    "Example structure:\n",
    "\"When d_k is large (e.g., 512), dot products between Q and K become very large in magnitude.\n",
    "This causes softmax to... which leads to... This is problematic for training because...\"\n",
    "\"\"\"\n",
    "\n",
    "answer_3_2 = \"\"\"\n",
    "TYPE YOUR ANSWER HERE:\n",
    "\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "print(answer_3_2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92fc5cb9",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "source": [
    "<!-- END QUESTION -->\n",
    "\n",
    "## **Part 2: The \"Committee\" (10 Marks)**\n",
    "---\n",
    "### **Question 3.3 & 3.4 Multi-Head Attention**\n",
    "One attention head is not enough. It might focus on grammar, but miss the tone.\n",
    "**Multi-Head Attention** splits the input into $h$ different \"heads,\" allowing the model to look at different aspects of the sentence simultaneously.\n",
    "\n",
    "**The Logic:**\n",
    "1.  **Split:** Divide the input embeddings into `num_heads` smaller chunks.\n",
    "2.  **Calculate:** Run Scaled Dot-Product Attention on all chunks in parallel.\n",
    "3.  **Concatenate:** Glue the results back together.\n",
    "4.  **Project:** Pass through a final Linear layer.\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "170b4066-954f-45b0-9501-a5e3d720157c",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "source": [
    "### Question 3.3 (5 points) & 3.4 (5 points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "255a6a71-a08c-45e5-9550-cb78760b9fea",
   "metadata": {
    "tags": [
     "otter_answer_cell"
    ]
   },
   "outputs": [],
   "source": [
    "### ANSWER QUESTION 3.3 & 3.4 HERE\n",
    "\n",
    "\n",
    "class MultiHeadAttention(nn.Module):\n",
    "    \"\"\"\n",
    "    Multi-Head Attention mechanism.\n",
    "    \n",
    "    MultiHead(Q, K, V) = Concat(head_1, ..., head_h) W^O\n",
    "    where head_i = Attention(Q W_i^Q, K W_i^K, V W_i^V)\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(self, d_model, num_heads, dropout=0.1):\n",
    "        \"\"\"\n",
    "        Parameters:\n",
    "        -----------\n",
    "        d_model : int\n",
    "            Model dimension (must be divisible by num_heads)\n",
    "        num_heads : int\n",
    "            Number of attention heads\n",
    "        dropout : float\n",
    "            Dropout probability\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        \n",
    "        assert d_model % num_heads == 0, \"d_model must be divisible by num_heads\"\n",
    "        \n",
    "        self.d_model = d_model\n",
    "        self.num_heads = num_heads\n",
    "        self.d_k = d_model // num_heads\n",
    "        \n",
    "        ### QUESTION 3.3 START\n",
    "        # Define projection layers:\n",
    "        # - W_q: for queries\n",
    "        # - W_k: for keys\n",
    "        # - W_v: for values\n",
    "        # - W_o: for output projection\n",
    "        # - dropout layer\n",
    "        \n",
    "        \n",
    "        ### QUESTION 3.3 END\n",
    "    \n",
    "    def forward(self, query, key, value, mask=None):\n",
    "        \"\"\"\n",
    "        Parameters:\n",
    "        -----------\n",
    "        query : torch.Tensor, shape (batch, seq_len, d_model)\n",
    "        key : torch.Tensor, shape (batch, seq_len, d_model)\n",
    "        value : torch.Tensor, shape (batch, seq_len, d_model)\n",
    "        mask : torch.Tensor, optional\n",
    "        \n",
    "        Returns:\n",
    "        --------\n",
    "        output : torch.Tensor, shape (batch, seq_len, d_model)\n",
    "        attention_weights : torch.Tensor, shape (batch, num_heads, seq_len, seq_len)\n",
    "        \"\"\"\n",
    "        batch_size = query.size(0)\n",
    "        \n",
    "        ### QUESTION 3.4 START\n",
    "        # 1. Apply linear projections to get Q, K, V\n",
    "        # 2. Reshape for multi-head\n",
    "        # 3. Apply scaled_dot_product_attention\n",
    "        # 4. Reshape back\n",
    "        # 5. Apply output projection W_o\n",
    "        \n",
    "        ### QUESTION 3.4 END"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "104b8f89",
   "metadata": {
    "deletable": false,
    "editable": false,
    "otter": {
     "tests": [
      "q3.4"
     ]
    }
   },
   "outputs": [],
   "source": [
    "### TEST FOR QUESTION 3.3 & 3.4\n",
    "\n",
    "\n",
    "# Ensure test variables are defined\n",
    "if 'batch' not in locals() or 'seq_len' not in locals():\n",
    "    batch = 2\n",
    "    seq_len = 10\n",
    "    print(\"Note: Using default batch=2, seq_len=10 for testing\")\n",
    "\n",
    "print(\"=\" * 60)\n",
    "print(\"Testing Multi-Head Attention...\")\n",
    "print(\"=\" * 60)\n",
    "mha = MultiHeadAttention(d_model=512, num_heads=8)\n",
    "x = torch.randn(batch, seq_len, 512)\n",
    "\n",
    "try:\n",
    "    out, attn = mha(x, x, x)\n",
    "    print(f\"  Input shape: {x.shape}\")\n",
    "    print(f\"  Output shape: {out.shape}\")\n",
    "    print(f\"  Attention shape: {attn.shape}\")\n",
    "    assert out.shape == x.shape\n",
    "    print(\"  ✓ Multi-Head Attention working!\")\n",
    "except Exception as e:\n",
    "    print(f\"  ✗ Multi-Head Attention failed: {e}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35c275f9",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "source": [
    "## **Part 3: The Timekeeper (5 Marks)**\n",
    "---\n",
    "### **Question 3.5 Debugging Positional Encoding (5 Marks)**\n",
    "Transformers process all words at the same time (parallel). This makes them fast, but it means **they don't know word order**. To the model, *\"The dog bit the man\"* is the same as *\"The man bit the dog\"*.\n",
    "\n",
    "To fix this, we add a **Positional Encoding** vector to the word embeddings.\n",
    "\n",
    "**The Bug:**\n",
    "A junior engineer wrote the class below, but the sine/cosine wavelengths are messed up.\n",
    "**Your Task:** Identify the indexing error in the `forward` or `__init__` logic and fix it so that even and odd indices get the correct trigonometric functions.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01f31842-feb4-443a-a769-d596354a8739",
   "metadata": {
    "tags": [
     "otter_answer_cell"
    ]
   },
   "outputs": [],
   "source": [
    "### QUESTION 3.5 START: Fix the Bug in Positional Encoding (10 pts) ###\n",
    "\n",
    "\n",
    "class PositionalEncoding(nn.Module):\n",
    "    def __init__(self, d_model, max_len=5000, dropout=0.1):\n",
    "        super().__init__()\n",
    "        self.dropout = nn.Dropout(p=dropout)\n",
    "        \n",
    "        # Create matrix of [SeqLen, HiddenDim] representing the positional encodings\n",
    "        pe = torch.zeros(max_len, d_model)\n",
    "        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n",
    "        \n",
    "        # --- THE BUG IS IN THE LINE BELOW ---\n",
    "        # \n",
    "        # The correct formula for positional encoding is:\n",
    "        #   PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))\n",
    "        #   PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))\n",
    "        #\n",
    "        # The divisor term should be: 1 / 10000^(2i/d_model)\n",
    "        # Using logarithms: 1 / 10000^x = exp(-x * log(10000))\n",
    "        # \n",
    "        # The bug is in THIS line - the formula is wrong:\n",
    "        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))\n",
    "        \n",
    "        # TODO: Fix the div_term calculation to match the formula above\n",
    "        # Hint: What should be inside the parentheses after the * operator?\n",
    "        # -----------------------\n",
    "\n",
    "        pe[:, 0::2] = torch.sin(position * div_term)\n",
    "        pe[:, 1::2] = torch.cos(position * div_term)\n",
    "        \n",
    "        pe = pe.unsqueeze(0)\n",
    "        self.register_buffer('pe', pe)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # x: [Batch, SeqLen, HiddenDim]\n",
    "        x = x + self.pe[:, :x.size(1)]\n",
    "        return self.dropout(x)\n",
    "\n",
    "### QUESTION 3.5 END ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b4bb921",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "outputs": [],
   "source": [
    "### TEST FOR QUESTION 3.5\n",
    "\n",
    "\n",
    "print(\"=\" * 60)\n",
    "print(\"Testing Positional Encoding...\")\n",
    "print(\"=\" * 60)\n",
    "pe = PositionalEncoding(d_model=512)\n",
    "x = torch.randn(2, 100, 512)\n",
    "\n",
    "try:\n",
    "    out = pe(x)\n",
    "    print(f\"  Input shape: {x.shape}\")\n",
    "    print(f\"  Output shape: {out.shape}\")\n",
    "    assert out.shape == x.shape\n",
    "    \n",
    "    # Test that PE is added correctly\n",
    "    pe_no_dropout = PositionalEncoding(d_model=512, dropout=0.0)\n",
    "    pe_no_dropout.eval()\n",
    "    x_test = torch.zeros(1, 10, 512)\n",
    "    out_test = pe_no_dropout(x_test)\n",
    "    # Output should equal just the positional encoding\n",
    "    assert not torch.allclose(out_test, x_test), \"PE should add non-zero values\"\n",
    "    \n",
    "    # Test GPU transfer\n",
    "    pe_cuda = PositionalEncoding(d_model=512).to(device)\n",
    "    x_cuda = torch.randn(2, 100, 512).to(device)\n",
    "    out_cuda = pe_cuda(x_cuda)\n",
    "    print(f\"  GPU test passed: output on {out_cuda.device}\")\n",
    "    print(\"  ✓ Positional Encoding working!\")\n",
    "except Exception as e:\n",
    "    print(f\"  ✗ Positional Encoding failed: {e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b5bb3c0",
   "metadata": {
    "deletable": false,
    "editable": false,
    "otter": {
     "tests": [
      "q3.5"
     ]
    }
   },
   "outputs": [],
   "source": [
    "# --- NUMERICAL CORRECTNESS TEST FOR ATTENTION ---\n",
    "\n",
    "print(\"\\n\" + \"=\" * 60)\n",
    "print(\"Testing Attention Numerical Correctness...\")\n",
    "print(\"=\" * 60)\n",
    "\n",
    "# Create simple test case\n",
    "batch_test = 1\n",
    "heads_test = 1\n",
    "seq_test = 3\n",
    "d_k_test = 4\n",
    "\n",
    "Q_test = torch.ones(batch_test, heads_test, seq_test, d_k_test)\n",
    "K_test = torch.ones(batch_test, heads_test, seq_test, d_k_test)  \n",
    "V_test = torch.arange(12).reshape(batch_test, heads_test, seq_test, d_k_test).float()\n",
    "\n",
    "try:\n",
    "    out_test, attn_test = scaled_dot_product_attention(Q_test, K_test, V_test)\n",
    "    \n",
    "    # Check attention weights sum to 1\n",
    "    attn_sums = attn_test.sum(dim=-1)\n",
    "    assert torch.allclose(attn_sums, torch.ones_like(attn_sums), atol=1e-5), \\\n",
    "        \"Attention weights must sum to 1 across seq_len dimension\"\n",
    "    \n",
    "    # Check output shape\n",
    "    assert out_test.shape == V_test.shape, \\\n",
    "        f\"Output shape {out_test.shape} doesn't match Value shape {V_test.shape}\"\n",
    "    \n",
    "    print(\"  ✓ Attention weights sum to 1\")\n",
    "    print(\"  ✓ Output shape correct\")\n",
    "    print(\"  ✓ Numerical correctness verified!\")\n",
    "    \n",
    "except AssertionError as e:\n",
    "    print(f\"  ❌ Assertion failed: {e}\")\n",
    "except Exception as e:\n",
    "    print(f\"  ❌ Test failed: {e}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab4ce924",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "source": [
    "## **Part 4: The Decision (5 Marks)**\n",
    "---\n",
    "### **Question 3.6 & 3.7 The Classification Head (5 Marks)**\n",
    "Now that the Transformer Encoder has processed the text and understood the context, we need to make a prediction (e.g., Is this sentence Positive or Negative?).\n",
    "\n",
    "**Your Task:**\n",
    "Implement a simple classifier that:\n",
    "- Takes the **mean** (average) of the encoder outputs **over the sequence dimension** (dim=1).\n",
    "- Passes the result through a fully connected layer.\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "542a6ced-61b2-4c10-936d-44e33652996c",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "source": [
    "### Question 3.6 (2 points) & 3.7 (3 points)\n",
    "\n",
    "Transformer Encoder for Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c630be9-cafb-411f-8cf4-58fcf67a48ec",
   "metadata": {
    "tags": [
     "otter_answer_cell"
    ]
   },
   "outputs": [],
   "source": [
    "class TransformerEncoderLayer(nn.Module):\n",
    "    \"\"\"Single Transformer encoder layer.\"\"\"\n",
    "    \n",
    "    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):\n",
    "        super().__init__()\n",
    "        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)\n",
    "        self.feed_forward = nn.Sequential(\n",
    "            nn.Linear(d_model, d_ff),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(dropout),\n",
    "            nn.Linear(d_ff, d_model)\n",
    "        )\n",
    "        self.norm1 = nn.LayerNorm(d_model)\n",
    "        self.norm2 = nn.LayerNorm(d_model)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "    \n",
    "    def forward(self, x, mask=None):\n",
    "        # Self-attention with residual connection and layer norm\n",
    "        attn_out, _ = self.self_attn(x, x, x, mask)\n",
    "        x = self.norm1(x + self.dropout(attn_out))\n",
    "        \n",
    "        # Feed-forward with residual connection and layer norm\n",
    "        ff_out = self.feed_forward(x)\n",
    "        x = self.norm2(x + self.dropout(ff_out))\n",
    "        \n",
    "        return x\n",
    "\n",
    "\n",
    "class TransformerClassifier(nn.Module):\n",
    "    \"\"\"Transformer encoder for text classification.\"\"\"\n",
    "    \n",
    "    def __init__(self, vocab_size, d_model=256, num_heads=8, num_layers=4, \n",
    "                 d_ff=1024, num_classes=2, max_len=512, dropout=0.1):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.d_model = d_model\n",
    "        \n",
    "        # Embedding layers\n",
    "        self.token_embedding = nn.Embedding(vocab_size, d_model)\n",
    "        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)\n",
    "        \n",
    "        # Transformer encoder layers\n",
    "        self.encoder_layers = nn.ModuleList([\n",
    "            TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)\n",
    "            for _ in range(num_layers)\n",
    "        ])\n",
    "        \n",
    "        ### QUESTION 3.6 START\n",
    "        # Define classification head:\n",
    "        # Option 1: Use [CLS] token representation\n",
    "        # Option 2: Use mean pooling over sequence\n",
    "        # Include: LayerNorm, Linear(d_model, d_model), ReLU, Dropout, Linear(d_model, num_classes)\n",
    "        \n",
    "        \n",
    "        ### QUESTION 3.6 END\n",
    "    \n",
    "    def forward(self, input_ids, attention_mask=None):\n",
    "        \"\"\"\n",
    "        Parameters:\n",
    "        -----------\n",
    "        input_ids : torch.Tensor, shape (batch, seq_len)\n",
    "            Token indices\n",
    "        attention_mask : torch.Tensor, shape (batch, seq_len)\n",
    "            Mask for padding tokens (1 = real token, 0 = padding)\n",
    "            \n",
    "        Returns:\n",
    "        --------\n",
    "        logits : torch.Tensor, shape (batch, num_classes)\n",
    "        \"\"\"\n",
    "        ### QUESTION 3.7 START\n",
    "        # 1. Get token embeddings and scale by sqrt(d_model)\n",
    "        # 2. Add positional encoding\n",
    "        # 3. Convert attention_mask to proper format for self-attention\n",
    "        # 4. Pass through encoder layers\n",
    "        # 5. Apply classification head (pool sequence, then classify)\n",
    "        \n",
    "        \n",
    "        ### QUESTION 3.7 END"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4f35510",
   "metadata": {
    "deletable": false,
    "editable": false,
    "otter": {
     "tests": [
      "q3.7"
     ]
    }
   },
   "outputs": [],
   "source": [
    "### TEST FOR QUESTION 3.6 & 3.7\n",
    "\n",
    "\n",
    "print(\"=\" * 60)\n",
    "print(\"Testing Transformer Classifier...\")\n",
    "print(\"=\" * 60)\n",
    "vocab_size = 30000\n",
    "model = TransformerClassifier(\n",
    "    vocab_size=vocab_size,\n",
    "    d_model=256,\n",
    "    num_heads=heads,\n",
    "    num_layers=4,\n",
    "    num_classes=2\n",
    ").to(device)\n",
    "\n",
    "input_ids = torch.randint(0, vocab_size, (batch, seq_len)).to(device)\n",
    "attention_mask = torch.ones(batch, seq_len).to(device)\n",
    "attention_mask[:, -2:] = 0  # Last 2 tokens are padding\n",
    "\n",
    "try:\n",
    "    logits = model(input_ids, attention_mask)\n",
    "    print(f\"  Input shape: {input_ids.shape}\")\n",
    "    print(f\"  Logits shape: {logits.shape}\")\n",
    "    print(f\"  Predictions: {logits.argmax(-1).tolist()}\")\n",
    "    assert logits.shape == (2, 2)\n",
    "    print(\"  ✓ Transformer Classifier working!\")\n",
    "    \n",
    "    # Test without mask\n",
    "    logits_no_mask = model(input_ids)\n",
    "    print(f\"  Without mask - Logits shape: {logits_no_mask.shape}\")\n",
    "    print(\"  ✓ Works without mask too!\")\n",
    "except Exception as e:\n",
    "    print(f\"  ✗ Transformer Classifier failed: {e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bae81bdf-b627-406f-9030-d14ff12dba43",
   "metadata": {
    "tags": [
     "otter_answer_cell"
    ]
   },
   "outputs": [],
   "source": [
    "### ANSWER BONUS QUESTION 3.8 START\n",
    "\n",
    "\n",
    "def create_causal_mask(seq_len, device='cpu'):\n",
    "    \"\"\"\n",
    "    Create a causal (lower triangular) mask for decoder self-attention.\n",
    "    \n",
    "    Why needed: In language models, when generating token at position i,\n",
    "    the model should only attend to tokens at positions 0 to i (not future tokens).\n",
    "    This prevents \"cheating\" during training.\n",
    "    \n",
    "    Example for seq_len=4:\n",
    "        [[1, 0, 0, 0],   <- Token 0 can only see itself\n",
    "         [1, 1, 0, 0],   <- Token 1 can see tokens 0,1\n",
    "         [1, 1, 1, 0],   <- Token 2 can see tokens 0,1,2\n",
    "         [1, 1, 1, 1]]   <- Token 3 can see all tokens 0,1,2,3\n",
    "    \n",
    "    Args:\n",
    "        seq_len: Sequence length\n",
    "        device: Device to create mask on ('cpu' or 'cuda')\n",
    "    \n",
    "    Returns:\n",
    "        Causal mask of shape (1, 1, seq_len, seq_len)\n",
    "    \"\"\"\n",
    "    ### BONUS QUESTION 3.8 START\n",
    "\n",
    "    \n",
    "    ### QUESTION 3.8 END"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aed9f2dc-ada8-4813-a3e4-4c826c350b31",
   "metadata": {
    "deletable": false,
    "editable": false,
    "otter": {
     "tests": [
      "q3.8"
     ]
    }
   },
   "outputs": [],
   "source": [
    "### TEST FOR BONUS QUESTION 3.8\n",
    "\n",
    "\n",
    "print(\"=\" * 60)\n",
    "print(\"Testing Causal Mask...\")\n",
    "print(\"=\" * 60)\n",
    "try:\n",
    "    mask = create_causal_mask(5)\n",
    "    print(f\"  Mask shape: {mask.shape}\")\n",
    "    print(f\"  Mask (squeezed):\")\n",
    "    print(mask.squeeze().numpy())\n",
    "    \n",
    "    # Verify it's lower triangular\n",
    "    assert mask.shape == (1, 1, 5, 5)\n",
    "    expected = torch.tril(torch.ones(5, 5))\n",
    "    assert torch.allclose(mask.squeeze(), expected)\n",
    "    print(\"  ✓ Causal Mask working!\")\n",
    "except Exception as e:\n",
    "    print(f\"  ✗ Causal Mask failed: {e}\")\n",
    "\n",
    "# Test with causal mask in attention\n",
    "print(\"\\nTesting Attention with Causal Mask...\")\n",
    "try:\n",
    "    causal_mask = create_causal_mask(seq_len)\n",
    "    out_causal, attn_causal = scaled_dot_product_attention(Q, K, V, mask=causal_mask)\n",
    "    \n",
    "    # Verify upper triangle is zero (can't attend to future)\n",
    "    for i in range(seq_len):\n",
    "        for j in range(i + 1, seq_len):\n",
    "            assert (attn_causal[:, :, i, j] == 0).all(), \\\n",
    "                f\"Position {i} should not attend to position {j}\"\n",
    "    print(\"  ✓ Causal attention correctly masks future positions!\")\n",
    "except Exception as e:\n",
    "    print(f\"  ✗ Causal attention failed: {e}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96dcf3f4",
   "metadata": {},
   "source": [
    "## End of NOAI 2026 - Programming Question 3\n",
    "---"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  },
  "otter": {
   "OK_FORMAT": true,
   "tests": {
    "q3.1": {
     "name": "q3.1",
     "points": null,
     "suites": [
      {
       "cases": [],
       "scored": true,
       "setup": "",
       "teardown": "",
       "type": "doctest"
      }
     ]
    },
    "q3.3": {
     "name": "q3.3",
     "points": null,
     "suites": [
      {
       "cases": [],
       "scored": true,
       "setup": "",
       "teardown": "",
       "type": "doctest"
      }
     ]
    },
    "q3.4": {
     "name": "q3.4",
     "points": null,
     "suites": [
      {
       "cases": [],
       "scored": true,
       "setup": "",
       "teardown": "",
       "type": "doctest"
      }
     ]
    },
    "q3.5": {
     "name": "q3.5",
     "points": null,
     "suites": [
      {
       "cases": [],
       "scored": true,
       "setup": "",
       "teardown": "",
       "type": "doctest"
      }
     ]
    },
    "q3.6": {
     "name": "q3.6",
     "points": null,
     "suites": [
      {
       "cases": [],
       "scored": true,
       "setup": "",
       "teardown": "",
       "type": "doctest"
      }
     ]
    },
    "q3.7": {
     "name": "q3.7",
     "points": null,
     "suites": [
      {
       "cases": [],
       "scored": true,
       "setup": "",
       "teardown": "",
       "type": "doctest"
      }
     ]
    },
    "q3.8": {
     "name": "q3.8",
     "points": null,
     "suites": [
      {
       "cases": [],
       "scored": true,
       "setup": "",
       "teardown": "",
       "type": "doctest"
      }
     ]
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
