Training a model to perform semantic segmentation is really not that much different to training a classification model. The only extra step needed after the final logits layer is to flatten out the output layer from a 2D image to a flat array. This process takes a tensor of shape `[n_samples, height, width, n_classes]`

, and we end up with a tensor of shape `[n_samples, n_pixels, n_classes]`

. The same transformation should also be applied to the labels data.

Once we have the data as `[n_samples, n_pixels, n_classes]`

, we can apply a softmax non-linearity to the data, and perform cross entropy loss. The only difference here from a classification task is that instead of a single cross entropy loss per sample in the batch, we get a cross-entropy loss for every pixel of every sample in the batch. To get a single loss value, we calculate the mean.