Exploring the Effective Rank of Projection Weights in Attention

Author

Jonathan Chang

Published

May 13, 2024

Introduction

  1. DeepSeek-V2 introduced Multi-Head Latent Attention (MLA), which uses low rank to compress the KV cache.
  2. The purpose of this notebook is to explore the effective rank of the projection weights in a pretrained model, using Llama-3-8B as an example.

Disclaimer: Most of the code are written by GPT/Copilot, and is not optimized for presentation.

import torch

from transformers import AutoTokenizer, AutoModelForCausalLM
# you need to accept the LLAMA license to use this model
# feel free to try this notebook with another model
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name)
layer = model.model.layers[0].self_attn
layer
LlamaSdpaAttention(
  (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
  (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
  (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (rotary_emb): LlamaRotaryEmbedding()
)
def get_qkv_weights(layer):
    # this can be different depending on the model
    if model_name == 'microsoft/Phi-3-mini-128k-instruct':
        query_pos = layer.num_heads * layer.head_dim
        q_range = slice(0, query_pos)
        k_range = slice(query_pos, query_pos + layer.num_key_value_heads * layer.head_dim)
        v_range = slice(query_pos + layer.num_key_value_heads * layer.head_dim, None)
        q_proj_weight = layer.qkv_proj.weight[q_range]
        k_proj_weight = layer.qkv_proj.weight[k_range]
        v_proj_weight = layer.qkv_proj.weight[v_range]
        o_proj_weight = layer.o_proj.weight
    elif model_name == 'microsoft/phi-1_5':
        q_proj_weight = layer.q_proj.weight
        k_proj_weight = layer.k_proj.weight
        v_proj_weight = layer.v_proj.weight
        o_proj_weight = layer.dense.weight
    elif model_name == 'meta-llama/Meta-Llama-3-8B-Instruct':
        q_proj_weight = layer.q_proj.weight
        k_proj_weight = layer.k_proj.weight
        v_proj_weight = layer.v_proj.weight
        o_proj_weight = layer.o_proj.weight
    return q_proj_weight, k_proj_weight, v_proj_weight, o_proj_weight

# let's test the function
q_proj_weight, k_proj_weight, v_proj_weight, o_proj_weight = get_qkv_weights(layer)
q_proj_weight.shape, k_proj_weight.shape, v_proj_weight.shape, o_proj_weight.shape
(torch.Size([4096, 4096]),
 torch.Size([1024, 4096]),
 torch.Size([1024, 4096]),
 torch.Size([4096, 4096]))
model.config.num_hidden_layers
32
import matplotlib.pyplot as plt
import numpy as np
import altair as alt
import pandas as pd
import pandas as pd

# Enable Altair data transformer
alt.data_transformers.enable("vegafusion")

# Initialize a list to hold all data frames for different layers
all_data = []

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
import torch
from tqdm import tqdm
# Loop over the desired layers
for i in tqdm(range(0, model.config.num_hidden_layers)):
    layer = model.model.layers[i].self_attn
    q_proj_weight, k_proj_weight, v_proj_weight, o_proj_weight = get_qkv_weights(layer)

    q_proj_weight = q_proj_weight.detach().float().to(device)  # Move to GPU
    k_proj_weight = k_proj_weight.detach().float().to(device)  # Move to GPU
    v_proj_weight = v_proj_weight.detach().float().to(device)  # Move to GPU
    o_proj_weight = o_proj_weight.detach().float().to(device)  # Move to GPU

    # Compute SVD without computing gradients
    with torch.no_grad():
        _, S_q, _ = torch.linalg.svd(q_proj_weight, full_matrices=False)
        _, S_k, _ = torch.linalg.svd(k_proj_weight, full_matrices=False)
        _, S_v, _ = torch.linalg.svd(v_proj_weight, full_matrices=False)
        _, S_o, _ = torch.linalg.svd(o_proj_weight, full_matrices=False)

    # Move singular values back to CPU and convert to NumPy for plotting
    S_q = S_q.cpu().numpy()
    S_k = S_k.cpu().numpy()
    S_v = S_v.cpu().numpy()
    S_o = S_o.cpu().numpy()
    # Create a DataFrame for the singular values including 'Output' type
    data = pd.DataFrame({
        'Index': list(range(len(S_q))) + list(range(len(S_k))) + list(range(len(S_v))) + list(range(len(S_o))),
        'Singular Value': list(S_q) + list(S_k) + list(S_v) + list(S_o),
        'Type': ['Query'] * len(S_q) + ['Key'] * len(S_k) + ['Value'] * len(S_v) + ['Output'] * len(S_o),
        'Layer': [i] * (len(S_q) + len(S_k) + len(S_v) + len(S_o))
    })
    all_data.append(data)
100%|██████████| 32/32 [02:51<00:00,  5.35s/it]

Let’s visualize the singular values of Q, K, V, and O projection weights across different layers

# Concatenate all data frames
full_data = pd.concat(all_data, ignore_index=True)

# Filter data to include only the first 4 and last 4 layers
num_layers = full_data['Layer'].max() + 1
selected_layers = list(range(4)) + list(range(num_layers - 4, num_layers))
filtered_data = full_data[full_data['Layer'].isin(selected_layers)]

# Filter data to reduce the number of points
filtered_data = filtered_data[
    (filtered_data['Index'] <= 100) |
    ((filtered_data['Index'] > 100) & (filtered_data['Index'] <= 1000) & (filtered_data['Index'] % 5 == 0)) |
    (filtered_data['Index'] > 1000) & (filtered_data['Index'] % 10 == 0)
]

# Create the Altair chart
chart = alt.Chart(filtered_data).mark_line().encode(
    x='Index:Q',
    y='Singular Value:Q',
    color='Type:N',
    facet=alt.Facet('Layer:N', columns=4),
    tooltip=['Index', 'Singular Value', 'Type', 'Layer']
).properties(
    width=180,
    height=180,
    title='Singular Values of Q, K, V, O Projection Weights for Selected Layers'
)

chart.display()

Now visualize the cumulative energy of singular values across layers

# Calculate cumulative energy for each layer and type
for data in all_data:
    data['Squared Singular Value'] = data['Singular Value'].map(lambda x: x**2)
    total_energy = data.groupby('Type')['Squared Singular Value'].transform('sum')
    cumulative_energy = data.groupby('Type')['Squared Singular Value'].cumsum() / total_energy
    data['Cumulative Energy'] = cumulative_energy

# Concatenate all data frames and filter for the first 4 and last 4 layers
concatenated_data = pd.concat(all_data, ignore_index=True)
num_layers = concatenated_data['Layer'].max() + 1
selected_layers = list(range(4)) + list(range(num_layers - 4, num_layers))
filtered_data = concatenated_data[concatenated_data['Layer'].isin(selected_layers)]
filtered_data = filtered_data[
    (filtered_data['Index'] <= 100) |
    ((filtered_data['Index'] > 100) & (filtered_data['Index'] <= 1000) & (filtered_data['Index'] % 5 == 0)) |
    (filtered_data['Index'] > 1000) & (filtered_data['Index'] % 10 == 0)
]

# Update the Altair chart to include cumulative energy for selected layers
chart = alt.Chart(filtered_data).mark_line().encode(
    x='Index:Q',
    y='Cumulative Energy:Q',
    color='Type:N',
    facet=alt.Facet('Layer:N', columns=4),
    tooltip=['Index', 'Cumulative Energy', 'Type', 'Layer']
).properties(
    width=180,
    height=180,
    title='Cumulative Energy of Singular Values for Selected Layers'
)

chart.display()

Aside from the first few layers, the rest look pretty similar. Let’s see the number of components required to reach different energy thresholds across different layers and types.

# Calculate the number of components needed to reach 60% and 80% energy
energy_thresholds = [0.6, 0.8, 0.9]
threshold_data = []

for data in all_data:
    for threshold in energy_thresholds:
        for type_ in ['Query', 'Key', 'Value', 'Output']:
            type_data = data[data['Type'] == type_].reset_index(drop=True)
            cumulative_energy = type_data['Cumulative Energy']
            # Find the minimum index where cumulative energy exceeds the threshold
            num_components = (cumulative_energy >= threshold).idxmax() + 1
            threshold_data.append({
                'Layer': data['Layer'].iloc[0],
                'Type': type_,
                'Threshold': threshold,
                'Components Needed': num_components
            })

threshold_df = pd.DataFrame(threshold_data)

# Plotting the number of components needed to reach energy thresholds
chart = alt.Chart(threshold_df).mark_line().encode(
    x='Layer:O',
    y='Components Needed:Q',
    color='Type:N',
    column='Threshold:N',
    tooltip=['Layer', 'Type', 'Components Needed', 'Threshold']
).properties(
    width=200,
    height=200,
    title='Number of Components Needed to Reach Energy Thresholds'
)

chart.display()

Observations:

  • The raw dimension of the K&V projection weights is 8x smaller than Q&O, due to GQA. But the difference in effective rank is not as large. Try other models that doesn’t use GQA! (e.g. Phi)

  • The effective rank of the projection weights in Llama-3-8B roughly matches the choice in DeepSeek-V2. DeepSeek-V2 config file

    "kv_lora_rank": 512,
    "q_lora_rank": 1536,

Next steps:

  1. I think the result suggests we can make the first few layers much smaller.

  2. Let’s try visualizing the singular values of the FFN weights and see if we can find a similar pattern.

def get_ffn_weights(layer):
    # this can be different depending on the model
    if model_name == 'meta-llama/Meta-Llama-3-8B-Instruct':
        up_proj_weight = layer.mlp.up_proj.weight
        down_proj_weight = layer.mlp.down_proj.weight
        return up_proj_weight, down_proj_weight

# let's test the function
layer = model.model.layers[0]
up_proj_weight, down_proj_weight = get_ffn_weights(layer)
up_proj_weight.shape, down_proj_weight.shape
(torch.Size([14336, 4096]), torch.Size([4096, 14336]))
import matplotlib.pyplot as plt
import numpy as np
import altair as alt
import pandas as pd

# Enable Altair data transformer
alt.data_transformers.enable("vegafusion")

# Initialize a list to hold all data frames for different layers
all_data = []

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
import torch
from tqdm import tqdm

# Loop over the desired layers
for i in tqdm(range(0, model.config.num_hidden_layers)):
    layer = model.model.layers[i]
    up_proj_weight, down_proj_weight = get_ffn_weights(layer)

    up_proj_weight = up_proj_weight.detach().float().to(device)  # Move to GPU
    down_proj_weight = down_proj_weight.detach().float().to(device)  # Move to GPU

    # Compute SVD without computing gradients
    with torch.no_grad():
        _, S_up, _ = torch.linalg.svd(up_proj_weight, full_matrices=False)
        _, S_down, _ = torch.linalg.svd(down_proj_weight, full_matrices=False)

    # Move singular values back to CPU and convert to NumPy for plotting
    S_up = S_up.cpu().numpy()
    S_down = S_down.cpu().numpy()

    # Create a DataFrame for the singular values including 'Output' type
    data = pd.DataFrame({
        'Index': list(range(len(S_up))) + list(range(len(S_down))),
        'Singular Value': list(S_up) + list(S_down),
        'Type': ['Up'] * len(S_up) + ['Down'] * len(S_down),
        'Layer': [i] * (len(S_up) + len(S_down))
    })
    all_data.append(data)
100%|██████████| 32/32 [02:48<00:00,  5.26s/it]
# Concatenate all data frames
full_data = pd.concat(all_data, ignore_index=True)

# Filter data to include only the first 4 and last 4 layers
num_layers = model.config.num_hidden_layers
selected_layers = list(range(4)) + list(range(num_layers - 4, num_layers))
filtered_data = full_data[full_data['Layer'].isin(selected_layers)]
# Filter data to reduce the number of points
filtered_data = filtered_data[
    (filtered_data['Index'] <= 100) |
    ((filtered_data['Index'] > 100) & (filtered_data['Index'] <= 1000) & (filtered_data['Index'] % 5 == 0)) |
    (filtered_data['Index'] > 1000) & (filtered_data['Index'] % 10 == 0)
]

# Create the Altair chart
chart = alt.Chart(filtered_data).mark_line().encode(
    x='Index:Q',
    y='Singular Value:Q',
    color='Type:N',
    facet=alt.Facet('Layer:N', columns=4),
    tooltip=['Index', 'Singular Value', 'Type', 'Layer']
).properties(
    width=180,
    height=180,
    title='Singular Values of Up and Down Projection Weights for Selected Layers'
)

chart.display()
# Calculate cumulative energy for each layer and type
for data in all_data:
    data['Squared Singular Value'] = data['Singular Value'].map(lambda x: x**2)
    total_energy = data.groupby('Type')['Squared Singular Value'].transform('sum')
    cumulative_energy = data.groupby('Type')['Squared Singular Value'].cumsum() / total_energy
    data['Cumulative Energy'] = cumulative_energy

# Concatenate all data frames and filter for the first 4 and last 4 layers
concatenated_data = pd.concat(all_data, ignore_index=True)
num_layers = concatenated_data['Layer'].max() + 1
selected_layers = list(range(4)) + list(range(num_layers - 4, num_layers))
filtered_data = concatenated_data[concatenated_data['Layer'].isin(selected_layers)]
filtered_data = filtered_data[
    (filtered_data['Index'] <= 100) |
    ((filtered_data['Index'] > 100) & (filtered_data['Index'] <= 1000) & (filtered_data['Index'] % 5 == 0)) |
    (filtered_data['Index'] > 1000) & (filtered_data['Index'] % 10 == 0)
]

# Update the Altair chart to include cumulative energy for selected layers
chart = alt.Chart(filtered_data).mark_line().encode(
    x='Index:Q',
    y='Cumulative Energy:Q',
    color='Type:N',
    facet=alt.Facet('Layer:N', columns=4),
    tooltip=['Index', 'Cumulative Energy', 'Type', 'Layer']
).properties(
    width=180,
    height=180,
    title='Cumulative Energy of Singular Values for Selected Layers'
)

chart.display()
# Calculate the number of components needed to reach 60% and 80% energy
energy_thresholds = [0.6, 0.8, 0.9]
threshold_data = []

for data in all_data:
    for threshold in energy_thresholds:
        for type_ in ['Up', 'Down']:
            type_data = data[data['Type'] == type_].reset_index(drop=True)
            cumulative_energy = type_data['Cumulative Energy']
            # Find the minimum index where cumulative energy exceeds the threshold
            num_components = (cumulative_energy >= threshold).idxmax() + 1
            threshold_data.append({
                'Layer': data['Layer'].iloc[0],
                'Type': type_,
                'Threshold': threshold,
                'Components Needed': num_components
            })

threshold_df = pd.DataFrame(threshold_data)

# Plotting the number of components needed to reach energy thresholds
chart = alt.Chart(threshold_df).mark_line().encode(
    x='Layer:O',
    y='Components Needed:Q',
    color='Type:N',
    column='Threshold:N',
    tooltip=['Layer', 'Type', 'Components Needed', 'Threshold']
).properties(
    width=200,
    height=200,
    title='Number of Components Needed to Reach Energy Thresholds'
)

chart.display()

Observations & Discussions

  • The rank is primarily determined by the model’s dimension, which is 4096 for Llama-3-8B. The intermediate dimension 14336 is much larger.
    • Llama3-8B is heavily over-trained on 15T tokens. Will some other models trained with less token have a smaller rank in some layers?
  • The last few layers’ up_projection weights have a larger first singular values, what’s going on there? (I don’t know)