def support_quantized_model_reload_from_hp_weights(original_load_weights):
"""Decorator for `load_weights` method for AutoWeightsLoader.load_weights to support
reloading high precision (bfloat16/float16/float32) weight for an already quantized
model, this involves restoring the weights to a high precision weights and
then online quantize the weights
"""
# online quantization, right now only enabled for
# torchao
# R1, R2, R3, R4, R5 in the Notes
def patched_model_load_weights(
auto_weight_loader, weights: Iterable[tuple[str, torch.Tensor]], *, mapper=None
) -> set[str]:
model = auto_weight_loader.module
offline_quantization_or_first_run_of_online_quantization = not getattr(
model, "weight_metadata_and_attr_saved", False
)
# if we don't have `model.weight_metadata_and_attr_saved` defined and
# set to True, it means that this is either offline quantization case
# or the first run of online quantization
# see Notes in this file for more details
if offline_quantization_or_first_run_of_online_quantization:
# case 1: offline quantized checkpoint
# case 2: Step I1 first run of weight loading with
# online quantization
return original_load_weights(auto_weight_loader, weights, mapper=mapper)
model_config = model._model_config
# TODO: Add fp8 support
assert model_config.quantization == "torchao", (
"online quantization is only enabled for torchao currently"
)
# TODO: use create_weights to restore the weights to original state
# Step R1: First restore the quantized weights to original bfloat16
# weights, with original metadata (shape, dtype, device)
# and attributes, so that bfloat16 weights can be loaded properly
# TODO: maybe set remove_duplicate to True?
original_quantized_weight_dict = dict(
model.named_parameters(remove_duplicate=False)
)
named_modules = dict(model.named_modules(remove_duplicate=False))
model_device = None
for name, d in model.original_weights_rebuild_keys.items():
_shape = d["shape"]
_dtype = d["dtype"]
_device = d["device"]
if model_device is not None:
assert model_device == _device, (
"Expecting all weights "
"to be in the same device for now, got both: "
f"{model_device} and {_device}"
)
else:
model_device = _device
if name in original_quantized_weight_dict:
module_name, weight_name = name.rsplit(".", 1)
module = named_modules[module_name]
setattr(
module,
weight_name,
torch.nn.Parameter(
torch.empty(_shape, dtype=_dtype, device=_device),
requires_grad=False,
),
)
# Step R2: recover the weight attributes to the state before first loading
# recorded_weight_attr is
# {"weight_name": {"weight_attr_key": attr}}
# e.g.
# {
# {
# "layer.0.weight": {
# "weight_loader": weight_loader_function_object,
# "input_dim": 0, ...
# },
# "layer.1.weight": ...,
# }
# }
for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items():
for attr_name, attr in weight_attr_dict.items():
module_name, weight_name = full_weight_name.rsplit(".", 1)
module = named_modules[module_name]
weight = getattr(module, weight_name)
if not hasattr(weight, attr_name):
setattr(weight, attr_name, _bond_method_to_cls(attr, weight))
# Step R3: reload bfloat16 / high precision weights
updated_params = original_load_weights(
auto_weight_loader, weights, mapper=mapper
)
# Step R4: online quantize the weights
# manually process weights after loading
model.process_weights_after_loading_already_called = False
if model_device is not None:
process_weights_after_loading(model, model_config, model_device)
else:
logger.warning_once(
"model_device is None, skip calling process_weights_after_loading"
)
# Step R5 (workaround for cudagraph): restore the original quantized weights
# and do a copy_ of the currents weights to the original weights
updated_quantized_weights = dict(model.named_parameters(remove_duplicate=False))
for name in model.original_weights_rebuild_keys:
if name in original_quantized_weight_dict:
original_quantized_weight = original_quantized_weight_dict[name]
updated_quantized_weight = updated_quantized_weights[name]
module_name, weight_name = name.rsplit(".", 1)
module = named_modules[module_name]
setattr(module, weight_name, original_quantized_weight)
with torch.no_grad():
original_quantized_weight.copy_(updated_quantized_weight)
del original_quantized_weight_dict
del named_modules
del updated_quantized_weight
model.process_weights_after_loading_already_called = True
return updated_params
return patched_model_load_weights