openvla驱动mujoco的franka机械臂运动

  1. openvla-7b 需要自己下载好模型,大约15GB;不会科学上网参考这里 不会科学上网自己从魔塔下载openvla-7b
  2. mujoco 的 fr3.xml 文件,是从Fr3py项目抄过来的。可以参考加速git仓库:整个mujoco目录都复制过去
  3. mujoco 第一次渲染,一定会有不少报错,参考内部文档,自行解决问题 mujoco安装问题处理

如下代码,就可以用openvla直接驱动 基于mujoco的Franka运动起来。 后续的摄像头仿真后续加上

#!/usr/bin/env python3

import os
import sys
import time
import numpy as np
from PIL import Image

os.environ['MUJOCO_GL'] = 'egl'

import mujoco
import mujoco.viewer
import torch
from transformers import AutoModelForVision2Seq, AutoProcessor

XML_PATH = os.path.join(os.path.dirname(__file__), "mujoco", "fr3.xml")
MODEL_PATH = "/home/ctbots/llm/openvla-7b"


class OpenVLAMuJoCoController:
    def __init__(self, use_gui: bool = True):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"[INFO] 使用设备: {self.device}")
        
        if torch.cuda.is_available():
            print(f"[INFO] GPU: {torch.cuda.get_device_name(0)}")
        
        print(f"[INFO] 加载 MuJoCo 模型: {XML_PATH}")
        self.model = mujoco.MjModel.from_xml_path(XML_PATH)
        self.data = mujoco.MjData(self.model)
        print("[SUCCESS] MuJoCo 模型加载成功")
        
        print("[INFO] 加载 OpenVLA 模型...")
        self.processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
        self.vla = AutoModelForVision2Seq.from_pretrained(
            MODEL_PATH,
            torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32,
            low_cpu_mem_usage=True,
            trust_remote_code=True
        ).to(self.device)
        print("[SUCCESS] OpenVLA 模型加载成功")
        
        self.use_gui = use_gui and os.environ.get('DISPLAY')
        self.renderer = None
        
        if not self.use_gui:
            print("[INFO] 初始化离屏渲染器...")
            try:
                self.renderer = mujoco.Renderer(self.model, height=224, width=224)
                print("[SUCCESS] 渲染器初始化成功")
            except Exception as e:
                print(f"[WARNING] 渲染器初始化失败: {e}")
        
        self.action_scale = 0.05
        self.step_count = 0
        
        self.initial_qpos = np.array([0, 0, 0, -1.57079, 0, 1.57079, -0.7853])
        self.data.qpos[:7] = self.initial_qpos
        mujoco.mj_forward(self.model, self.data)
    
    def get_camera_image(self):
        """获取相机图像"""
        if self.renderer:
            self.renderer.update_scene(self.data, camera="track")
            pixels = self.renderer.render()
            return pixels
        else:
            return np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
    
    def predict_action(self, image: np.ndarray, instruction: str):
        """使用 OpenVLA 预测动作"""
        pil_image = Image.fromarray(image)
        if pil_image.mode != 'RGB':
            pil_image = pil_image.convert('RGB')
        
        prompt = f"In: What action should the robot take to {instruction}?\nOut:"
        
        inputs = self.processor(prompt, pil_image).to(
            self.device,
            dtype=torch.bfloat16 if self.device == "cuda" else torch.float32
        )
        
        with torch.no_grad():
            action = self.vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)
        
        if isinstance(action, torch.Tensor):
            action_np = action.cpu().numpy().flatten()
        else:
            action_np = np.array(action).flatten()
        
        if len(action_np) > 7:
            action_np = action_np[:7]
        elif len(action_np) < 7:
            action_np = np.pad(action_np, (0, 7 - len(action_np)), 'constant')
        
        action_np = np.clip(action_np, -1.0, 1.0) * self.action_scale
        
        return action_np
    
    def run_with_gui(self, instruction: str, num_steps: int = 500):
        """使用 GUI 交互式运行"""
        print(f"\n[INFO] 任务: {instruction}")
        print(f"[INFO] 总步数: {num_steps}")
        print("[INFO] 启动交互式查看器...")
        print("[INFO] 按 Ctrl+C 停止")
        
        with mujoco.viewer.launch_passive(self.model, self.data) as viewer:
            step = 0
            reset_interval = 100
            
            while viewer.is_running() and step < num_steps:
                self.step_count += 1
                step += 1
                
                if step % 10 == 0:
                    rgb_image = self.get_camera_image()
                    action = self.predict_action(rgb_image, instruction)
                    
                    self.data.ctrl[:7] = np.clip(
                        self.data.qpos[:7] + action,
                        self.model.jnt_range[:7, 0],
                        self.model.jnt_range[:7, 1]
                    )
                
                mujoco.mj_step(self.model, self.data)
                viewer.sync()
                
                if step % 50 == 0:
                    ee_pos = self.data.xpos[self.model.body('hand').id]
                    print(f"[Step {step:04d}] 末端位置: [{ee_pos[0]:.3f}, {ee_pos[1]:.3f}, {ee_pos[2]:.3f}]")
                
                if step % reset_interval == 0:
                    print(f"[INFO] 重置到初始位置 (周期性运动)")
                    self.data.qpos[:7] = self.initial_qpos + np.random.randn(7) * 0.1
                    self.data.qpos[:7] = np.clip(
                        self.data.qpos[:7],
                        self.model.jnt_range[:7, 0],
                        self.model.jnt_range[:7, 1]
                    )
                    mujoco.mj_forward(self.model, self.data)
        
        print(f"\n[SUCCESS] 完成 {step} 步仿真")
    
    def run_headless(self, instruction: str, num_steps: int = 500):
        """无 GUI 运行(离屏渲染)"""
        print(f"\n[INFO] 任务: {instruction}")
        print(f"[INFO] 总步数: {num_steps}")
        print("[INFO] 离屏模式运行...")
        
        reset_interval = 100
        
        for step in range(num_steps):
            self.step_count += 1
            
            if step % 10 == 0:
                rgb_image = self.get_camera_image()
                action = self.predict_action(rgb_image, instruction)
                
                self.data.ctrl[:7] = np.clip(
                    self.data.qpos[:7] + action,
                    self.model.jnt_range[:7, 0],
                    self.model.jnt_range[:7, 1]
                )
            
            mujoco.mj_step(self.model, self.data)
            
            if step % 50 == 0:
                ee_pos = self.data.xpos[self.model.body('hand').id]
                print(f"[Step {step:04d}] 末端位置: [{ee_pos[0]:.3f}, {ee_pos[1]:.3f}, {ee_pos[2]:.3f}]")
            
            if step % reset_interval == 0 and step > 0:
                print(f"[INFO] 重置到初始位置 (周期性运动)")
                self.data.qpos[:7] = self.initial_qpos + np.random.randn(7) * 0.1
                self.data.qpos[:7] = np.clip(
                    self.data.qpos[:7],
                    self.model.jnt_range[:7, 0],
                    self.model.jnt_range[:7, 1]
                )
                mujoco.mj_forward(self.model, self.data)
        
        print(f"\n[SUCCESS] 完成 {num_steps} 步仿真")
    
    def close(self):
        if self.renderer:
            self.renderer.close()


def main():
    import argparse
    
    parser = argparse.ArgumentParser(description="OpenVLA + MuJoCo Franka 仿真")
    parser.add_argument("--instruction", type=str, default="pick up the object",
                        help="任务指令")
    parser.add_argument("--steps", type=int, default=500,
                        help="仿真步数")
    parser.add_argument("--no_gui", action="store_true",
                        help="禁用 GUI,使用离屏渲染")
    
    args = parser.parse_args()
    
    controller = OpenVLAMuJoCoController(use_gui=not args.no_gui)
    
    try:
        if controller.use_gui:
            controller.run_with_gui(args.instruction, args.steps)
        else:
            controller.run_headless(args.instruction, args.steps)
    except KeyboardInterrupt:
        print("\n[INFO] 用户中断")
    finally:
        controller.close()


if __name__ == "__main__":
    main()

评论

请输入您的评论. 可以使用维基语法:
 
机器人/openvla/openvla实验/openvla驱动mujoco的franka机械臂运动.txt · 最后更改: 2025/11/06 06:13