13 (Appendix). Parameter-Efficient Fine-Tuning beyond LoRA#
⚡Compute Note: I recommend running this notebook on a node with 1x H200 GPU.
This appendix surveys several extensions to Low-Rank Adaptation (LoRA) that have recently appeared in the parameter-efficient fine-tuning (PEFT) literature. We focus on methods that are implemented in 🤗 PEFT so that they are easy to experiment with. The goal is to understand why each variant was introduced and to provide a lightweight, reproducible comparison of their trainable parameter counts on a small language model.
Recap: What LoRA optimises#
LoRA proposes to freeze the pretrained weights \(W_0\) of a linear layer and to learn a low-rank update \(\Delta W = BA\) with \(A \in \mathbb{R}^{r \times d_{\text{in}}}\) and \(B \in \mathbb{R}^{d_{\text{out}} \times r}\). The forward pass becomes
where \(r\) is the rank hyper-parameter and \(\alpha\) rescales the update. Optimising only \(A\) and \(B\) drastically reduces the number of trainable parameters compared to the dense matrix \(W_0\).
LoRA extensions covered here#
Below is a concise overview of the LoRA-inspired variants explored in this appendix. Each approach introduces an additional inductive bias on top of the low-rank structure.
Method |
Main idea |
Extra knobs |
Intuition |
|---|---|---|---|
AdaLoRA (Adaptive LoRA) |
Dynamically reallocates rank budget during training |
\((r_t)\) schedule, importance metrics |
Concentrate capacity on the most useful layers. |
LoHa (Low-Rank Hypercomplex Adapter) |
Factorises updates into hypercomplex components |
Hypercomplex multiplier \(h\) |
Couples channel interactions more expressively than purely real matrices. |
DoRA (Weight-Decomposed LoRA) |
Separates weight direction and magnitude |
Scaling vector rank |
Learns a rank-one rescaler in addition to low-rank direction updates. |
HRA (Householder Reflection Adaptation) |
Uses orthogonal Householder reflections |
# reflections \(m\) |
Enforces near-orthogonal updates that preserve norms. |
LoRA+ / other tweaks |
Better initialisation, merged optimisers, dropout tricks |
None |
Often complementary to the above and compatible with PEFT APIs. |
1. AdaLoRA#
AdaLoRA augments LoRA with a learnable schedule for the ranks \(r_t\) of each adapter. It starts with a small \(r\) and gradually increases or prunes ranks based on the importance of singular values, measured by the magnitude of gradient statistics. The optimisation objective for a weight matrix \(W\) becomes
where \(r_t\) evolves during training. The PEFT implementation automates this via a per-layer scheduler that redistributes a global rank budget subject to user-defined milestones \(t_{\text{init}}, t_{\text{final}}\).
2. LoHa#
LoHa (Low-Rank Hadamard Adaptation) generalises the low-rank factors to hadamard multiplications. Instead of \(BA\), LoHa forms updates using \(h\) shared components \(\{A^{(k)}, B^{(k)}\}_{k=1}^h\) and fuses them with learnable phase matrices \(\Phi^{(k)}\). This can be interpreted as a structured Kronecker product:
where \(R^{(k)}\) and \(L^{(k)}\) encode the rotations. In practice LoHa drops into the same API but exposes a multiplier \(h\) that controls how many components are used.
3. DoRA#
DoRA (Weight-Decomposed Low-Rank Adaptation) splits each pretrained weight into a direction \(\hat{W}\) and a scalar magnitude \(s = \lVert W \rVert\). DoRA then adapts only the direction with low-rank matrices \((A, B)\) while a separate low-rank adapter learns multiplicative scaling factors \(d\):
This decoupling improves stability when the base model contains layers with vastly different scales.
4. Householder Reflection Adaptation (HRA)#
HRA constrains the update matrix to be a product of Householder reflections, each of which is an orthogonal matrix \(H_i = I - 2 \frac{v_i v_i^\top}{\lVert v_i \rVert^2}\). Composing \(m\) reflections yields an orthogonal transform \(Q = \prod_{i=1}^{m} H_i\) that preserves norms. HRA parameterises \(Q\) via the \(v_i\) vectors and learns a small diagonal scaling \(D\), leading to
By construction, HRA focuses on rotation-like adaptations, which can be beneficial when invariance to norm changes is desirable.
Other practical tweaks#
LoRA+ initialises the \(A\) and \(B\) factors with different learning rates to stabilise early updates.
Dynamic scaling / RsLoRA rescales \(B\) on the fly to keep the spectral norm bounded.
AdapterDrop randomly drops LoRA adapters during training as a regulariser.
These refinements are orthogonal to the structural extensions above and are usually exposed as configuration flags (for example use_rslora=True).
Lightweight empirical comparison#
The next cells instantiate adapters for a tiny OPT language model to compare how many parameters each method trains. The point is not to run a full fine-tuning loop but to sanity-check that the adapters are easy to attach and that the parameter budgets differ as expected.
from __future__ import annotations
import inspect
from dataclasses import dataclass
from typing import Any, Dict
import pandas as pd
import torch
from transformers import AutoModelForCausalLM
from peft import (
AdaLoraConfig,
HRAConfig,
LoHaConfig,
LoraConfig,
TaskType,
get_peft_model,
)
BASE_MODEL = "EleutherAI/pythia-70m-deduped"
TARGET_MODULES = ["query_key_value", "dense"]
@dataclass
class AdapterResult:
name: str
trainable_params: int
total_params: int
percent_trainable: float
def prepare_config(config_cls, *, extra: Dict[str, Any] | None = None, **common_kwargs: Any):
extra = extra or {}
sig = inspect.signature(config_cls.__init__)
kwargs = {}
for name in sig.parameters:
if name == "self":
continue
if name in extra:
kwargs[name] = extra[name]
elif name in common_kwargs:
kwargs[name] = common_kwargs[name]
return config_cls(**kwargs)
def count_trainable_parameters(model: torch.nn.Module) -> int:
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def count_total_parameters(model: torch.nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def evaluate_adapter(config_cls, name: str, *, extra: Dict[str, Any] | None = None) -> AdapterResult:
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
model.requires_grad_(False)
config = prepare_config(
config_cls,
extra=extra,
r=8,
lora_alpha=16,
lora_dropout=0.05,
target_modules=TARGET_MODULES,
task_type=TaskType.CAUSAL_LM,
)
peft_model = get_peft_model(model, config)
trainable = count_trainable_parameters(peft_model)
total = count_total_parameters(peft_model)
return AdapterResult(name, trainable, total, 100.0 * trainable / total)
results = [
evaluate_adapter(LoraConfig, "LoRA"),
evaluate_adapter(
AdaLoraConfig,
"AdaLoRA",
extra=dict(init_r=6, target_r=12, beta1=0.85, beta2=0.85, tinit=10, tfinal=50, delta_t=10, total_step=1000),
),
evaluate_adapter(LoHaConfig, "LoHa"),
evaluate_adapter(LoraConfig, "DoRA", extra=dict(use_dora=True)),
evaluate_adapter(HRAConfig, "HRA", extra=dict(num_householder_blocks=2)),
]
df = pd.DataFrame(
{
"Adapter": [r.name for r in results],
"Trainable params": [r.trainable_params for r in results],
"Total params": [r.total_params for r in results],
"% trainable": [f"{r.percent_trainable:.4f}%" for r in results],
}
)
df
| Adapter | Trainable params | Total params | % trainable | |
|---|---|---|---|---|
| 0 | LoRA | 147456 | 70574080 | 0.2089% |
| 1 | AdaLoRA | 110664 | 70537300 | 0.1569% |
| 2 | LoHa | 294912 | 70721536 | 0.4170% |
| 3 | DoRA | 159744 | 70586368 | 0.2263% |
| 4 | HRA | 49152 | 70475776 | 0.0697% |
The table above shows that all variants update only a small fraction of the model compared to full fine-tuning. AdaLoRA typically allocates a few more parameters because it keeps extra buffers for rank reallocation, while LoHa and HRA introduce structured factors that slightly increase the footprint. DoRA tracks both directional and scaling adapters, leading to a modest increase as well.
Comparing adapter parameter counts#
To make the comparison more explicit, the following bar plot visualises the trainable parameter budgets. The differences remain tiny relative to the full model, but they highlight which methods introduce additional learnable tensors beyond the vanilla LoRA factors.
import plotly.express as px
fig = px.bar(df, x="Adapter", y="Trainable params", text="% trainable", title="Trainable parameters per adapter type")
fig.update_traces(textposition="outside")
fig.update_layout(yaxis_title="Parameters", xaxis_title="Adapter")
fig
# -*- coding: utf-8 -*-
from __future__ import annotations
import os, math, inspect, traceback
from dataclasses import dataclass
from typing import Any, Dict, List
import torch
import pandas as pd
from datasets import load_dataset
from packaging import version
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Trainer,
TrainingArguments,
set_seed,
__version__ as hf_ver,
)
from peft import (
LoraConfig,
AdaLoraConfig,
LoHaConfig,
HRAConfig,
TaskType,
get_peft_model,
)
# ------------------------------------------------------------------
# Env tweaks
# ------------------------------------------------------------------
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# ------------------------------------------------------------------
# Base / target modules
# ------------------------------------------------------------------
BASE_MODEL = "EleutherAI/pythia-70m-deduped"
TARGET_MODULES = ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"]
# ------------------------------------------------------------------
# Training budget (identical across adapters)
# ------------------------------------------------------------------
MAX_STEPS = 300
LR = 1e-4
BATCH_SIZE = 4
GRAD_ACCUM = 2
WARMUP_STEPS = 30
MAX_LEN = 512
EVAL_EVERY = 100
SEED = 42
set_seed(SEED)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ------------------------------------------------------------------
# Utilities
# ------------------------------------------------------------------
@dataclass
class AdapterSize:
name: str
trainable_params: int
total_params: int
percent_trainable: float
def prepare_config(config_cls, *, extra: Dict[str, Any] | None = None, **common_kwargs: Any):
extra = extra or {}
sig = inspect.signature(config_cls.__init__)
kwargs = {}
for name in sig.parameters:
if name == "self":
continue
if name in extra:
kwargs[name] = extra[name]
elif name in common_kwargs:
kwargs[name] = common_kwargs[name]
return config_cls(**kwargs)
def count_trainable_parameters(model: torch.nn.Module) -> int:
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def count_total_parameters(model: torch.nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def get_tokenizer():
tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
return tok
tokenizer = get_tokenizer()
# ------------------------------------------------------------------
# Collator
# ------------------------------------------------------------------
from dataclasses import dataclass as _dataclass
@_dataclass
class SFTDataCollator:
tokenizer: Any
label_pad_token_id: int = -100
def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
labels = [f["labels"] for f in features]
feats_wo_labels = [{k: v for k, v in f.items() if k != "labels"} for f in features]
batch = self.tokenizer.pad(feats_wo_labels, padding=True, return_tensors="pt")
max_len = batch["input_ids"].size(1)
padded_labels = []
for l in labels:
if len(l) < max_len:
l = l + [self.label_pad_token_id] * (max_len - len(l))
else:
l = l[:max_len]
padded_labels.append(l)
batch["labels"] = torch.tensor(padded_labels, dtype=torch.long)
return batch
# ------------------------------------------------------------------
# Dataset (DROP 'prompt'/'output' BEFORE collator sees them)
# ------------------------------------------------------------------
def build_sft_splits(tokenizer, max_len=MAX_LEN, train_n=1000, eval_n=200):
ds = load_dataset("yahma/alpaca-cleaned")
pool = ds["train"].shuffle(seed=SEED).select(range(train_n + eval_n))
eval_raw = pool.select(range(eval_n))
train_raw = pool.select(range(eval_n, eval_n + train_n))
def format_example(ex):
instr = (ex.get("instruction") or "").strip()
inp = (ex.get("input") or "").strip()
out = (ex.get("output") or "").strip()
if inp:
prompt = f"### Instruction:\n{instr}\n\n### Input:\n{inp}\n\n### Response:\n"
else:
prompt = f"### Instruction:\n{instr}\n\n### Response:\n"
return {"prompt": prompt, "output": out}
train_raw = train_raw.map(format_example, remove_columns=train_raw.column_names)
eval_raw = eval_raw.map(format_example, remove_columns=eval_raw.column_names)
def tok_id(s: str):
return tokenizer(s, add_special_tokens=False)["input_ids"]
def build_io(example):
p_ids = tok_id(example["prompt"])
o_ids = tok_id(example["output"])
if len(o_ids) == 0:
return {"discard": True}
# ensure at least 1 supervised token
if len(p_ids) > max_len - 1:
p_ids = p_ids[: max_len - 1]
room = max_len - len(p_ids)
if room <= 0:
return {"discard": True}
o_ids = o_ids[:room]
if len(o_ids) == 0:
return {"discard": True}
input_ids = p_ids + o_ids
labels = [-100] * len(p_ids) + o_ids
attn = [1] * len(input_ids)
return {"input_ids": input_ids, "attention_mask": attn, "labels": labels, "discard": False}
train_built = train_raw.map(build_io, remove_columns=train_raw.column_names)
eval_built = eval_raw.map(build_io, remove_columns=eval_raw.column_names)
# drop discarded rows and the helper flag
train_built = train_built.filter(lambda ex: not ex["discard"]).remove_columns(["discard"])
eval_built = eval_built.filter(lambda ex: not ex["discard"]).remove_columns(["discard"])
return train_built, eval_built
# ------------------------------------------------------------------
# PEFT-wrapped model
# ------------------------------------------------------------------
def make_peft_model(config_cls, *, extra: Dict[str, Any] | None = None):
base = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
base.config.pad_token_id = base.config.pad_token_id or 0
base.requires_grad_(False)
cfg = prepare_config(
config_cls,
extra=extra,
r=4,
lora_alpha=8,
lora_dropout=0.05,
target_modules=TARGET_MODULES,
task_type=TaskType.CAUSAL_LM,
)
return get_peft_model(base, cfg)
# ------------------------------------------------------------------
# Trainer
# ------------------------------------------------------------------
def make_trainer(peft_model, tokenizer, train_ds, eval_ds, run_name: str):
collator = SFTDataCollator(tokenizer=tokenizer, label_pad_token_id=-100)
args = TrainingArguments(
output_dir=f"runs/{run_name}",
overwrite_output_dir=True,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM,
learning_rate=LR,
max_steps=MAX_STEPS,
warmup_steps=WARMUP_STEPS,
weight_decay=0.0,
lr_scheduler_type="cosine",
eval_steps=EVAL_EVERY,
logging_steps=50,
save_steps=MAX_STEPS,
save_total_limit=1,
report_to=[],
remove_unused_columns=False, # we pass exact fields
fp16=False, bf16=False, # simple numerics
dataloader_num_workers=0, # safest in notebooks
dataloader_pin_memory=False,
eval_strategy="steps",
)
peft_model.config.pad_token_id = tokenizer.pad_token_id
return Trainer(
model=peft_model,
args=args,
train_dataset=train_ds,
eval_dataset=eval_ds,
tokenizer=tokenizer, # ok despite FutureWarning
data_collator=collator,
)
# ------------------------------------------------------------------
# Size reporter
# ------------------------------------------------------------------
def adapter_size_row(config_cls, name: str, *, extra: Dict[str, Any] | None = None) -> AdapterSize:
model = make_peft_model(config_cls, extra=extra)
trainable = count_trainable_parameters(model)
total = count_total_parameters(model)
out = AdapterSize(name, trainable, total, 100.0 * trainable / total)
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
return out
# ------------------------------------------------------------------
# Main
# ------------------------------------------------------------------
def main():
train_ds, eval_ds = build_sft_splits(tokenizer, max_len=MAX_LEN)
adapters = [
("LoRA", LoraConfig, dict()),
("AdaLoRA", AdaLoraConfig, dict(
init_r=6, target_r=12, beta1=0.85, beta2=0.85,
tinit=int(0.1*MAX_STEPS), tfinal=int(0.5*MAX_STEPS), delta_t=int(0.1*MAX_STEPS),
total_step=MAX_STEPS,
)),
("DoRA", LoraConfig, dict(use_dora=True)),
# Optional (comment out if your local `peft` lacks support)
("LoHa", LoHaConfig, dict(hypercomplex_multiplier=2)),
("HRA", HRAConfig, dict(num_householder_blocks=2)),
]
size_rows, metrics_rows = [], []
for name, cfg_cls, extra in adapters:
print(f"\n==== {name} ====")
try:
sz = adapter_size_row(cfg_cls, name, extra=extra)
size_rows.append(sz)
model = make_peft_model(cfg_cls, extra=extra)
model.to(DEVICE)
trainer = make_trainer(model, tokenizer, train_ds, eval_ds, run_name=f"{name.lower()}_sft")
trainer.train()
eval_metrics = trainer.evaluate()
loss = float(eval_metrics["eval_loss"])
if math.isnan(loss) or math.isinf(loss):
raise ValueError(f"eval_loss is {loss}")
ppl = float(math.exp(min(20.0, loss)))
metrics_rows.append({"Adapter": name, "eval_loss": loss, "perplexity": ppl})
del trainer, model
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception as e:
print(f"[SKIP] {name} failed:\n{e}")
traceback.print_exc()
metrics_rows.append({"Adapter": name, "eval_loss": float('inf'), "perplexity": float('inf')})
print("\n== Adapter Size ==")
print(pd.DataFrame([{**r.__dict__} for r in size_rows]).to_string(index=False))
print("\n== Eval Metrics (lower is better) ==")
print(pd.DataFrame(metrics_rows).sort_values('perplexity').to_string(index=False))
main()
==== LoRA ====
/tmp/ipykernel_2696/259473458.py:213: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.
return Trainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 0}.
You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
| Step | Training Loss | Validation Loss |
|---|---|---|
| 100 | 2.704800 | 2.698495 |
| 200 | 2.613100 | 2.678968 |
| 300 | 2.620000 | 2.678347 |
==== AdaLoRA ====
/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/peft/tuners/adalora/config.py:96: UserWarning: Note that `r` is not used in AdaLora and will be ignored.If you intended to set the initial rank, use `init_r` instead.
warnings.warn(
/tmp/ipykernel_2696/259473458.py:213: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.
return Trainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 0}.
| Step | Training Loss | Validation Loss |
|---|---|---|
| 100 | 3.847500 | 3.278960 |
| 200 | 3.397800 | 3.112560 |
| 300 | 3.292000 | 3.078031 |
==== DoRA ====
/tmp/ipykernel_2696/259473458.py:213: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.
return Trainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 0}.
| Step | Training Loss | Validation Loss |
|---|---|---|
| 100 | 2.702900 | 2.702206 |
| 200 | 2.612600 | 2.683782 |
| 300 | 2.617300 | 2.681008 |
==== LoHa ====
/tmp/ipykernel_2696/259473458.py:213: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.
return Trainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 0}.
| Step | Training Loss | Validation Loss |
|---|---|---|
| 100 | 2.775400 | 2.784306 |
| 200 | 2.686600 | 2.759445 |
| 300 | 2.702200 | 2.755796 |
==== HRA ====
/tmp/ipykernel_2696/259473458.py:213: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.
return Trainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 0}.
| Step | Training Loss | Validation Loss |
|---|---|---|
| 100 | 2.773100 | 2.783116 |
| 200 | 2.695600 | 2.769071 |
| 300 | 2.714900 | 2.766985 |
== Adapter Size ==
name trainable_params total_params percent_trainable
LoRA 196608 70623232 0.278390
AdaLoRA 295056 70721704 0.417207
DoRA 224256 70650880 0.317414
LoHa 393216 70819840 0.555234
HRA 86016 70512640 0.121987
== Eval Metrics (lower is better) ==
Adapter eval_loss perplexity
LoRA 2.678347 14.561006
DoRA 2.681008 14.599797
LoHa 2.755796 15.733567
HRA 2.766985 15.910594
AdaLoRA 3.078031 21.715593
Takeaways#
All of the covered methods are exposed through the same
get_peft_modelworkflow, so swapping adapters is mostly a matter of changing the config class.Adaptive or structured variants incur a modest increase in trainable parameters but still stay within the PEFT regime.
Choosing between them depends on the task: AdaLoRA shines when different layers need different capacities, LoHa and HRA introduce richer geometric biases, and DoRA stabilises training when weight scales vary a lot.
These adapters are complementary to the optimisation strategies introduced earlier in the book, and they can be combined with quantisation or pruning techniques from the other appendices.