差别
这里会显示出您选择的修订版和当前版本之间的差别。
| 人工智能:paddle:paddleclas迁移timm:2.timm的efficientvit迁移到paddleclas:step1-转换timm的权重 [2026/01/28 08:31] – 创建 ctbots | 人工智能:paddle:paddleclas迁移timm:2.timm的efficientvit迁移到paddleclas:step1-转换timm的权重 [2026/01/28 08:31] (当前版本) – 移除 ctbots | ||
|---|---|---|---|
| 行 1: | 行 1: | ||
| - | ====== step1: | ||
| - | <code python> | ||
| - | # | ||
| - | import os | ||
| - | import argparse | ||
| - | import paddle | ||
| - | import traceback | ||
| - | def load_pytorch_weights(model_name): | ||
| - | try: | ||
| - | import torch | ||
| - | import timm | ||
| - | except ImportError: | ||
| - | raise ImportError(" | ||
| - | |||
| - | print(f" | ||
| - | pt_model = timm.create_model(model_name, | ||
| - | pt_state_dict = pt_model.state_dict() | ||
| - | |||
| - | print(f" | ||
| - | return pt_state_dict | ||
| - | |||
| - | |||
| - | def convert_name(pt_name): | ||
| - | if ' | ||
| - | return None | ||
| - | |||
| - | name = pt_name.replace(' | ||
| - | name = name.replace(' | ||
| - | | ||
| - | return name | ||
| - | |||
| - | |||
| - | def debug_weight_mapping(pt_state_dict, | ||
| - | print(f" | ||
| - | | ||
| - | aggreg_keys = [k for k in pt_state_dict.keys() if ' | ||
| - | if aggreg_keys: | ||
| - | print(f" | ||
| - | for key in aggreg_keys[: | ||
| - | print(f" | ||
| - | if len(aggreg_keys) > 10: | ||
| - | print(f" | ||
| - | | ||
| - | attention_keys = [k for k in pt_state_dict.keys() if any(x in k for x in [' | ||
| - | if attention_keys: | ||
| - | print(f" | ||
| - | for key in attention_keys[: | ||
| - | print(f" | ||
| - | if len(attention_keys) > 10: | ||
| - | print(f" | ||
| - | | ||
| - | print(f" | ||
| - | print(" | ||
| - | |||
| - | |||
| - | def is_large_model(model_name): | ||
| - | return ' | ||
| - | |||
| - | |||
| - | def convert_param(pt_name, | ||
| - | param = pt_param.cpu().numpy() | ||
| - | |||
| - | if ' | ||
| - | if ' | ||
| - | param = param.T | ||
| - | return param | ||
| - | |||
| - | |||
| - | def convert_state_dict(pt_state_dict, | ||
| - | pd_state_dict = {} | ||
| - | skipped = [] | ||
| - | | ||
| - | is_large = is_large_model(model_name) if model_name else False | ||
| - | | ||
| - | if debug or is_large: | ||
| - | debug_weight_mapping(pt_state_dict, | ||
| - | |||
| - | for pt_name, pt_param in pt_state_dict.items(): | ||
| - | pd_name = convert_name(pt_name) | ||
| - | if pd_name is None: | ||
| - | skipped.append(pt_name) | ||
| - | continue | ||
| - | |||
| - | pd_param = convert_param(pt_name, | ||
| - | pd_state_dict[pd_name] = pd_param | ||
| - | |||
| - | print(f" | ||
| - | print(f" | ||
| - | if skipped: | ||
| - | print(f" | ||
| - | |||
| - | return pd_state_dict | ||
| - | |||
| - | |||
| - | def save_weights(pd_state_dict, | ||
| - | os.makedirs(os.path.dirname(output_path), | ||
| - | paddle.save(pd_state_dict, | ||
| - | print(f" | ||
| - | print(f" | ||
| - | |||
| - | |||
| - | model_keys = [ | ||
| - | ' | ||
| - | ' | ||
| - | ' | ||
| - | ' | ||
| - | ' | ||
| - | ' | ||
| - | ' | ||
| - | ' | ||
| - | ' | ||
| - | ' | ||
| - | ' | ||
| - | ' | ||
| - | ' | ||
| - | ' | ||
| - | ' | ||
| - | ' | ||
| - | ' | ||
| - | ' | ||
| - | ' | ||
| - | ] | ||
| - | |||
| - | def main(): | ||
| - | parser = argparse.ArgumentParser(description=' | ||
| - | parser.add_argument(' | ||
| - | help=' | ||
| - | parser.add_argument(' | ||
| - | help=' | ||
| - | parser.add_argument(' | ||
| - | help=' | ||
| - | |||
| - | args = parser.parse_args() | ||
| - | |||
| - | # 如果没有指定模型,则遍历所有支持的模型 | ||
| - | if args.model is None or args.model == "": | ||
| - | for model_name in model_keys: | ||
| - | print(f" | ||
| - | try: | ||
| - | pt_state_dict = load_pytorch_weights(model_name) | ||
| - | pd_state_dict = convert_state_dict(pt_state_dict, | ||
| - | output_path = os.path.join(args.output, | ||
| - | save_weights(pd_state_dict, | ||
| - | except Exception as e: | ||
| - | print(f" | ||
| - | print(f" | ||
| - | continue | ||
| - | else: | ||
| - | # 单个模型转换 | ||
| - | if args.model not in model_keys: | ||
| - | print(f" | ||
| - | print(f" | ||
| - | return | ||
| - | |||
| - | pt_state_dict = load_pytorch_weights(args.model) | ||
| - | pd_state_dict = convert_state_dict(pt_state_dict, | ||
| - | output_path = os.path.join(args.output, | ||
| - | save_weights(pd_state_dict, | ||
| - | |||
| - | print(" | ||
| - | print(" | ||
| - | print(" | ||
| - | print(f" | ||
| - | print(f" | ||
| - | print(f" | ||
| - | print(" | ||
| - | |||
| - | |||
| - | if __name__ == ' | ||
| - | main() | ||
| - | |||
| - | </ | ||