{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "27208db6-36c7-4b38-9ef6-e7da2689dd54",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "source": [
    "# 📘 **NOAI 2026 - Programming Task 2**\n",
    "* **Subject:** Deep Learning (Computer Vision)\n",
    "* **Title:** The Segmentation Fracture (Debugging U-Net)\n",
    "* **Total:** 20 Marks\n",
    "  * **Part 1: The Architecture (10 Marks)** - Fixing the Padding Mismatch\n",
    "  * **Part 2: The Metric (10 Marks)** - Re-implementing Mean IoU\n",
    "  * **Part 3: System Verification (No Marks Allocated)** - Final Integrity Check\n",
    "\n",
    "Timestamp: 20 Feb 2026"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4a6ca3f",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "source": [
    "## **1. The Scenario**\n",
    "You are the Lead Engineer at *OrbitalVision*, a startup processing high-resolution satellite imagery to track deforestation. Your team uses a standard **U-Net** architecture to segment images into two classes: **Urban (0)** and **Forest (1)**.\n",
    "\n",
    "A junior intern recently tried to optimize the model for speed by removing \"unnecessary\" padding. However, this broke the geometry of the network. Now, the training pipeline crashes with `RuntimeError: Sizes of tensors must match` because the feature maps in the Decoder are larger than the ones from the Encoder.\n",
    "\n",
    "**Your Mission:**\n",
    "1.  **Fix the Architecture:** Restore the padding logic so the skip connections align perfectly.\n",
    "2.  **Restore the Metric:** The intern also deleted the evaluation code. You must re-implement the **Mean Intersection over Union (mIoU)** from scratch (no Scikit-Learn allowed)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39163b56",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "source": [
    "### **📚 Supplementary Reading**\n",
    "\n",
    "* **[Guide to Intersection over Union (IoU)](https://www.v7labs.com/blog/intersection-over-union-guide)**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5fe401e7",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "source": [
    "## **2. Setup & Rules**\n",
    "- **Allowed Libraries:** `torch`, `torch.nn`, `torch.nn.functional`, `numpy`.\n",
    "- **Banned Libraries:** `sklearn`, `torchmetrics`, or any other high-level metric wrappers.\n",
    "- **Input Data:** We provide a mock random tensor `(B, 3, 256, 256)` to test your forward pass."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5143fd25-0805-4981-bbbb-7f270071c6ac",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "outputs": [],
   "source": [
    "# --- CELL 1: SETUP (DO NOT EDIT) ---\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "\n",
    "# Set seed for reproducibility\n",
    "torch.manual_seed(2026)\n",
    "\n",
    "print(\"Setup Complete. Device:\", \"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45e79f9d-9d46-4252-b5c6-4b6951de2038",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "source": [
    "## **Part 1: The Architecture (10 Marks)**\n",
    "\n",
    "### **Question 2.1: The \"Valid Padding\" Fracture**\n",
    "The intern used `padding=0` (\"Valid\" convolution) to save compute. This caused a critical geometric issue: **The Feature Maps Shrink.**\n",
    "\n",
    "* **Input:** $256 \\times 256$\n",
    "* **Encoder Layer 1:** $254 \\times 254$ (Lost 2 pixels border)\n",
    "* **...**\n",
    "* **Bottleneck Output:** The feature map is significantly smaller than the input.\n",
    "\n",
    "When the Decoder upsamples the image, it reaches a size (e.g., $114 \\times 114$) that is smaller than the corresponding Skip Connection from the Encoder (e.g., $122 \\times 122$). Pytorch cannot concatenate tensors of different sizes.\n",
    "\n",
    "**Your Task:**\n",
    "1.  **Analyze** the shape mismatch in the `forward` method.\n",
    "2.  **Implement** the `center_crop(layer, target_shape)` method. It must slice the larger Encoder feature map (`layer`) from the center so it matches the height and width of the Decoder feature map (`target_shape`).\n",
    "3.  **Verify** that the channel dimensions align for concatenation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "867b783c-2415-4fa6-b93f-7a163fd49359",
   "metadata": {
    "tags": [
     "otter_answer_cell"
    ]
   },
   "outputs": [],
   "source": [
    "### QUESTION 2.1 START: Fix the Architecture (10 pts)\n",
    "\n",
    "class BrokenUNet(nn.Module):\n",
    "    def __init__(self, in_channels=3, out_classes=2):\n",
    "        super(BrokenUNet, self).__init__()\n",
    "        \n",
    "        # --- ENCODER (Intern removed padding to 'save compute') ---\n",
    "        # Note: kernel=3, padding=0 causes spatial dimension reduction by 2 pixels per layer.\n",
    "        self.enc1_1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=0) \n",
    "        self.enc1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=0)\n",
    "        self.pool1 = nn.MaxPool2d(2) # Downsample by 2\n",
    "        \n",
    "        self.enc2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=0)\n",
    "        self.enc2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=0)\n",
    "        self.pool2 = nn.MaxPool2d(2) # Downsample by 2\n",
    "        \n",
    "        # --- BOTTLENECK ---\n",
    "        self.bot1 = nn.Conv2d(128, 256, kernel_size=3, padding=0)\n",
    "        self.bot2 = nn.Conv2d(256, 256, kernel_size=3, padding=0)\n",
    "        \n",
    "        # --- DECODER ---\n",
    "        self.upconv1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)\n",
    "        \n",
    "        # Decoder Convolutions\n",
    "        # Expects 256 input channels (128 from upconv + 128 from skip)\n",
    "        self.dec1_1 = nn.Conv2d(256, 128, kernel_size=3, padding=0)\n",
    "        self.dec1_2 = nn.Conv2d(128, 64, kernel_size=3, padding=0)\n",
    "        \n",
    "        self.final = nn.Conv2d(64, out_classes, kernel_size=1)\n",
    "\n",
    "    def center_crop(self, layer, target_shape):\n",
    "        \"\"\"\n",
    "        Crop the layer to match target dimensions (center-aligned).\n",
    "        \n",
    "        Why needed: With padding=0, encoder feature maps are LARGER than \n",
    "        decoder feature maps after upsampling. We need to crop the encoder \n",
    "        features from the center to match decoder size for concatenation.\n",
    "        \n",
    "        Args:\n",
    "            layer: Tensor of shape (B, C, H_large, W_large) - the skip connection from encoder\n",
    "            target_shape: torch.Size (B, C, H_small, W_small) - the upsampled decoder tensor\n",
    "        \n",
    "        Returns:\n",
    "            Cropped tensor of shape (B, C, H_small, W_small)\n",
    "        \n",
    "        Example:\n",
    "            layer shape: (1, 128, 122, 122)  <- from encoder\n",
    "            target_shape: (1, 128, 114, 114) <- from decoder after upsampling\n",
    "            Need to crop: from [122x122] to [114x114]\n",
    "            Center crop: start at (122-114)//2 = 4, end at 4+114 = 118\n",
    "            Return: layer[:, :, 4:118, 4:118]\n",
    "        \"\"\"\n",
    "        # TODO Step 1: Extract dimensions from layer\n",
    "\n",
    "        # TODO Step 2: Extract target dimensions from target_shape\n",
    "\n",
    "        # TODO Step 3: Calculate center crop offsets\n",
    "\n",
    "        # TODO Step 4: Return center-cropped tensor\n",
    "    \n",
    "\n",
    "    def forward(self, x):\n",
    "        # 1. Encoder 1\n",
    "        x1 = F.relu(self.enc1_1(x))\n",
    "        x1 = F.relu(self.enc1_2(x1))\n",
    "        x1_p = self.pool1(x1)\n",
    "        \n",
    "        # 2. Encoder 2\n",
    "        x2 = F.relu(self.enc2_1(x1_p))\n",
    "        x2 = F.relu(self.enc2_2(x2))\n",
    "        x2_p = self.pool2(x2)\n",
    "        \n",
    "        # 3. Bottleneck\n",
    "        bot = F.relu(self.bot1(x2_p))\n",
    "        bot = F.relu(self.bot2(bot))\n",
    "        \n",
    "        # 4. Decoder 1\n",
    "        d1 = self.upconv1(bot)\n",
    "        \n",
    "        # --- CRASH POINT: Shape Mismatch ---\n",
    "        # x2 is larger (from Encoder). d1 is smaller (from Bottleneck).\n",
    "        # You must use self.center_crop() here to make x2 match d1.\n",
    "        \n",
    "        # TODO: UNCOMMENT the line below after implementing center_crop()\n",
    "        x2_cropped = x2  # This placeholder will cause a crash - you must fix it!\n",
    "        \n",
    "        # Concatenate along channel axis (dim=1)\n",
    "        cat1 = torch.cat((x2_cropped, d1), dim=1)\n",
    "        \n",
    "        dec1 = F.relu(self.dec1_1(cat1))\n",
    "        dec1 = F.relu(self.dec1_2(dec1))\n",
    "        \n",
    "        return self.final(dec1)\n",
    "\n",
    "### QUESTION 2.1 END"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "577b0d61",
   "metadata": {
    "deletable": false,
    "editable": false,
    "otter": {
     "tests": [
      "q2.1a"
     ]
    }
   },
   "outputs": [],
   "source": [
    "# --- SELF-CHECK: RUN THIS TO TEST YOUR FIX ---\n",
    "# If your center_crop is correct, this cell will print a Success message.\n",
    "\n",
    "model = BrokenUNet()\n",
    "dummy_input = torch.randn(1, 3, 256, 256)\n",
    "\n",
    "try:\n",
    "    output = model(dummy_input)\n",
    "    print(\"✅ Forward Pass Successful!\")\n",
    "    print(f\"Input Shape:  {dummy_input.shape}\")\n",
    "    print(f\"Output Shape: {output.shape}\")\n",
    "    print(\"-\" * 30)\n",
    "    print(\"Note: The output is smaller than the input (e.g. 196x196) because padding=0.\")\n",
    "    print(\"This is expected behavior for this specific architecture.\")\n",
    "except Exception as e:\n",
    "    print(f\"❌ CRASH REPORT: {e}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f81fff6f",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "source": [
    "<!-- BEGIN QUESTION -->\n",
    "\n",
    "### **Question 2.1b: Why Center Cropping? (Conceptual - No Coding)**\n",
    "\n",
    "After fixing the code, answer this question in your own words:\n",
    "\n",
    "**Q:** The intern removed padding to \"save compute\". This caused feature maps to shrink. \n",
    "You fixed it with center cropping. But what are the **trade-offs** of this approach?\n",
    "\n",
    "**Think about:**\n",
    "- What information do we lose when we crop?\n",
    "- Could we use a different solution (e.g., zero-padding, interpolation)?\n",
    "- Why does the original U-Net paper use padding=1 instead?\n",
    "\n",
    "**Your Answer (2-3 sentences):**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60c56e00",
   "metadata": {
    "tags": [
     "otter_answer_cell"
    ]
   },
   "source": [
    "*[Type your answer here after implementing the fix]*"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d9425c9",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "source": [
    "---\n",
    "\n",
    "**Hint:** There's no single \"correct\" answer - this tests your understanding of design choices."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6e08c9a",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "source": [
    "<!-- END QUESTION -->\n",
    "\n",
    "## **Part 2: The Metric (10 Marks)**\n",
    "\n",
    "### **Question 2.2: The \"IoU\" Standard**\n",
    "\n",
    "**Why IoU instead of Accuracy?**\n",
    "If 90% of a satellite map is Forest, a dumb model that always predicts \"Forest\" gets 90% accuracy but has zero utility. IoU measures actual overlap between prediction and ground truth.\n",
    "\n",
    "---\n",
    "\n",
    "### **The Formula**\n",
    "\n",
    "For a single class:\n",
    "\n",
    "$\\displaystyle \\text{IoU}_{\\text{class}} = \\frac{\\text{Intersection}}{\\text{Union}} = \\frac{TP}{TP + FP + FN}$\n",
    "\n",
    "Where:\n",
    "- **Intersection** = Pixels where BOTH prediction AND target are this class (Logical AND)\n",
    "- **Union** = Pixels where EITHER prediction OR target is this class (Logical OR)\n",
    "\n",
    "**Mean IoU (mIoU):** Average IoU across all classes (skip classes where union = 0)\n",
    "\n",
    "---\n",
    "\n",
    "### **PyTorch Implementation Pattern**\n",
    "```python\n",
    "# For class c:\n",
    "pred_mask = (pred == c)    # Boolean tensor: True where predicted as class c\n",
    "true_mask = (target == c)  # Boolean tensor: True where actually class c\n",
    "\n",
    "intersection = (pred_mask & true_mask).sum()  # Logical AND, then count\n",
    "union = (pred_mask | true_mask).sum()         # Logical OR, then count\n",
    "\n",
    "iou = intersection / union  # ⚠️ Will be NaN if union == 0!\n",
    "```\n",
    "\n",
    "---\n",
    "\n",
    "### **Your Task**\n",
    "\n",
    "Implement `calculate_miou` with these constraints:\n",
    "1. **Vectorized Operations Only:** You may loop over classes (0, 1), but **NO** loops over pixels or batches\n",
    "2. **Handle Missing Classes:** If union == 0 (class doesn't exist in target), skip it to avoid `NaN`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76d0bcaa-f818-453b-9eab-548bdbc4e9d7",
   "metadata": {
    "tags": [
     "otter_answer_cell"
    ]
   },
   "outputs": [],
   "source": [
    "### QUESTION 2.2 START: Implement Mean IoU (10 pts)\n",
    "\n",
    "def calculate_miou(pred_mask, true_mask, num_classes=2):\n",
    "    \"\"\"\n",
    "    Computes Mean Intersection over Union.\n",
    "    Args:\n",
    "        pred_mask: (B, H, W) Integer Tensor - predicted class labels\n",
    "        true_mask: (B, H, W) Integer Tensor - ground truth class labels\n",
    "        num_classes: Number of classes (e.g., 2 for Urban/Forest)\n",
    "    \n",
    "    Returns:\n",
    "        float: Mean IoU across all classes\n",
    "    \n",
    "    Algorithm:\n",
    "        For each class c in [0, num_classes-1]:\n",
    "            1. Create boolean masks: (pred == c) and (true == c)\n",
    "            2. Intersection = count where BOTH are True  [use & operator]\n",
    "            3. Union = count where AT LEAST ONE is True  [use | operator]\n",
    "            4. IoU = Intersection / Union\n",
    "            5. If Union == 0 (class doesn't exist), skip this class [use continue]\n",
    "        Return: Average of all valid IoUs [sum(ious) / len(ious)]\n",
    "    \n",
    "    Vectorization requirement:\n",
    "        - NO loops over pixels or batches\n",
    "        - Flatten tensors first: pred_flat = pred_mask.view(-1)\n",
    "        - Use boolean indexing: (pred_flat == cls) creates boolean tensor\n",
    "        - Use .sum() to count True values\n",
    "    \"\"\"\n",
    "    # TODO: Flatten tensors to (N,)\n",
    "    \n",
    "    ious = []\n",
    "    \n",
    "    # TODO: Loop over classes\n",
    "    for cls in range(num_classes):\n",
    "        # TODO: Create boolean masks\n",
    "\n",
    "        # TODO: Calculate Intersection\n",
    "\n",
    "        # TODO: Calculate Union\n",
    "\n",
    "        # TODO: Calculate IoU\n",
    "        # IMPORTANT: Check if union == 0 BEFORE dividing to avoid division by zero\n",
    "\n",
    "\n",
    "\n",
    "    # TODO: Return mean of ious (handle empty list case)\n",
    "\n",
    "\n",
    "### QUESTION 2.2 END"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "141ebfc5",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "outputs": [],
   "source": [
    "# --- SELF-CHECK: RUN THIS TO TEST YOUR METRIC ---\n",
    "# We create a fake perfect prediction and a bad prediction\n",
    "# 2x2 Image\n",
    "fake_pred = torch.tensor([[[0, 0], [1, 1]]]) \n",
    "# One pixel differs\n",
    "fake_true = torch.tensor([[[0, 1], [1, 1]]]) \n",
    "\n",
    "# Class 0 (Urban): Pred=[T, T, F, F], True=[T, F, F, F] -> IoU = 1/2 = 0.5\n",
    "# Class 1 (Forest): Pred=[F, F, T, T], True=[F, T, T, T] -> IoU = 2/3 = 0.66\n",
    "# Mean IoU = (0.5 + 0.66) / 2 = 0.5833\n",
    "\n",
    "try:\n",
    "    score = calculate_miou(fake_pred, fake_true, num_classes=2)\n",
    "    print(f\"Calculated mIoU: {score:.4f}\")\n",
    "    print(f\"Expected mIoU:   0.5833\")\n",
    "    \n",
    "    if 0.58 < score < 0.59:\n",
    "        print(\"✅ Metric Check Passed!\")\n",
    "    else:\n",
    "        print(\"❌ Metric Check Failed: Value mismatch.\")\n",
    "        \n",
    "except Exception as e:\n",
    "    print(f\"❌ Metric Check Crashed: {e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e61341d",
   "metadata": {
    "deletable": false,
    "editable": false,
    "otter": {
     "tests": [
      "q2.2"
     ]
    }
   },
   "outputs": [],
   "source": [
    "# --- ADDITIONAL TEST: Missing Class Handling ---\n",
    "# This tests the Union == 0 edge case\n",
    "\n",
    "# Create scenario where class 1 (Forest) doesn't exist in target\n",
    "pred_missing = torch.tensor([[[0, 0], [0, 1]]])   # Some predictions of class 1\n",
    "true_missing = torch.tensor([[[0, 0], [0, 0]]])   # But no class 1 in ground truth\n",
    "\n",
    "# Class 0: Both pred and true have it → IoU should be computed\n",
    "# Class 1: Union = 0 (doesn't exist in target) → Should be skipped\n",
    "\n",
    "try:\n",
    "    score_missing = calculate_miou(pred_missing, true_missing, num_classes=2)\n",
    "    print(f\"mIoU with missing class: {score_missing:.4f}\")\n",
    "    print(\"Expected: Only class 0 IoU counted (should be 0.75)\")\n",
    "    \n",
    "    # Class 0: pred=[T,T,T,F], true=[T,T,T,T] → Intersection=3, Union=4, IoU=0.75\n",
    "    if 0.74 < score_missing < 0.76:\n",
    "        print(\"✅ Missing class handling correct!\")\n",
    "    else:\n",
    "        print(\"❌ Missing class handling failed\")\n",
    "        \n",
    "except Exception as e:\n",
    "    print(f\"❌ Test crashed: {e}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8c231e7",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "source": [
    "## **Part 3: System Verification (No Marks Allocated)**\n",
    "\n",
    "### **Final Integrity Check**\n",
    "Before you submit, run the cell below to verify that your fixed U-Net (`BrokenUNet`) and your metric (`calculate_miou`) are compatible.\n",
    "\n",
    "**Why is this important?**\n",
    "Even if your logic seems correct, a shape mismatch during the forward pass will cause the auto-grader to fail your submission entirely. This is your \"Pre-Flight Check.\"\n",
    "\n",
    "**Success Criteria:**\n",
    "1.  The model accepts a $256 \\times 256$ input without crashing.\n",
    "2.  The output tensor has the correct shape (likely smaller than input due to valid padding, e.g., $196 \\times 196$, but this is expected)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e57972a9",
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "outputs": [],
   "source": [
    "# --- FINAL SYSTEM CHECK ---\n",
    "# Run this cell to ensure your model is ready for submission.\n",
    "\n",
    "print(\"Initiating System Integrity Check...\")\n",
    "\n",
    "try:\n",
    "    # 1. Instantiate the model\n",
    "    model = BrokenUNet(in_channels=3, out_classes=2)\n",
    "    \n",
    "    # 2. Create dummy input\n",
    "    x = torch.randn(1, 3, 256, 256)\n",
    "    \n",
    "    # 3. Run Forward Pass\n",
    "    y = model(x)\n",
    "    \n",
    "    print(\"✅ System Online: Forward pass successful.\")\n",
    "    print(f\"   Input Shape:  {x.shape}\")\n",
    "    print(f\"   Output Shape: {y.shape}\")\n",
    "    \n",
    "    if y.shape[2] < x.shape[2]:\n",
    "        print(\"\\n   (Note: Output is smaller than input due to valid padding. This is normal.)\")\n",
    "\n",
    "except Exception as e:\n",
    "    print(f\"❌ System Failure: {e}\")\n",
    "    print(\"   CRITICAL: Your model is still broken. Go back to Question 2.1.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "951031a7-fecf-40c2-99fa-a2cb8729857f",
   "metadata": {
    "deletable": false,
    "editable": false,
    "otter": {
     "tests": [
      "q2.3"
     ]
    }
   },
   "source": [
    "## End of NOAI 2026 - Programming Question 2\n",
    "---"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "noai-questions-2026",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  },
  "otter": {
   "OK_FORMAT": true,
   "tests": {
    "q2.1a": {
     "name": "q2.1a",
     "points": null,
     "suites": [
      {
       "cases": [
        {
         "code": ">>> x = torch.randn(1, 3, 256, 256)\n>>> def test_2_1_output_classes(x):\n...     \"\"\"Forward pass output has 2 channels (num_classes=2).\"\"\"\n...     model = BrokenUNet()\n...     output = model(x)\n...     return output.shape[1]\n>>> assert test_2_1_output_classes(x) == 2\n",
         "hidden": false,
         "locked": false
        },
        {
         "code": ">>> x = torch.randn(1, 3, 128, 128)\n>>> def test_2_1_forward_different_input(x):\n...     \"\"\"Forward pass also works on (1, 3, 128, 128) input.\"\"\"\n...     model = BrokenUNet()\n...     try:\n...         output = model(x)\n...         return output\n...     except RuntimeError as e:\n...         if 'Sizes of tensors must match' in str(e):\n...             raise AssertionError('center_crop failed on 128x128 input — skip connection mismatch')\n...         return torch.randn(1, 2, 1, 1)\n>>> assert test_2_1_forward_different_input(x).shape[1] == 2\n",
         "hidden": false,
         "locked": false
        }
       ],
       "scored": true,
       "setup": "",
       "teardown": "",
       "type": "doctest"
      }
     ]
    },
    "q2.2": {
     "name": "q2.2",
     "points": null,
     "suites": [
      {
       "cases": [
        {
         "code": ">>> def test_2_2_missing_class_skip():\n...     \"\"\"Class absent from both pred and true → skip (don't crash).\"\"\"\n...     pred = torch.tensor([[[0, 0], [0, 0]]])\n...     true = torch.tensor([[[0, 0], [0, 0]]])\n...     return calculate_miou(pred, true, num_classes=2)\n>>> assert abs(test_2_2_missing_class_skip() - 1.0) < 1e-06\n",
         "hidden": false,
         "locked": false
        }
       ],
       "scored": true,
       "setup": "",
       "teardown": "",
       "type": "doctest"
      }
     ]
    },
    "q2.3": {
     "name": "q2.3",
     "points": null,
     "suites": [
      {
       "cases": [],
       "scored": true,
       "setup": "",
       "teardown": "",
       "type": "doctest"
      }
     ]
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
