If youâve explored image generation, segmentation, or diffusion models, youâve probably heard of U-Net. But what exactly is it, and why is it so widely used? In this post, Iâll break down U-Net step by step with concrete examples and visual diagrams.
U-Net is a neural network architecture designed for tasks where you need an image in and an image out of the same size. It was originally created for medical image segmentation in 2015, but has since become the backbone of many modern AI systems, including Stable Diffusion.
The name comes from its shapeâwhen you draw the architecture, it looks like the letter âUâ:
Input Image
â
âŒ
âââââââââââââââââââââââââââââââââââââââââââ
â ENCODER (Downsampling) â
â âââââââ âââââââ âââââââ â
â â64ch â â â128châ â â256châ â ... â
â â128² â â64² â â32² â â
â ââââ¬âââ ââââ¬âââ ââââ¬âââ â
â â skip â skip â skip â
â ⌠⌠⌠â
â ââââŽâââ ââââŽâââ ââââŽâââ â
â â64ch â â â128châ â â256châ â ... â
â â128² â â64² â â32² â â
â âââââââ âââââââ âââââââ â
â DECODER (Upsampling) â
âââââââââââââââââââââââââââââââââââââââââââ
â
âŒ
Output Image
The encoder compresses the image, making it spatially smaller but with more channels:
128Ã128Ã3 â 64Ã64Ã64 â 32Ã32Ã128 â 16Ã16Ã256 â 8Ã8Ã512
â â â â â
ââââââââââââââââŽââââââââââââââŽââââââââââââââŽâââââââââââââ
Shrinking spatially
Growing in channels
At each step:
This is like summarizing a bookâyou lose details but capture the main ideas.
The bottleneck is the smallest point in the network:
âââââââââââââââââââââââââââââââââââ
â 8Ã8Ã512 â
â â
â Only 64 spatial positions â
â but 512 features each â
â â
â "Compressed understanding" â
âââââââââââââââââââââââââââââââââââ
At this point, the network has maximum semantic understanding but minimum spatial detail. It knows âwhatâ is in the image but has lost âwhereâ things are precisely.
The decoder expands the image back to full resolution:
8Ã8Ã512 â 16Ã16Ã256 â 32Ã32Ã128 â 64Ã64Ã64 â 128Ã128Ã3
But hereâs the problem: how do you recover the spatial details that were lost?
This is what makes U-Net special. Skip connections pass information directly from the encoder to the decoder, bypassing the bottleneck:
ENCODER DECODER
âââââââ âââââââ
128Ã128 âââââââ skip1 ââââââââââââââ 128Ã128
â â²
64Ã64 âââââââââ skip2 ââââââââââââ 64Ã64
â â²
32Ã32 âââââââââ skip3 ââââââââââ 32Ã32
â â²
16Ã16 âââââââââ skip4 ââââââââ 16Ã16
â â²
ââââ 8Ã8 BOTTLENECK âââââââââââââââââââ
Think of it this way:
| Source | Knows | Problem |
|---|---|---|
| Bottleneck | âWhatâ is in image | Lost âwhereâ exactly |
| Skip | âWhereâ things are | Doesnât know context |
| Combined | Both! | Sharp + accurate output |
WITHOUT skip connections: WITH skip connections:
ââââââââââââââââââââââ ââââââââââââââââââââââ
â â â â â
â ⯠â â â² â
â (blurry, â â â² â
â wrong spot) â â â (sharp, â
â â â â² correct!) â
â â â â â
ââââââââââââââââââââââ ââââââââââââââââââââââ
The bottleneck knows âthereâs a line somewhereâ but lost the exact position. The skip connection says âthe line edge is at these exact pixels.â Combined, you get a sharp, accurate output.
Every level of the U-Net uses convolutional blocks:
Input
â
Conv 3Ã3 â BatchNorm â ReLU
â
Conv 3Ã3 â BatchNorm â ReLU
â
Output
A 3Ã3 convolution looks at a pixel and its 8 neighbors to compute each output pixel.
Letâs make this concrete with Conv2d(2, 3, 3) â 2 input channels, 3 output channels, 3Ã3 kernel.
Key insight: Each output channel has its own filter, and each filter looks at ALL input channels.
INPUT (2 channels) OUTPUT (3 channels)
âââââââââââ âââââââââââ
â Ch 0 ââââ¬â Filter 0 âââââââ Ch 0 â
â â â âââââââââââ
âââââââââââ â
ââ Filter 1 âââââââââââââââââ
âââââââââââ â â Ch 1 â
â Ch 1 âââ†âââââââââââ
â â â
âââââââââââ ââ Filter 2 âââââââââââââââââ
â Ch 2 â
âââââââââââ
Each filter reads ALL input channels to produce ONE output channel.
Input (2 channels, 4Ã4 each):
Channel 0: Channel 1:
ââââââ¬âââââ¬âââââ¬âââââ ââââââ¬âââââ¬âââââ¬âââââ
â 10 â 10 â 0 â 0 â â 5 â 5 â 5 â 5 â
ââââââŒâââââŒâââââŒââââ†ââââââŒâââââŒâââââŒâââââ€
â 10 â 10 â 0 â 0 â â 5 â 5 â 5 â 5 â
ââââââŒâââââŒâââââŒââââ†ââââââŒâââââŒâââââŒâââââ€
â 10 â 10 â 0 â 0 â â 5 â 5 â 5 â 5 â
ââââââŒâââââŒâââââŒââââ†ââââââŒâââââŒâââââŒâââââ€
â 10 â 10 â 0 â 0 â â 5 â 5 â 5 â 5 â
ââââââŽâââââŽâââââŽâââââ ââââââŽâââââŽâââââŽâââââ
Filter 0 (one 3Ã3 kernel per input channel):
For input ch0: For input ch1:
ââââââ¬âââââ¬âââââ ââââââ¬âââââ¬âââââ
â 1 â 0 â -1 â â 0 â 0 â 0 â
ââââââŒâââââŒââââ†ââââââŒâââââŒâââââ€
â 1 â 0 â -1 â â 0 â 1 â 0 â
ââââââŒâââââŒââââ†ââââââŒâââââŒâââââ€
â 1 â 0 â -1 â â 0 â 0 â 0 â
ââââââŽâââââŽâââââ ââââââŽâââââŽâââââ
To compute output pixel at (row=1, col=1):
From ch0: 10Ã1 + 10Ã0 + 0Ã(-1) + 10Ã1 + 10Ã0 + 0Ã(-1) + 10Ã1 + 10Ã0 + 0Ã(-1) = 30
From ch1: 5Ã0 + 5Ã0 + 5Ã0 + 5Ã0 + 5Ã1 + 5Ã0 + 5Ã0 + 5Ã0 + 5Ã0 = 5
Total: 30 + 5 + bias = 35
def forward(self, x):
features = self.conv(x) # Process with ConvBlock
pooled = self.pool(features) # Shrink by half
return pooled, features # Return BOTH!
Input: (1, 64, 64, 64)
â
ConvBlock
â
(1, 128, 64, 64) âââ SAVED as skip connection
â
MaxPool2d (shrink)
â
Output: (1, 128, 32, 32)
The key: it returns TWO things â the pooled result for the next layer AND the features for the skip connection.
def forward(self, x, skip):
x = self.up(x) # Grow spatially (ConvTranspose2d)
x = torch.cat([x, skip], dim=1) # Concatenate with skip
x = self.conv(x) # Process combined features
return x
Input: (1, 512, 8, 8) Skip: (1, 512, 16, 16)
â
ConvTranspose2d (grow 2Ã)
â
(1, 512, 16, 16)
â
Concat with skip (channels add)
â
(1, 1024, 16, 16)
â
ConvBlock (reduce channels)
â
Output: (1, 256, 16, 16)
ConvTranspose2d is the opposite of Conv2d â it makes images bigger:
Conv2d (stride=2): ConvTranspose2d (stride=2):
4Ã4 â 2Ã2 2Ã2 â 4Ã4
(shrink) (grow)
Each input pixel becomes a 2Ã2 region:
Input (2Ã2): Output (4Ã4):
âââââ¬ââââ âââââ¬ââââ¬ââââ¬ââââ
â 1 â 2 â â 1 â 1 â 2 â 2 â
âââââŒâââ†â âââââŒââââŒââââŒââââ€
â 3 â 4 â â 1 â 1 â 2 â 2 â
âââââŽââââ âââââŒââââŒââââŒââââ€
â 3 â 3 â 4 â 4 â
âââââŒââââŒââââŒââââ€
â 3 â 3 â 4 â 4 â
âââââŽââââŽââââŽââââ
Letâs trace through an entire U-Net forward pass:
INPUT: (1, 3, 128, 128) "RGB image"
ENCODER:
enc1: (1, 64, 64, 64) â skip1 saved
enc2: (1, 128, 32, 32) â skip2 saved
enc3: (1, 256, 16, 16) â skip3 saved
enc4: (1, 512, 8, 8) â skip4 saved
BOTTLENECK:
(1, 512, 8, 8) "Compressed understanding"
DECODER:
dec4: (1, 256, 16, 16) â uses skip4
dec3: (1, 128, 32, 32) â uses skip3
dec2: (1, 64, 64, 64) â uses skip2
dec1: (1, 64, 128, 128) â uses skip1
OUTPUT: (1, 3, 128, 128) "Processed image"
U-Net is used for any task requiring pixel-level output:
| Task | Input | Output |
|---|---|---|
| Medical segmentation | CT scan | Tumor mask |
| Semantic segmentation | Photo | Labels per pixel |
| Image denoising | Noisy image | Clean image |
| Inpainting | Image with hole | Filled image |
| Super resolution | Low-res | High-res |
| Style transfer | Photo | Stylized image |
| Diffusion models | Noisy latent | Denoised latent |
Not all tasks need a decoder:
Classification (no decoder):
Image â [shrink, shrink, shrink] â "This is a cat"
U-Net (full decoder):
Image â [shrink] â [expand] â Processed image
If you only need a label, not a pixel-by-pixel output, skip the decoder.
U-Netâs power comes from three key ideas:
This combination allows U-Net to understand both the big picture (global context from bottleneck) and fine details (local information from skips), producing sharp, accurate outputs.
Whether youâre segmenting medical images, generating art with Stable Diffusion, or building your own image editing model, U-Netâs elegant architecture is likely at the core.
This post was created while building a text-conditioned image editing model. The examples and diagrams come from hands-on experimentation with PyTorch.
If youâve explored image generation, segmentation, or diffusion models, youâve probably heard of U-Net. But what exactly is it, and why is it so widely used? In this post, Iâll break down U-Net step by step with concrete examples and visual diagrams.
U-Net is a neural network architecture designed for tasks where you need an image in and an image out of the same size. It was originally created for medical image segmentation in 2015, but has since become the backbone of many modern AI systems, including Stable Diffusion.
The name comes from its shapeâwhen you draw the architecture, it looks like the letter âUâ:
Input Image
â
âŒ
âââââââââââââââââââââââââââââââââââââââââââ
â ENCODER (Downsampling) â
â âââââââ âââââââ âââââââ â
â â64ch â â â128châ â â256châ â ... â
â â128² â â64² â â32² â â
â ââââ¬âââ ââââ¬âââ ââââ¬âââ â
â â skip â skip â skip â
â ⌠⌠⌠â
â ââââŽâââ ââââŽâââ ââââŽâââ â
â â64ch â â â128châ â â256châ â ... â
â â128² â â64² â â32² â â
â âââââââ âââââââ âââââââ â
â DECODER (Upsampling) â
âââââââââââââââââââââââââââââââââââââââââââ
â
âŒ
Output Image
The encoder compresses the image, making it spatially smaller but with more channels:
128Ã128Ã3 â 64Ã64Ã64 â 32Ã32Ã128 â 16Ã16Ã256 â 8Ã8Ã512
â â â â â
ââââââââââââââââŽââââââââââââââŽââââââââââââââŽâââââââââââââ
Shrinking spatially
Growing in channels
At each step:
This is like summarizing a bookâyou lose details but capture the main ideas.
The bottleneck is the smallest point in the network:
âââââââââââââââââââââââââââââââââââ
â 8Ã8Ã512 â
â â
â Only 64 spatial positions â
â but 512 features each â
â â
â "Compressed understanding" â
âââââââââââââââââââââââââââââââââââ
At this point, the network has maximum semantic understanding but minimum spatial detail. It knows âwhatâ is in the image but has lost âwhereâ things are precisely.
The decoder expands the image back to full resolution:
8Ã8Ã512 â 16Ã16Ã256 â 32Ã32Ã128 â 64Ã64Ã64 â 128Ã128Ã3
But hereâs the problem: how do you recover the spatial details that were lost?
This is what makes U-Net special. Skip connections pass information directly from the encoder to the decoder, bypassing the bottleneck:
ENCODER DECODER
âââââââ âââââââ
128Ã128 âââââââ skip1 ââââââââââââââ 128Ã128
â â²
64Ã64 âââââââââ skip2 ââââââââââââ 64Ã64
â â²
32Ã32 âââââââââ skip3 ââââââââââ 32Ã32
â â²
16Ã16 âââââââââ skip4 ââââââââ 16Ã16
â â²
ââââ 8Ã8 BOTTLENECK âââââââââââââââââââ
Think of it this way:
| Source | Knows | Problem |
|---|---|---|
| Bottleneck | âWhatâ is in image | Lost âwhereâ exactly |
| Skip | âWhereâ things are | Doesnât know context |
| Combined | Both! | Sharp + accurate output |
WITHOUT skip connections: WITH skip connections:
ââââââââââââââââââââââ ââââââââââââââââââââââ
â â â â â
â ⯠â â â² â
â (blurry, â â â² â
â wrong spot) â â â (sharp, â
â â â â² correct!) â
â â â â â
ââââââââââââââââââââââ ââââââââââââââââââââââ
The bottleneck knows âthereâs a line somewhereâ but lost the exact position. The skip connection says âthe line edge is at these exact pixels.â Combined, you get a sharp, accurate output.
Every level of the U-Net uses convolutional blocks:
Input
â
Conv 3Ã3 â BatchNorm â ReLU
â
Conv 3Ã3 â BatchNorm â ReLU
â
Output
A 3Ã3 convolution looks at a pixel and its 8 neighbors to compute each output pixel.
Letâs make this concrete with Conv2d(2, 3, 3) â 2 input channels, 3 output channels, 3Ã3 kernel.
Key insight: Each output channel has its own filter, and each filter looks at ALL input channels.
INPUT (2 channels) OUTPUT (3 channels)
âââââââââââ âââââââââââ
â Ch 0 ââââ¬â Filter 0 âââââââ Ch 0 â
â â â âââââââââââ
âââââââââââ â
ââ Filter 1 âââââââââââââââââ
âââââââââââ â â Ch 1 â
â Ch 1 âââ†âââââââââââ
â â â
âââââââââââ ââ Filter 2 âââââââââââââââââ
â Ch 2 â
âââââââââââ
Each filter reads ALL input channels to produce ONE output channel.
Input (2 channels, 4Ã4 each):
Channel 0: Channel 1:
ââââââ¬âââââ¬âââââ¬âââââ ââââââ¬âââââ¬âââââ¬âââââ
â 10 â 10 â 0 â 0 â â 5 â 5 â 5 â 5 â
ââââââŒâââââŒâââââŒââââ†ââââââŒâââââŒâââââŒâââââ€
â 10 â 10 â 0 â 0 â â 5 â 5 â 5 â 5 â
ââââââŒâââââŒâââââŒââââ†ââââââŒâââââŒâââââŒâââââ€
â 10 â 10 â 0 â 0 â â 5 â 5 â 5 â 5 â
ââââââŒâââââŒâââââŒââââ†ââââââŒâââââŒâââââŒâââââ€
â 10 â 10 â 0 â 0 â â 5 â 5 â 5 â 5 â
ââââââŽâââââŽâââââŽâââââ ââââââŽâââââŽâââââŽâââââ
Filter 0 (one 3Ã3 kernel per input channel):
For input ch0: For input ch1:
ââââââ¬âââââ¬âââââ ââââââ¬âââââ¬âââââ
â 1 â 0 â -1 â â 0 â 0 â 0 â
ââââââŒâââââŒââââ†ââââââŒâââââŒâââââ€
â 1 â 0 â -1 â â 0 â 1 â 0 â
ââââââŒâââââŒââââ†ââââââŒâââââŒâââââ€
â 1 â 0 â -1 â â 0 â 0 â 0 â
ââââââŽâââââŽâââââ ââââââŽâââââŽâââââ
To compute output pixel at (row=1, col=1):
From ch0: 10Ã1 + 10Ã0 + 0Ã(-1) + 10Ã1 + 10Ã0 + 0Ã(-1) + 10Ã1 + 10Ã0 + 0Ã(-1) = 30
From ch1: 5Ã0 + 5Ã0 + 5Ã0 + 5Ã0 + 5Ã1 + 5Ã0 + 5Ã0 + 5Ã0 + 5Ã0 = 5
Total: 30 + 5 + bias = 35
def forward(self, x):
features = self.conv(x) # Process with ConvBlock
pooled = self.pool(features) # Shrink by half
return pooled, features # Return BOTH!
Input: (1, 64, 64, 64)
â
ConvBlock
â
(1, 128, 64, 64) âââ SAVED as skip connection
â
MaxPool2d (shrink)
â
Output: (1, 128, 32, 32)
The key: it returns TWO things â the pooled result for the next layer AND the features for the skip connection.
def forward(self, x, skip):
x = self.up(x) # Grow spatially (ConvTranspose2d)
x = torch.cat([x, skip], dim=1) # Concatenate with skip
x = self.conv(x) # Process combined features
return x
Input: (1, 512, 8, 8) Skip: (1, 512, 16, 16)
â
ConvTranspose2d (grow 2Ã)
â
(1, 512, 16, 16)
â
Concat with skip (channels add)
â
(1, 1024, 16, 16)
â
ConvBlock (reduce channels)
â
Output: (1, 256, 16, 16)
ConvTranspose2d is the opposite of Conv2d â it makes images bigger:
Conv2d (stride=2): ConvTranspose2d (stride=2):
4Ã4 â 2Ã2 2Ã2 â 4Ã4
(shrink) (grow)
Each input pixel becomes a 2Ã2 region:
Input (2Ã2): Output (4Ã4):
âââââ¬ââââ âââââ¬ââââ¬ââââ¬ââââ
â 1 â 2 â â 1 â 1 â 2 â 2 â
âââââŒâââ†â âââââŒââââŒââââŒââââ€
â 3 â 4 â â 1 â 1 â 2 â 2 â
âââââŽââââ âââââŒââââŒââââŒââââ€
â 3 â 3 â 4 â 4 â
âââââŒââââŒââââŒââââ€
â 3 â 3 â 4 â 4 â
âââââŽââââŽââââŽââââ
Letâs trace through an entire U-Net forward pass:
INPUT: (1, 3, 128, 128) "RGB image"
ENCODER:
enc1: (1, 64, 64, 64) â skip1 saved
enc2: (1, 128, 32, 32) â skip2 saved
enc3: (1, 256, 16, 16) â skip3 saved
enc4: (1, 512, 8, 8) â skip4 saved
BOTTLENECK:
(1, 512, 8, 8) "Compressed understanding"
DECODER:
dec4: (1, 256, 16, 16) â uses skip4
dec3: (1, 128, 32, 32) â uses skip3
dec2: (1, 64, 64, 64) â uses skip2
dec1: (1, 64, 128, 128) â uses skip1
OUTPUT: (1, 3, 128, 128) "Processed image"
U-Net is used for any task requiring pixel-level output:
| Task | Input | Output |
|---|---|---|
| Medical segmentation | CT scan | Tumor mask |
| Semantic segmentation | Photo | Labels per pixel |
| Image denoising | Noisy image | Clean image |
| Inpainting | Image with hole | Filled image |
| Super resolution | Low-res | High-res |
| Style transfer | Photo | Stylized image |
| Diffusion models | Noisy latent | Denoised latent |
Not all tasks need a decoder:
Classification (no decoder):
Image â [shrink, shrink, shrink] â "This is a cat"
U-Net (full decoder):
Image â [shrink] â [expand] â Processed image
If you only need a label, not a pixel-by-pixel output, skip the decoder.
U-Netâs power comes from three key ideas:
This combination allows U-Net to understand both the big picture (global context from bottleneck) and fine details (local information from skips), producing sharp, accurate outputs.
Whether youâre segmenting medical images, generating art with Stable Diffusion, or building your own image editing model, U-Netâs elegant architecture is likely at the core.
This post was created while building a text-conditioned image editing model. The examples and diagrams come from hands-on experimentation with PyTorch.