import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# you need to accept the LLAMA license to use this model
# feel free to try this notebook with another model
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name)Introduction
- DeepSeek-V2 introduced Multi-Head Latent Attention (MLA), which uses low rank to compress the KV cache.
- The purpose of this notebook is to explore the effective rank of the projection weights in a pretrained model, using Llama-3-8B as an example.
Disclaimer: Most of the code are written by GPT/Copilot, and is not optimized for presentation.
layer = model.model.layers[0].self_attn
layerLlamaSdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear(in_features=4096, out_features=1024, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
def get_qkv_weights(layer):
# this can be different depending on the model
if model_name == 'microsoft/Phi-3-mini-128k-instruct':
query_pos = layer.num_heads * layer.head_dim
q_range = slice(0, query_pos)
k_range = slice(query_pos, query_pos + layer.num_key_value_heads * layer.head_dim)
v_range = slice(query_pos + layer.num_key_value_heads * layer.head_dim, None)
q_proj_weight = layer.qkv_proj.weight[q_range]
k_proj_weight = layer.qkv_proj.weight[k_range]
v_proj_weight = layer.qkv_proj.weight[v_range]
o_proj_weight = layer.o_proj.weight
elif model_name == 'microsoft/phi-1_5':
q_proj_weight = layer.q_proj.weight
k_proj_weight = layer.k_proj.weight
v_proj_weight = layer.v_proj.weight
o_proj_weight = layer.dense.weight
elif model_name == 'meta-llama/Meta-Llama-3-8B-Instruct':
q_proj_weight = layer.q_proj.weight
k_proj_weight = layer.k_proj.weight
v_proj_weight = layer.v_proj.weight
o_proj_weight = layer.o_proj.weight
return q_proj_weight, k_proj_weight, v_proj_weight, o_proj_weight
# let's test the function
q_proj_weight, k_proj_weight, v_proj_weight, o_proj_weight = get_qkv_weights(layer)
q_proj_weight.shape, k_proj_weight.shape, v_proj_weight.shape, o_proj_weight.shape(torch.Size([4096, 4096]),
torch.Size([1024, 4096]),
torch.Size([1024, 4096]),
torch.Size([4096, 4096]))
model.config.num_hidden_layers32
import matplotlib.pyplot as plt
import numpy as np
import altair as alt
import pandas as pd
import pandas as pd
# Enable Altair data transformer
alt.data_transformers.enable("vegafusion")
# Initialize a list to hold all data frames for different layers
all_data = []
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
import torch
from tqdm import tqdm
# Loop over the desired layers
for i in tqdm(range(0, model.config.num_hidden_layers)):
layer = model.model.layers[i].self_attn
q_proj_weight, k_proj_weight, v_proj_weight, o_proj_weight = get_qkv_weights(layer)
q_proj_weight = q_proj_weight.detach().float().to(device) # Move to GPU
k_proj_weight = k_proj_weight.detach().float().to(device) # Move to GPU
v_proj_weight = v_proj_weight.detach().float().to(device) # Move to GPU
o_proj_weight = o_proj_weight.detach().float().to(device) # Move to GPU
# Compute SVD without computing gradients
with torch.no_grad():
_, S_q, _ = torch.linalg.svd(q_proj_weight, full_matrices=False)
_, S_k, _ = torch.linalg.svd(k_proj_weight, full_matrices=False)
_, S_v, _ = torch.linalg.svd(v_proj_weight, full_matrices=False)
_, S_o, _ = torch.linalg.svd(o_proj_weight, full_matrices=False)
# Move singular values back to CPU and convert to NumPy for plotting
S_q = S_q.cpu().numpy()
S_k = S_k.cpu().numpy()
S_v = S_v.cpu().numpy()
S_o = S_o.cpu().numpy()
# Create a DataFrame for the singular values including 'Output' type
data = pd.DataFrame({
'Index': list(range(len(S_q))) + list(range(len(S_k))) + list(range(len(S_v))) + list(range(len(S_o))),
'Singular Value': list(S_q) + list(S_k) + list(S_v) + list(S_o),
'Type': ['Query'] * len(S_q) + ['Key'] * len(S_k) + ['Value'] * len(S_v) + ['Output'] * len(S_o),
'Layer': [i] * (len(S_q) + len(S_k) + len(S_v) + len(S_o))
})
all_data.append(data)100%|██████████| 32/32 [02:51<00:00, 5.35s/it]
Let’s visualize the singular values of Q, K, V, and O projection weights across different layers
# Concatenate all data frames
full_data = pd.concat(all_data, ignore_index=True)
# Filter data to include only the first 4 and last 4 layers
num_layers = full_data['Layer'].max() + 1
selected_layers = list(range(4)) + list(range(num_layers - 4, num_layers))
filtered_data = full_data[full_data['Layer'].isin(selected_layers)]
# Filter data to reduce the number of points
filtered_data = filtered_data[
(filtered_data['Index'] <= 100) |
((filtered_data['Index'] > 100) & (filtered_data['Index'] <= 1000) & (filtered_data['Index'] % 5 == 0)) |
(filtered_data['Index'] > 1000) & (filtered_data['Index'] % 10 == 0)
]
# Create the Altair chart
chart = alt.Chart(filtered_data).mark_line().encode(
x='Index:Q',
y='Singular Value:Q',
color='Type:N',
facet=alt.Facet('Layer:N', columns=4),
tooltip=['Index', 'Singular Value', 'Type', 'Layer']
).properties(
width=180,
height=180,
title='Singular Values of Q, K, V, O Projection Weights for Selected Layers'
)
chart.display()Now visualize the cumulative energy of singular values across layers
# Calculate cumulative energy for each layer and type
for data in all_data:
data['Squared Singular Value'] = data['Singular Value'].map(lambda x: x**2)
total_energy = data.groupby('Type')['Squared Singular Value'].transform('sum')
cumulative_energy = data.groupby('Type')['Squared Singular Value'].cumsum() / total_energy
data['Cumulative Energy'] = cumulative_energy
# Concatenate all data frames and filter for the first 4 and last 4 layers
concatenated_data = pd.concat(all_data, ignore_index=True)
num_layers = concatenated_data['Layer'].max() + 1
selected_layers = list(range(4)) + list(range(num_layers - 4, num_layers))
filtered_data = concatenated_data[concatenated_data['Layer'].isin(selected_layers)]
filtered_data = filtered_data[
(filtered_data['Index'] <= 100) |
((filtered_data['Index'] > 100) & (filtered_data['Index'] <= 1000) & (filtered_data['Index'] % 5 == 0)) |
(filtered_data['Index'] > 1000) & (filtered_data['Index'] % 10 == 0)
]
# Update the Altair chart to include cumulative energy for selected layers
chart = alt.Chart(filtered_data).mark_line().encode(
x='Index:Q',
y='Cumulative Energy:Q',
color='Type:N',
facet=alt.Facet('Layer:N', columns=4),
tooltip=['Index', 'Cumulative Energy', 'Type', 'Layer']
).properties(
width=180,
height=180,
title='Cumulative Energy of Singular Values for Selected Layers'
)
chart.display()Aside from the first few layers, the rest look pretty similar. Let’s see the number of components required to reach different energy thresholds across different layers and types.
# Calculate the number of components needed to reach 60% and 80% energy
energy_thresholds = [0.6, 0.8, 0.9]
threshold_data = []
for data in all_data:
for threshold in energy_thresholds:
for type_ in ['Query', 'Key', 'Value', 'Output']:
type_data = data[data['Type'] == type_].reset_index(drop=True)
cumulative_energy = type_data['Cumulative Energy']
# Find the minimum index where cumulative energy exceeds the threshold
num_components = (cumulative_energy >= threshold).idxmax() + 1
threshold_data.append({
'Layer': data['Layer'].iloc[0],
'Type': type_,
'Threshold': threshold,
'Components Needed': num_components
})
threshold_df = pd.DataFrame(threshold_data)
# Plotting the number of components needed to reach energy thresholds
chart = alt.Chart(threshold_df).mark_line().encode(
x='Layer:O',
y='Components Needed:Q',
color='Type:N',
column='Threshold:N',
tooltip=['Layer', 'Type', 'Components Needed', 'Threshold']
).properties(
width=200,
height=200,
title='Number of Components Needed to Reach Energy Thresholds'
)
chart.display()Observations:
The raw dimension of the K&V projection weights is 8x smaller than Q&O, due to GQA. But the difference in effective rank is not as large. Try other models that doesn’t use GQA! (e.g. Phi)
The effective rank of the projection weights in Llama-3-8B roughly matches the choice in DeepSeek-V2. DeepSeek-V2 config file
"kv_lora_rank": 512, "q_lora_rank": 1536,
Next steps:
I think the result suggests we can make the first few layers much smaller.
Let’s try visualizing the singular values of the FFN weights and see if we can find a similar pattern.
def get_ffn_weights(layer):
# this can be different depending on the model
if model_name == 'meta-llama/Meta-Llama-3-8B-Instruct':
up_proj_weight = layer.mlp.up_proj.weight
down_proj_weight = layer.mlp.down_proj.weight
return up_proj_weight, down_proj_weight
# let's test the function
layer = model.model.layers[0]
up_proj_weight, down_proj_weight = get_ffn_weights(layer)
up_proj_weight.shape, down_proj_weight.shape(torch.Size([14336, 4096]), torch.Size([4096, 14336]))
import matplotlib.pyplot as plt
import numpy as np
import altair as alt
import pandas as pd
# Enable Altair data transformer
alt.data_transformers.enable("vegafusion")
# Initialize a list to hold all data frames for different layers
all_data = []
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
import torch
from tqdm import tqdm
# Loop over the desired layers
for i in tqdm(range(0, model.config.num_hidden_layers)):
layer = model.model.layers[i]
up_proj_weight, down_proj_weight = get_ffn_weights(layer)
up_proj_weight = up_proj_weight.detach().float().to(device) # Move to GPU
down_proj_weight = down_proj_weight.detach().float().to(device) # Move to GPU
# Compute SVD without computing gradients
with torch.no_grad():
_, S_up, _ = torch.linalg.svd(up_proj_weight, full_matrices=False)
_, S_down, _ = torch.linalg.svd(down_proj_weight, full_matrices=False)
# Move singular values back to CPU and convert to NumPy for plotting
S_up = S_up.cpu().numpy()
S_down = S_down.cpu().numpy()
# Create a DataFrame for the singular values including 'Output' type
data = pd.DataFrame({
'Index': list(range(len(S_up))) + list(range(len(S_down))),
'Singular Value': list(S_up) + list(S_down),
'Type': ['Up'] * len(S_up) + ['Down'] * len(S_down),
'Layer': [i] * (len(S_up) + len(S_down))
})
all_data.append(data)100%|██████████| 32/32 [02:48<00:00, 5.26s/it]
# Concatenate all data frames
full_data = pd.concat(all_data, ignore_index=True)
# Filter data to include only the first 4 and last 4 layers
num_layers = model.config.num_hidden_layers
selected_layers = list(range(4)) + list(range(num_layers - 4, num_layers))
filtered_data = full_data[full_data['Layer'].isin(selected_layers)]
# Filter data to reduce the number of points
filtered_data = filtered_data[
(filtered_data['Index'] <= 100) |
((filtered_data['Index'] > 100) & (filtered_data['Index'] <= 1000) & (filtered_data['Index'] % 5 == 0)) |
(filtered_data['Index'] > 1000) & (filtered_data['Index'] % 10 == 0)
]
# Create the Altair chart
chart = alt.Chart(filtered_data).mark_line().encode(
x='Index:Q',
y='Singular Value:Q',
color='Type:N',
facet=alt.Facet('Layer:N', columns=4),
tooltip=['Index', 'Singular Value', 'Type', 'Layer']
).properties(
width=180,
height=180,
title='Singular Values of Up and Down Projection Weights for Selected Layers'
)
chart.display()# Calculate cumulative energy for each layer and type
for data in all_data:
data['Squared Singular Value'] = data['Singular Value'].map(lambda x: x**2)
total_energy = data.groupby('Type')['Squared Singular Value'].transform('sum')
cumulative_energy = data.groupby('Type')['Squared Singular Value'].cumsum() / total_energy
data['Cumulative Energy'] = cumulative_energy
# Concatenate all data frames and filter for the first 4 and last 4 layers
concatenated_data = pd.concat(all_data, ignore_index=True)
num_layers = concatenated_data['Layer'].max() + 1
selected_layers = list(range(4)) + list(range(num_layers - 4, num_layers))
filtered_data = concatenated_data[concatenated_data['Layer'].isin(selected_layers)]
filtered_data = filtered_data[
(filtered_data['Index'] <= 100) |
((filtered_data['Index'] > 100) & (filtered_data['Index'] <= 1000) & (filtered_data['Index'] % 5 == 0)) |
(filtered_data['Index'] > 1000) & (filtered_data['Index'] % 10 == 0)
]
# Update the Altair chart to include cumulative energy for selected layers
chart = alt.Chart(filtered_data).mark_line().encode(
x='Index:Q',
y='Cumulative Energy:Q',
color='Type:N',
facet=alt.Facet('Layer:N', columns=4),
tooltip=['Index', 'Cumulative Energy', 'Type', 'Layer']
).properties(
width=180,
height=180,
title='Cumulative Energy of Singular Values for Selected Layers'
)
chart.display()# Calculate the number of components needed to reach 60% and 80% energy
energy_thresholds = [0.6, 0.8, 0.9]
threshold_data = []
for data in all_data:
for threshold in energy_thresholds:
for type_ in ['Up', 'Down']:
type_data = data[data['Type'] == type_].reset_index(drop=True)
cumulative_energy = type_data['Cumulative Energy']
# Find the minimum index where cumulative energy exceeds the threshold
num_components = (cumulative_energy >= threshold).idxmax() + 1
threshold_data.append({
'Layer': data['Layer'].iloc[0],
'Type': type_,
'Threshold': threshold,
'Components Needed': num_components
})
threshold_df = pd.DataFrame(threshold_data)
# Plotting the number of components needed to reach energy thresholds
chart = alt.Chart(threshold_df).mark_line().encode(
x='Layer:O',
y='Components Needed:Q',
color='Type:N',
column='Threshold:N',
tooltip=['Layer', 'Type', 'Components Needed', 'Threshold']
).properties(
width=200,
height=200,
title='Number of Components Needed to Reach Energy Thresholds'
)
chart.display()Observations & Discussions
- The rank is primarily determined by the model’s dimension, which is 4096 for Llama-3-8B. The intermediate dimension 14336 is much larger.
- Llama3-8B is heavily over-trained on 15T tokens. Will some other models trained with less token have a smaller rank in some layers?
- The last few layers’ up_projection weights have a larger first singular values, what’s going on there? (I don’t know)