Building Controllable Latent Representations with Typed Autoencoders

Feature image: Translated digits
Feature image: Translated digits

Earlier we discussed latent spaces (in the context of AE and VAE). Let's explore a bit more deeply how to actually make the latent space useful by explicitly defining semantics of named coordinates.

So far latent spaces have just been vector spaces with no clear inherent semantic structure. You can locate a subspace that contains description of a digit, and the further you drift from that location, the "less digit" you get. None of the directions don't really carry any clear semantics.

Let's try to add a rotational element in the latent space. Namely, assume we have some latent representation \(z\) of some image \(x\). We add a coordinate in latent space with which we can rotate the decoded image, i.e. decoder\((z,\theta)\) results in an image \(x\) rotated with \(\theta\).

In literature this subject is called "disentangled representations".

Again, you can find the code from my GitHub under DeepEnd and deep_learning/typed_ae.

Model Architecture

Now that we saw and understood the beautiful theory, we are licensed to break it. And that license we will use with no hesitation.

Let's define a simple autoencoder architecture with a small variation. We define a separate rotation head.

Figure 1: Model architecture
Figure 1: Model architecture

The rotation head encodes rotation into a vector \((\cos(\theta),\sin(\theta))\) and we use a special class for it. The motivation becomes clear later.

class CyclicBottleneck(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.fc = nn.Linear(in_features, 2)
    
    def forward(self, x):
        raw = self.fc(x)
        normalized = F.normalize(raw, p=2, dim=-1)
        return normalized

The actual model is, again, straightforward. We define encoder and decoder separately, and in the bottleneck-layer we use the CyclicBottleneck.

class TypedAE(nn.Module):
    def __init__(self, latent_dim=5):
        super().__init__()
        self.latent_dim = latent_dim
        
        self.encoder_backbone = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU()
        )

        self.head_rot = CyclicBottleneck(128)  # Rotation (2D)
        self.head_latent = nn.Linear(128, latent_dim)  # Rest latent

        # Decoder takes: 2 (rot) + latent_dim
        self.decoder = nn.Sequential(
            nn.Linear(2 + latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 28*28),
            nn.Sigmoid()
        )

    def forward(self, x):
        h = self.encoder_backbone(x.view(x.size(0), -1))
        
        z_rot = self.head_rot(h)
        z_latent = self.head_latent(h)
        
        z_combined = torch.cat([z_rot, z_latent], dim=-1)
        recon = self.decoder(z_combined).view(x.size(0), 1, 28, 28)
        
        return recon, z_rot, z_latent

Notice the naming. The "typed" means that we are trying to build some semantics within the latent space, i.e. we are trying to define coordinates in the latent space that provide us some meaningful control over the generated image.

Our training now consists of three separate phases. First we train the autoencoder to encapsulate the images to latent space. We keep the rotation head constant at this phase. This ensures the image is encoded to the other dimensions.

    model = TypedAE().to(device)
    
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    
    train_dataset = torchvision.datasets.MNIST(root='./data', train=True, 
                                               download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        for batch_idx, (images, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")):
            images = images.to(device)
            optimizer.zero_grad()
            
            # Forward pass - NO rotation
            h = model.encoder_backbone(images.view(images.size(0), -1))
            z_latent = model.head_latent(h)
            
            # z_rot = zero (no rotation)
            z_rot = torch.zeros(images.size(0), 2).to(device)
            z_rot[:, 0] = 1.0  # cos(0) = 1
            
            z_combined = torch.cat([z_rot, z_latent], dim=-1)
            recon = model.decoder(z_combined).view(images.size(0), 1, 28, 28)
            
            # Loss
            loss = nn.functional.mse_loss(recon, images)
            loss.backward()
            optimizer.step()

The training is illustrated below.

Figure 2: Training setup
Figure 2: Training setup

This taught the model to capture the essentials in the latent space.

In the next step we do something very explicit. We freeze encoder altogether and train the decoder. We do this by taking input images, rotating them by \(\theta\) (that we encode in the rotation head) and teach decoder to do the actual rotation.

Here "freeze" means the encoder parameters are not trained. We still use the forward process in full, but optimization is done only on rotation head and decoder.

Figure 3: The second training phase.
Figure 3: The second training phase.

In python.

    model = TypedAE().to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # FREEZE ENCODER
    for param in model.encoder_backbone.parameters():
        param.requires_grad = False
    for param in model.head_latent.parameters():
        param.requires_grad = False
    for param in model.head_rot.parameters():
        param.requires_grad = False
    
    # Only decoder learns
    optimizer = optim.Adam(model.decoder.parameters(), lr=1e-3)
    
    transform = transforms.Compose([transforms.ToTensor()])
    train_dataset = torchvision.datasets.MNIST(root='./data', train=True, 
                                               download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        for batch_idx, (images, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")):
            images = images.to(device)
            
            optimizer.zero_grad()
            
            with torch.no_grad():
                # Encode original image
                h = model.encoder_backbone(images.view(images.size(0), -1))
                z_latent = model.head_latent(h)
            
            # Random rotation for each image
            angles_rad = (torch.rand(images.size(0)) * 2 * np.pi).to(device)  # 0-360°
            z_rot = torch.stack([torch.cos(angles_rad), torch.sin(angles_rad)], dim=1)
            
            # Decode with given theta
            z_combined = torch.cat([z_rot, z_latent], dim=-1)
            recon = model.decoder(z_combined).view(images.size(0), 1, 28, 28)
            
            # Target: rotated image
            target = rotate_image_batch(images, angles_rad)
            
            # Loss: decoder must produce the rotated image
            loss = nn.functional.mse_loss(recon, target)
            
            loss.backward()
            optimizer.step()

Next we need to ensure that encoder and decoder "speak the same language". Now, for the first time, encoder sees rotated images. This phase enforces it to NOT learn the angle because the angle is fed to the latent in a separate latent dimension. Second, we need the encoder to start to understand the angle the figure is rotated with.

The third phase aligns encoder and decoder.

Figure 4: The final training phase
Figure 4: The final training phase

And the training.

    # Freeze decoder
    for param in model.decoder.parameters():
        param.requires_grad = False
    
    # Release encoder (esp. rotation head)
    for param in model.encoder_backbone.parameters():
        param.requires_grad = True # Try false to see the effect
    for param in model.head_latent.parameters():
        param.requires_grad = True # Try false to see the effect
    for param in model.head_rot.parameters():
        param.requires_grad = True
    
    optimizer = optim.Adam([
        {'params': model.encoder_backbone.parameters()},
        {'params': model.head_latent.parameters()},
        {'params': model.head_rot.parameters()},
    ], lr=1e-3)
    
    transform = transforms.Compose([transforms.ToTensor()])
...    
    for epoch in range(epochs):
...        
        for batch_idx, (images, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")):
...            
            # Rotate with random angle
            angles_rad = (torch.rand(images.size(0)) * 2 * np.pi).to(device)
            images_rotated = rotate_image_batch(images, angles_rad)
            
            optimizer.zero_grad()
            
            # Encode rotated image
            h = model.encoder_backbone(images_rotated.view(images_rotated.size(0), -1))
            z_latent = model.head_latent(h)
            z_rot = model.head_rot(h)  # We train the angle!
            
            # Decode (frozen decoder)
            z_combined = torch.cat([z_rot, z_latent], dim=-1)
            recon = model.decoder(z_combined).view(images_rotated.size(0), 1, 28, 28)
            
            # Losses
            loss_recon = nn.functional.mse_loss(recon, images_rotated)
            
            # Rotation loss: z_rot must match the angle
            target_rot = torch.stack([torch.cos(angles_rad), torch.sin(angles_rad)], dim=1)
            loss_rot = nn.functional.mse_loss(z_rot, target_rot)
            
            loss = loss_recon + 1.0 * loss_rot
            
            loss.backward()
            optimizer.step()
            

The Result

This rather simple model works surprisingly well. When we encode an image we can control its angle via the rotation head and decode images in any angle we want. Try to run the code to see the results.

Some digits suffer from morphing or blurring.

Figure 5: Animation of rotation of digit 7
Figure 5: Animation of rotation of digit 7
Figure 6: Animation of rotation of digit 3
Figure 6: Animation of rotation of digit 3
Figure 7: Animation of rotation of digit 9
Figure 7: Animation of rotation of digit 9

Conclusions

Our example is very simple and we really want to stress test it, Fashion dataset would be the way to go. Short tryouts with it didn't really work out too well. Increasing the model size and training epochs is most likely needed.

Anyway, we have now a solid proof about that the latent space can be made controllable.