transformer模型信息统计

统计一些LLM模型的每层的参数shape和dtype

def print_model_info(model_path):
    model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)

    print("=" * 100)
    print(f"{'参数名称':<60} {'形状':<25} {'精度':<10} {'大小(MB)':<10}")
    print("=" * 100)

    total_size = 0
    for name, param in model.named_parameters():
        shape_str = str(tuple(param.shape))
        dtype_str = str(param.dtype).replace('torch.', '')
        size_mb = param.numel() * param.element_size() / 1024 / 1024
        total_size += size_mb

        print(f"{name:<60} {shape_str:<25} {dtype_str:<10} {size_mb: >8.2f}")