Training a Semantic Segmentation Model

Training a model to perform semantic segmentation is really not that much different to training a classification model. The only extra step needed after the final logits layer is to flatten out the output layer from a 2D image to a flat array. This process takes a tensor of shape [n_samples, height, width, n_classes], and we end up with a tensor of shape [n_samples, n_pixels, n_classes]. The same transformation should also be applied to the labels data.

Once we have the data as [n_samples, n_pixels, n_classes], we can apply a softmax non-linearity to the data, and perform cross entropy loss. The only difference here from a classification task is that instead of a single cross entropy loss per sample in the batch, we get a cross-entropy loss for every pixel of every sample in the batch. To get a single loss value, we calculate the mean.