import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# you need to accept the LLAMA license to use this model
# feel free to try this notebook with another model
= "meta-llama/Meta-Llama-3-8B-Instruct"
model_name = AutoModelForCausalLM.from_pretrained(model_name) model
Introduction
- DeepSeek-V2 introduced Multi-Head Latent Attention (MLA), which uses low rank to compress the KV cache.
- The purpose of this notebook is to explore the effective rank of the projection weights in a pretrained model, using Llama-3-8B as an example.
Disclaimer: Most of the code are written by GPT/Copilot, and is not optimized for presentation.
= model.model.layers[0].self_attn
layer layer
LlamaSdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear(in_features=4096, out_features=1024, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
def get_qkv_weights(layer):
# this can be different depending on the model
if model_name == 'microsoft/Phi-3-mini-128k-instruct':
= layer.num_heads * layer.head_dim
query_pos = slice(0, query_pos)
q_range = slice(query_pos, query_pos + layer.num_key_value_heads * layer.head_dim)
k_range = slice(query_pos + layer.num_key_value_heads * layer.head_dim, None)
v_range = layer.qkv_proj.weight[q_range]
q_proj_weight = layer.qkv_proj.weight[k_range]
k_proj_weight = layer.qkv_proj.weight[v_range]
v_proj_weight = layer.o_proj.weight
o_proj_weight elif model_name == 'microsoft/phi-1_5':
= layer.q_proj.weight
q_proj_weight = layer.k_proj.weight
k_proj_weight = layer.v_proj.weight
v_proj_weight = layer.dense.weight
o_proj_weight elif model_name == 'meta-llama/Meta-Llama-3-8B-Instruct':
= layer.q_proj.weight
q_proj_weight = layer.k_proj.weight
k_proj_weight = layer.v_proj.weight
v_proj_weight = layer.o_proj.weight
o_proj_weight return q_proj_weight, k_proj_weight, v_proj_weight, o_proj_weight
# let's test the function
= get_qkv_weights(layer)
q_proj_weight, k_proj_weight, v_proj_weight, o_proj_weight q_proj_weight.shape, k_proj_weight.shape, v_proj_weight.shape, o_proj_weight.shape
(torch.Size([4096, 4096]),
torch.Size([1024, 4096]),
torch.Size([1024, 4096]),
torch.Size([4096, 4096]))
model.config.num_hidden_layers
32
import matplotlib.pyplot as plt
import numpy as np
import altair as alt
import pandas as pd
import pandas as pd
# Enable Altair data transformer
"vegafusion")
alt.data_transformers.enable(
# Initialize a list to hold all data frames for different layers
= []
all_data
= torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device import torch
from tqdm import tqdm
# Loop over the desired layers
for i in tqdm(range(0, model.config.num_hidden_layers)):
= model.model.layers[i].self_attn
layer = get_qkv_weights(layer)
q_proj_weight, k_proj_weight, v_proj_weight, o_proj_weight
= q_proj_weight.detach().float().to(device) # Move to GPU
q_proj_weight = k_proj_weight.detach().float().to(device) # Move to GPU
k_proj_weight = v_proj_weight.detach().float().to(device) # Move to GPU
v_proj_weight = o_proj_weight.detach().float().to(device) # Move to GPU
o_proj_weight
# Compute SVD without computing gradients
with torch.no_grad():
= torch.linalg.svd(q_proj_weight, full_matrices=False)
_, S_q, _ = torch.linalg.svd(k_proj_weight, full_matrices=False)
_, S_k, _ = torch.linalg.svd(v_proj_weight, full_matrices=False)
_, S_v, _ = torch.linalg.svd(o_proj_weight, full_matrices=False)
_, S_o, _
# Move singular values back to CPU and convert to NumPy for plotting
= S_q.cpu().numpy()
S_q = S_k.cpu().numpy()
S_k = S_v.cpu().numpy()
S_v = S_o.cpu().numpy()
S_o # Create a DataFrame for the singular values including 'Output' type
= pd.DataFrame({
data 'Index': list(range(len(S_q))) + list(range(len(S_k))) + list(range(len(S_v))) + list(range(len(S_o))),
'Singular Value': list(S_q) + list(S_k) + list(S_v) + list(S_o),
'Type': ['Query'] * len(S_q) + ['Key'] * len(S_k) + ['Value'] * len(S_v) + ['Output'] * len(S_o),
'Layer': [i] * (len(S_q) + len(S_k) + len(S_v) + len(S_o))
}) all_data.append(data)
100%|██████████| 32/32 [02:51<00:00, 5.35s/it]
Let’s visualize the singular values of Q, K, V, and O projection weights across different layers
# Concatenate all data frames
= pd.concat(all_data, ignore_index=True)
full_data
# Filter data to include only the first 4 and last 4 layers
= full_data['Layer'].max() + 1
num_layers = list(range(4)) + list(range(num_layers - 4, num_layers))
selected_layers = full_data[full_data['Layer'].isin(selected_layers)]
filtered_data
# Filter data to reduce the number of points
= filtered_data[
filtered_data 'Index'] <= 100) |
(filtered_data['Index'] > 100) & (filtered_data['Index'] <= 1000) & (filtered_data['Index'] % 5 == 0)) |
((filtered_data['Index'] > 1000) & (filtered_data['Index'] % 10 == 0)
(filtered_data[
]
# Create the Altair chart
= alt.Chart(filtered_data).mark_line().encode(
chart ='Index:Q',
x='Singular Value:Q',
y='Type:N',
color=alt.Facet('Layer:N', columns=4),
facet=['Index', 'Singular Value', 'Type', 'Layer']
tooltip
).properties(=180,
width=180,
height='Singular Values of Q, K, V, O Projection Weights for Selected Layers'
title
)
chart.display()
Now visualize the cumulative energy of singular values across layers
# Calculate cumulative energy for each layer and type
for data in all_data:
'Squared Singular Value'] = data['Singular Value'].map(lambda x: x**2)
data[= data.groupby('Type')['Squared Singular Value'].transform('sum')
total_energy = data.groupby('Type')['Squared Singular Value'].cumsum() / total_energy
cumulative_energy 'Cumulative Energy'] = cumulative_energy
data[
# Concatenate all data frames and filter for the first 4 and last 4 layers
= pd.concat(all_data, ignore_index=True)
concatenated_data = concatenated_data['Layer'].max() + 1
num_layers = list(range(4)) + list(range(num_layers - 4, num_layers))
selected_layers = concatenated_data[concatenated_data['Layer'].isin(selected_layers)]
filtered_data = filtered_data[
filtered_data 'Index'] <= 100) |
(filtered_data['Index'] > 100) & (filtered_data['Index'] <= 1000) & (filtered_data['Index'] % 5 == 0)) |
((filtered_data['Index'] > 1000) & (filtered_data['Index'] % 10 == 0)
(filtered_data[
]
# Update the Altair chart to include cumulative energy for selected layers
= alt.Chart(filtered_data).mark_line().encode(
chart ='Index:Q',
x='Cumulative Energy:Q',
y='Type:N',
color=alt.Facet('Layer:N', columns=4),
facet=['Index', 'Cumulative Energy', 'Type', 'Layer']
tooltip
).properties(=180,
width=180,
height='Cumulative Energy of Singular Values for Selected Layers'
title
)
chart.display()
Aside from the first few layers, the rest look pretty similar. Let’s see the number of components required to reach different energy thresholds across different layers and types.
# Calculate the number of components needed to reach 60% and 80% energy
= [0.6, 0.8, 0.9]
energy_thresholds = []
threshold_data
for data in all_data:
for threshold in energy_thresholds:
for type_ in ['Query', 'Key', 'Value', 'Output']:
= data[data['Type'] == type_].reset_index(drop=True)
type_data = type_data['Cumulative Energy']
cumulative_energy # Find the minimum index where cumulative energy exceeds the threshold
= (cumulative_energy >= threshold).idxmax() + 1
num_components
threshold_data.append({'Layer': data['Layer'].iloc[0],
'Type': type_,
'Threshold': threshold,
'Components Needed': num_components
})
= pd.DataFrame(threshold_data)
threshold_df
# Plotting the number of components needed to reach energy thresholds
= alt.Chart(threshold_df).mark_line().encode(
chart ='Layer:O',
x='Components Needed:Q',
y='Type:N',
color='Threshold:N',
column=['Layer', 'Type', 'Components Needed', 'Threshold']
tooltip
).properties(=200,
width=200,
height='Number of Components Needed to Reach Energy Thresholds'
title
)
chart.display()
Observations:
The raw dimension of the K&V projection weights is 8x smaller than Q&O, due to GQA. But the difference in effective rank is not as large. Try other models that doesn’t use GQA! (e.g. Phi)
The effective rank of the projection weights in Llama-3-8B roughly matches the choice in DeepSeek-V2. DeepSeek-V2 config file
"kv_lora_rank": 512, "q_lora_rank": 1536,
Next steps:
I think the result suggests we can make the first few layers much smaller.
Let’s try visualizing the singular values of the FFN weights and see if we can find a similar pattern.
def get_ffn_weights(layer):
# this can be different depending on the model
if model_name == 'meta-llama/Meta-Llama-3-8B-Instruct':
= layer.mlp.up_proj.weight
up_proj_weight = layer.mlp.down_proj.weight
down_proj_weight return up_proj_weight, down_proj_weight
# let's test the function
= model.model.layers[0]
layer = get_ffn_weights(layer)
up_proj_weight, down_proj_weight up_proj_weight.shape, down_proj_weight.shape
(torch.Size([14336, 4096]), torch.Size([4096, 14336]))
import matplotlib.pyplot as plt
import numpy as np
import altair as alt
import pandas as pd
# Enable Altair data transformer
"vegafusion")
alt.data_transformers.enable(
# Initialize a list to hold all data frames for different layers
= []
all_data
= torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device import torch
from tqdm import tqdm
# Loop over the desired layers
for i in tqdm(range(0, model.config.num_hidden_layers)):
= model.model.layers[i]
layer = get_ffn_weights(layer)
up_proj_weight, down_proj_weight
= up_proj_weight.detach().float().to(device) # Move to GPU
up_proj_weight = down_proj_weight.detach().float().to(device) # Move to GPU
down_proj_weight
# Compute SVD without computing gradients
with torch.no_grad():
= torch.linalg.svd(up_proj_weight, full_matrices=False)
_, S_up, _ = torch.linalg.svd(down_proj_weight, full_matrices=False)
_, S_down, _
# Move singular values back to CPU and convert to NumPy for plotting
= S_up.cpu().numpy()
S_up = S_down.cpu().numpy()
S_down
# Create a DataFrame for the singular values including 'Output' type
= pd.DataFrame({
data 'Index': list(range(len(S_up))) + list(range(len(S_down))),
'Singular Value': list(S_up) + list(S_down),
'Type': ['Up'] * len(S_up) + ['Down'] * len(S_down),
'Layer': [i] * (len(S_up) + len(S_down))
}) all_data.append(data)
100%|██████████| 32/32 [02:48<00:00, 5.26s/it]
# Concatenate all data frames
= pd.concat(all_data, ignore_index=True)
full_data
# Filter data to include only the first 4 and last 4 layers
= model.config.num_hidden_layers
num_layers = list(range(4)) + list(range(num_layers - 4, num_layers))
selected_layers = full_data[full_data['Layer'].isin(selected_layers)]
filtered_data # Filter data to reduce the number of points
= filtered_data[
filtered_data 'Index'] <= 100) |
(filtered_data['Index'] > 100) & (filtered_data['Index'] <= 1000) & (filtered_data['Index'] % 5 == 0)) |
((filtered_data['Index'] > 1000) & (filtered_data['Index'] % 10 == 0)
(filtered_data[
]
# Create the Altair chart
= alt.Chart(filtered_data).mark_line().encode(
chart ='Index:Q',
x='Singular Value:Q',
y='Type:N',
color=alt.Facet('Layer:N', columns=4),
facet=['Index', 'Singular Value', 'Type', 'Layer']
tooltip
).properties(=180,
width=180,
height='Singular Values of Up and Down Projection Weights for Selected Layers'
title
)
chart.display()
# Calculate cumulative energy for each layer and type
for data in all_data:
'Squared Singular Value'] = data['Singular Value'].map(lambda x: x**2)
data[= data.groupby('Type')['Squared Singular Value'].transform('sum')
total_energy = data.groupby('Type')['Squared Singular Value'].cumsum() / total_energy
cumulative_energy 'Cumulative Energy'] = cumulative_energy
data[
# Concatenate all data frames and filter for the first 4 and last 4 layers
= pd.concat(all_data, ignore_index=True)
concatenated_data = concatenated_data['Layer'].max() + 1
num_layers = list(range(4)) + list(range(num_layers - 4, num_layers))
selected_layers = concatenated_data[concatenated_data['Layer'].isin(selected_layers)]
filtered_data = filtered_data[
filtered_data 'Index'] <= 100) |
(filtered_data['Index'] > 100) & (filtered_data['Index'] <= 1000) & (filtered_data['Index'] % 5 == 0)) |
((filtered_data['Index'] > 1000) & (filtered_data['Index'] % 10 == 0)
(filtered_data[
]
# Update the Altair chart to include cumulative energy for selected layers
= alt.Chart(filtered_data).mark_line().encode(
chart ='Index:Q',
x='Cumulative Energy:Q',
y='Type:N',
color=alt.Facet('Layer:N', columns=4),
facet=['Index', 'Cumulative Energy', 'Type', 'Layer']
tooltip
).properties(=180,
width=180,
height='Cumulative Energy of Singular Values for Selected Layers'
title
)
chart.display()
# Calculate the number of components needed to reach 60% and 80% energy
= [0.6, 0.8, 0.9]
energy_thresholds = []
threshold_data
for data in all_data:
for threshold in energy_thresholds:
for type_ in ['Up', 'Down']:
= data[data['Type'] == type_].reset_index(drop=True)
type_data = type_data['Cumulative Energy']
cumulative_energy # Find the minimum index where cumulative energy exceeds the threshold
= (cumulative_energy >= threshold).idxmax() + 1
num_components
threshold_data.append({'Layer': data['Layer'].iloc[0],
'Type': type_,
'Threshold': threshold,
'Components Needed': num_components
})
= pd.DataFrame(threshold_data)
threshold_df
# Plotting the number of components needed to reach energy thresholds
= alt.Chart(threshold_df).mark_line().encode(
chart ='Layer:O',
x='Components Needed:Q',
y='Type:N',
color='Threshold:N',
column=['Layer', 'Type', 'Components Needed', 'Threshold']
tooltip
).properties(=200,
width=200,
height='Number of Components Needed to Reach Energy Thresholds'
title
)
chart.display()
Observations & Discussions
- The rank is primarily determined by the model’s dimension, which is 4096 for Llama-3-8B. The intermediate dimension 14336 is much larger.
- Llama3-8B is heavily over-trained on 15T tokens. Will some other models trained with less token have a smaller rank in some layers?
- The last few layers’ up_projection weights have a larger first singular values, what’s going on there? (I don’t know)