2-paddle模型缓存相关细节-aistudio-模型仓库

paddle在使用 模型时,使用了aistudio的sdk,具体模型缓存在:

/home/ctbots/.cache/aistudio/hub/models/[项目id坐标] 下

如何避免paddle重复下载大文件,如何不改代码骗过aistudio的sdk

我们在本地处理预训练模型的时候,如果处理完毕,可以上传到paddle的 aistudio上,并获取项目坐标,例如: ctbots/first-paddle-model

但是在代码中,直接使用 ctbots/first-paddle-model 模型的时候,会发现 paddle框架会自动下载模型到本地cache文件中 /home/ctbots/.cache/aistudio/hub/models/[项目id坐标] 下

几百兆的文件还好,对于几十G的文件,实在不想重新下载;

而且每次执行,paddle都会对比 远程仓库文件的sha值和本地的是否一致,如果不一致直接重新下载。 我们调整预训练模型的时候,model一直在变化,但是又不想改代码,专门修改为本地加载,

我们灵机一动,既然模型在本地push上去的,那么我直接复制到cache目录就行了,实际不行,还有sha值对比的环节。

我们这里给出一段代码,模拟sha值的cache逻辑在本地做一份欺骗的缓存,在不修改paddle代码的情况下,避免前期预训练调整模型阶段,重复下载。

import hashlib
import os
import pickle
from urllib.parse import quote

import requests

def compute_hash(file_path):
    buffer_size = 1024 * 64  # 64k buffer size
    sha256_hash = hashlib.sha256()
    with open(file_path, 'rb') as f:
        while True:
            data = f.read(buffer_size)
            if not data:
                break
            sha256_hash.update(data)
    return sha256_hash.hexdigest()

def mk_aistudio_cache_file(project_id: str, local_dir: str):
    if not os.path.exists(local_dir):
        return False, "模型目录不存在"
    model_meta = {
        "id": project_id
    }
    meta_file_path = os.path.join(local_dir, ".mdl")
    with open(meta_file_path, 'wb') as f:
        pickle.dump(model_meta, f)

    files = [f for f in os.listdir(local_dir) if os.path.isfile(os.path.join(local_dir, f))]

    cached_files = []
    base_url = "https://git.aistudio.baidu.com/api/v1/repos"
    for f in files:
        encoded_path = quote(f, safe='')
        url = f"{base_url}/{project_id}/contents/{encoded_path}"

        try:
            response = requests.get(url)
            response.raise_for_status()
            data = response.json()
            commit_id = data.get("sha", "N/A")
            cached_files.append({
                "Path": f,
                "Revision": commit_id
            })

        except Exception as e:
            cached_files.append({
                "Path": f,
                "Revision": "N/A"
            })

    cache_keys_file_path = os.path.join(local_dir, ".msc")
    with open(cache_keys_file_path, 'wb') as f:
        pickle.dump(cached_files, f)
    return True, "成功写入cache"

def test_mk_cache():
    rs = mk_aistudio_cache_file("ctbots/first-paddle-model", "/home/ctbots/.cache/aistudio/hub/models/ctbots/first-paddle-model")
    print(rs)