Paper Implementation

Paper Implementation

CamVid Implementation

Paper implementation of ENet on CamVid dataset. This guide assumes basic familiarity with notebooks and will include a brief setup process to get started with Google Colab.

Get the Notebook

Open the notebook in Google Colab and connect to a GPU runtime.

  • Go to the ENet Notebook link below & click on the Open in Colab button.
  • Connect to GPU Runtime: In the menubar, go to Runtime Change runtime type. In the pop-up window, Runtime type as Python Select T4 GPU as the hardware accelerator Click Save.
  • A Google Account is required. Colab interface is constantly changing, and it will autodetect recommended configurations for the notebook at launch. User is expected to do the best in either cases as GPU will improve the training time dramatically.

Initial Setup

Importing dependencies: Execute the first cell in the notebook to prepare the python environment by importing required dependencies.

⚠️
You can safelt ignore the warning about the notebook being not authored by Google or you can opt to reviewed the source code and run it after.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import StepLR
import cv2
import os
from tqdm import tqdm
from PIL import Image

root_path = "/content" # use this for google colab

Initialize Datasets

Uncomment the next cell to download the CamVid dataset and extract it.

!wget "https://www.dropbox.com/s/pxcz2wdz04zxocq/CamVid.zip?dl=1" -O CamVid.zip
!unzip CamVid.zip

ENet Architecture

⚠️
Include the ENet class after the other 4 class blocks

Refer to ENet Architecture in the documentation for the architecture code.

Model Instantiation

enet = ENet(12)  # instantiate a 12 class ENet
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
enet = enet.to(device)

Loading Dataset

def loader(training_path, segmented_path, batch_size, h=320, w=1000):
    filenames_t = os.listdir(training_path)
    total_files_t = len(filenames_t)

    filenames_s = os.listdir(segmented_path)
    total_files_s = len(filenames_s)

    assert total_files_t == total_files_s

    if str(batch_size).lower() == "all":
        batch_size = total_files_s

    idx = 0
    while 1:
        # Choosing random indexes of images and labels
        batch_idxs = np.random.randint(0, total_files_s, batch_size)

        inputs = []
        labels = []

        for jj in batch_idxs:
            # Reading normalized photo
            img = plt.imread(training_path + filenames_t[jj])
            # Resizing using nearest neighbor method
            img = cv2.resize(img, (h, w), cv2.INTER_NEAREST)
            inputs.append(img)

            # Reading semantic image
            img = Image.open(segmented_path + filenames_s[jj])
            img = np.array(img)
            # Resizing using nearest neighbor method
            img = cv2.resize(img, (h, w), cv2.INTER_NEAREST)
            labels.append(img)

        inputs = np.stack(inputs, axis=2)
        # Changing image format to C x H x W
        inputs = torch.tensor(inputs).transpose(0, 2).transpose(1, 3)

        labels = torch.tensor(labels)

        yield inputs, labels

Defining Class Weights

def get_class_weights(num_classes, c=1.02):
    pipe = loader(f"{root_path}/train/", f"{root_path}/trainannot/", batch_size="all")
    _, labels = next(pipe)
    all_labels = labels.flatten()
    each_class = np.bincount(all_labels, minlength=num_classes)
    prospensity_score = each_class / len(all_labels)
    class_weights = 1 / (np.log(c + prospensity_score))
    return class_weights


class_weights = get_class_weights(12)

Defining Hyper Parameters

lr = 5e-4
batch_size = 10

criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights).to(device))
optimizer = torch.optim.Adam(enet.parameters(), lr=lr, weight_decay=2e-4)

print_every = 5
eval_every = 5

Training the model: optional

train_losses = []
eval_losses = []

bc_train = 367 // batch_size  # mini_batch train
bc_eval = 101 // batch_size  # mini_batch validation

# Define pipeline objects
pipe = loader(f"{root_path}/train/", f"{root_path}/trainannot/", batch_size)
eval_pipe = loader(f"{root_path}/val/", f"{root_path}/valannot/", batch_size)

epochs = 100

# Train loop

for e in range(1, epochs + 1):
    train_loss = 0
    print("-" * 15, "Epoch %d" % e, "-" * 15)

    enet.train()

    for _ in tqdm(range(bc_train)):
        X_batch, mask_batch = next(pipe)

        # assign data to cpu/gpu
        X_batch, mask_batch = X_batch.to(device), mask_batch.to(device)

        optimizer.zero_grad()

        out = enet(X_batch.float())

        # loss calculation
        loss = criterion(out, mask_batch.long())
        # update weights
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    print()
    train_losses.append(train_loss)

    if (e + 1) % print_every == 0:
        print("Epoch {}/{}...".format(e, epochs), "Loss {:6f}".format(train_loss))

    if e % eval_every == 0:
        with torch.no_grad():
            enet.eval()

            eval_loss = 0

            # Validation loop
            for _ in tqdm(range(bc_eval)):
                inputs, labels = next(eval_pipe)

                inputs, labels = inputs.to(device), labels.to(device)

                out = enet(inputs)

                out = out.data.max(1)[1]

                eval_loss += (labels.long() - out.long()).sum()

            print()
            print("Loss {:6f}".format(eval_loss))

            eval_losses.append(eval_loss)

    if e % print_every == 0:
        checkpoint = {"epochs": e, "state_dict": enet.state_dict()}
        torch.save(
            checkpoint, "{}/ckpt-enet-{}-{}.pth".format(root_path, e, train_loss)
        )
        print("Model saved!")

print(
    "Epoch {}/{}...".format(e, epochs),
    "Total Mean Loss: {:6f}".format(sum(train_losses) / epochs),
)

Inference and Results

state_dict = torch.load(f'{root_path}/ckpt-enet.pth')['state_dict']
enet.load_state_dict(state_dict)

fname = "Seq05VD_f05100.png"
tmg_ = plt.imread(f"{root_path}/test/" + fname)
tmg_ = cv2.resize(tmg_, (512, 512), cv2.INTER_NEAREST)
tmg = torch.tensor(tmg_).unsqueeze(0).float()
tmg = tmg.transpose(2, 3).transpose(1, 2).to(device)

enet.to(device)
with torch.no_grad():
    out1 = enet(tmg.float()).squeeze(0)

# load the labeled (inferred) image
smg_ = Image.open(f'{root_path}/testannot/' + fname)
smg_ = cv2.resize(np.array(smg_), (512, 512), cv2.INTER_NEAREST)

# move the output to cpu TODO: why?
out2 = out1.cpu().detach().numpy()

mno = 8  # Should be between 0 - n-1 | where n is the number of classes

figure = plt.figure(figsize=(20, 10))
plt.subplot(1, 3, 1)
plt.title("Input Image")
plt.axis("off")
plt.imshow(tmg_)
plt.subplot(1, 3, 2)
plt.title("Output Image")
plt.axis("off")
plt.imshow(out2[mno, :, :])
plt.show()

b_ = out1.data.max(0)[1].cpu().numpy()
# Define the function that maps a 2D image with all the class labels to a segmented image with the specified colored maps

def decode_segmap(image):
    Sky = [128, 128, 128]
    Building = [128, 0, 0]
    Pole = [192, 192, 128]
    Road_marking = [255, 69, 0]
    Road = [128, 64, 128]
    Pavement = [60, 40, 222]
    Tree = [128, 128, 0]
    SignSymbol = [192, 128, 128]
    Fence = [64, 64, 128]
    Car = [64, 0, 128]
    Pedestrian = [64, 64, 0]
    Bicyclist = [0, 128, 192]

    label_colours = np.array(
        [
            Sky,
            Building,
            Pole,
            Road_marking,
            Road,
            Pavement,
            Tree,
            SignSymbol,
            Fence,
            Car,
            Pedestrian,
            Bicyclist,
        ]
    ).astype(np.uint8)
    r = np.zeros_like(image).astype(np.uint8)
    g = np.zeros_like(image).astype(np.uint8)
    b = np.zeros_like(image).astype(np.uint8)
    for l in range(0, 12):
        r[image == l] = label_colours[l, 0]
        g[image == l] = label_colours[l, 1]
        b[image == l] = label_colours[l, 2]

    rgb = np.zeros((image.shape[0], image.shape[1], 3)).astype(np.uint8)
    rgb[:, :, 0] = b
    rgb[:, :, 1] = g
    rgb[:, :, 2] = r
    return rgb

# decode the images
true_seg = decode_segmap(smg_)
pred_seg = decode_segmap(b_)

# plot the decoded segments
figure = plt.figure(figsize=(20, 10))
plt.subplot(1, 3, 1)
plt.title('Input Image')
plt.axis('off')
plt.imshow(tmg_)
plt.subplot(1, 3, 2)
plt.title('Predicted Segmentation')
plt.axis('off')
plt.imshow(pred_seg)
plt.subplot(1, 3, 3)
plt.title('Ground Truth')
plt.axis('off')
plt.imshow(true_seg)
plt.show()

Important Links

ℹ️
Latest fixes and updates to the code can be obtained from the above GitHub link

Last updated on