transformer模型信息统计
统计一些LLM模型的每层的参数shape和dtype
def print_model_info(model_path):
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
print("=" * 100)
print(f"{'参数名称':<60} {'形状':<25} {'精度':<10} {'大小(MB)':<10}")
print("=" * 100)
total_size = 0
for name, param in model.named_parameters():
shape_str = str(tuple(param.shape))
dtype_str = str(param.dtype).replace('torch.', '')
size_mb = param.numel() * param.element_size() / 1024 / 1024
total_size += size_mb
print(f"{name:<60} {shape_str:<25} {dtype_str:<10} {size_mb: >8.2f}")
评论