面向强化学习的状态空间建模:RSSM的介绍和PyTorch实现

循环状态空间模型(Recurrent State Space Models, RSSM)最初由 Danijar Hafer 等人在论文《Learning Latent Dynamics for Planning from Pixels》中提出。该模型在现代基于模型的强化学习(Model-Based Reinforcement Learning, MBRL)中发挥着关键作用,其主要目标是构建可靠的环境动态预测模型。通过这些学习得到的模型,智能体能够模拟未来轨迹并进行前瞻性的行为规划。

图片

下面我们就来用一个实际案例来介绍RSSM。

环境配置

环境配置是实现过程中的首要步骤。我们这里用易于使用的 Gym API。为了提高实现效率,设计了多个模块化的包装器(wrapper),用于初始化参数并将观察结果调整为指定格式。

InitialWrapper 的设计允许在不执行任何动作的情况下进行特定数量的观察,同时支持在返回观察结果之前多次重复同一动作。这种设计对于响应具有显著延迟特性的环境特别有效。

PreprocessFrame 包装器负责将观察结果转换为正确的数据类型(本文中使用 numpy 数组),并支持灰度转换功能。

 class InitialWrapper(gym.Wrapper):  
     def __init__(self, env: gym.Env, no_ops: int = 0, repeat: int = 1):  
         super(InitialWrapper, self).__init__(env)  
         self.repeat = repeat  
         self.no_ops = no_ops  
 
         self.op_counter = 0  
   
     def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:  
         if self.op_counter < self.no_ops:  
             obs, reward, done, info = self.env.step(0)  
             self.op_counter += 1  
   
         total_reward = 0.0  
         done = False  
         for _ in range(self.repeat):  
             obs, reward, done, info = self.env.step(action)  
             total_reward += reward  
             if done:  
                 break  
   
         return obs, total_reward, done, info  
 
 
 class PreprocessFrame(gym.ObservationWrapper):  
     def __init__(self, env: gym.Env, new_shape: Sequence[int] = (128, 128, 3), grayscale: bool = False):  
         super(PreprocessFrame, self).__init__(env)  
         self.shape = new_shape  
         self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=self.shape, dtype=np.float32)  
         self.grayscale = grayscale  
   
         if self.grayscale:  
             self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(*self.shape[:-1], 1), dtype=np.float32)  
   
     def observation(self, obs: torch.Tensor) -> torch.Tensor:  
         obs = obs.astype(np.uint8)  
         new_frame = cv.resize(obs, self.shape[:-1], interpolation=cv.INTER_AREA)  
         if self.grayscale:  
             new_frame = cv.cvtColor(new_frame, cv.COLOR_RGB2GRAY)  
             new_frame = np.expand_dims(new_frame, -1)  
   
         torch_frame = torch.from_numpy(new_frame).float()  
         torch_frame = torch_frame / 255.0  
   
         return torch_frame  
   
 def make_env(env_name: str, new_shape: Sequence[int] = (128, 128, 3), grayscale: bool = True, **kwargs):  
     env = gym.make(env_name, **kwargs)  
     env = PreprocessFrame(env, new_shape, grayscale=grayscale)  
     return env

make_env 函数用于创建一个具有指定配置参数的环境实例。

模型架构

RSSM 的实现依赖于多个关键模型组件。具体来说,需要实现以下四个核心模块:

  • 原始观察编码器(Encoder)

  • 动态模型(Dynamics Model):通过确定性状态 h 和随机状态 s 对编码观察的时间依赖性进行建模

  • 解码器(Decoder):将随机状态和确定性状态映射回原始观察空间

  • 奖励模型(Reward Model):将随机状态和确定性状态映射到奖励值

图片

RSSM 模型组件结构图。模型包含随机状态 s 和确定性状态 h。

编码器实现

编码器采用简单的卷积神经网络(CNN)结构,将输入图像降维到一维嵌入表示。实现中使用了 BatchNorm 来提升训练稳定性。

 class EncoderCNN(nn.Module):  
     def __init__(self, in_channels: int, embedding_dim: int = 2048, input_shape: Tuple[int, int] = (128, 128)):  
         super(EncoderCNN, self).__init__()  
         # 定义卷积层结构
         self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1)  
         self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)  
         self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)  
         self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)  
   
         self.fc1 = nn.Linear(self._compute_conv_output((in_channels, input_shape[0], input_shape[1])), embedding_dim)  
   
         # 批标准化层
         self.bn1 = nn.BatchNorm2d(32)  
         self.bn2 = nn.BatchNorm2d(64)  
         self.bn3 = nn.BatchNorm2d(128)  
         self.bn4 = nn.BatchNorm2d(256)  
   
     def _compute_conv_output(self, shape: Tuple[int, int, int]):  
         with torch.no_grad():  
             x = torch.randn(1, shape[0], shape[1], shape[2])  
             x = self.conv1(x)  
             x = self.conv2(x)  
             x = self.conv3(x)  
             x = self.conv4(x)  
   
             return x.shape[1] * x.shape[2] * x.shape[3]  
 
     def forward(self, x):  
         x = torch.relu(self.conv1(x))  
         x = self.bn1(x)  
         x = torch.relu(self.conv2(x))  
         x = self.bn2(x)  
   
         x = torch.relu(self.conv3(x))  
         x = self.bn3(x)  
   
         x = self.conv4(x)  
         x = self.bn4(x)  
   
         x = x.view(x.size(0), -1)  
         x = self.fc1(x)  
   
         return x

解码器实现

解码器遵循传统自编码器架构设计,其功能是将编码后的观察结果重建回原始观察空间。

 class DecoderCNN(nn.Module):  
     def __init__(self, hidden_size: int, state_size: int,  embedding_size: int,  
                  use_bn: bool = True, output_shape: Tuple[int, int] = (3, 128, 128)):  
         super(DecoderCNN, self).__init__()  
   
         self.output_shape = output_shape  
   
         self.embedding_size = embedding_size  
         # 全连接层进行特征变换
         self.fc1 = nn.Linear(hidden_size + state_size, embedding_size)  
         self.fc2 = nn.Linear(embedding_size, 256 * (output_shape[1] // 16) * (output_shape[2] // 16))  
   
         # 反卷积层进行上采样
         self.conv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1)  # ×2  
         self.conv2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)  # ×2  
         self.conv3 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)  # ×2  
         self.conv4 = nn.ConvTranspose2d(32, output_shape[0], kernel_size=3, stride=2, padding=1, output_padding=1)  
   
         # 批标准化层
         self.bn1 = nn.BatchNorm2d(128)  
         self.bn2 = nn.BatchNorm2d(64)  
         self.bn3 = nn.BatchNorm2d(32)  
   
         self.use_bn = use_bn  
 
     def forward(self, h: torch.Tensor, s: torch.Tensor):  
         x = torch.cat([h, s], dim=-1)  
         x = self.fc1(x)  
         x = torch.relu(x)  
         x = self.fc2(x)  
   
         x = x.view(-1, 256, self.output_shape[1] // 16, self.output_shape[2] // 16)  
   
         if self.use_bn:  
             x = torch.relu(self.bn1(self.conv1(x)))  
             x = torch.relu(self.bn2(self.conv2(x)))  
             x = torch.relu(self.bn3(self.conv3(x)))  
   
         else:  
             x = torch.relu(self.conv1(x))  
             x = torch.relu(self.conv2(x))  
             x = torch.relu(self.conv3(x))  
   
         x = self.conv4(x)  
   
         return x    

奖励模型实现

奖励模型采用了一个三层前馈神经网络结构,用于将随机状态 s 和确定性状态 h 映射到正态分布参数,进而通过采样获得奖励预测。

 class RewardModel(nn.Module):  
     def __init__(self, hidden_dim: int, state_dim: int):  
         super(RewardModel, self).__init__()  
   
         self.fc1 = nn.Linear(hidden_dim + state_dim, hidden_dim)  
         self.fc2 = nn.Linear(hidden_dim, hidden_dim)  
         self.fc3 = nn.Linear(hidden_dim, 2)  
   
     def forward(self, h: torch.Tensor, s: torch.Tensor):  
         x = torch.cat([h, s], dim=-1)  
         x = torch.relu(self.fc1(x))  
         x = torch.relu(self.fc2(x))  
         x = self.fc3(x)  
   
         return x

动态模型的实现

动态模型是 RSSM 架构中最复杂的组件,需要同时处理先验和后验状态转移模型:

  1. 后验转移模型:在能够访问真实观察的情况下使用(主要在训练阶段),用于在给定观察和历史状态的条件下近似随机状态的后验分布。

  2. 先验转移模型:用于近似先验状态分布,仅依赖于前一时刻状态,不依赖于观察。这在无法获取后验观察的推理阶段使用。

这两个模型均通过单层前馈网络进行参数化,输出各自正态分布的均值和对数方差,用于状态 s 的采样。该实现采用了简单的网络结构,但可以根据需要扩展为更复杂的架构。

确定性状态采用门控循环单元(GRU)实现。其输入包括:

  • 前一时刻的隐藏状态

  • 独热编码动作

  • 前一时刻随机状态 s(根据是否可以获取观察来选择使用后验或先验状态)

这些输入信息足以让模型了解动作历史和系统状态。以下是具体实现代码:

 class DynamicsModel(nn.Module):  
     def __init__(self, hidden_dim: int, action_dim: int, state_dim: int, embedding_dim: int, rnn_layer: int = 1):  
         super(DynamicsModel, self).__init__()  
   
         self.hidden_dim = hidden_dim  
           
         # 递归层实现,支持多层 GRU
         self.rnn = nn.ModuleList([nn.GRUCell(hidden_dim, hidden_dim) for _ in range(rnn_layer)])  
           
         # 状态动作投影层
         self.project_state_action = nn.Linear(action_dim + state_dim, hidden_dim)  
           
         # 先验网络:输出正态分布参数
         self.prior = nn.Linear(hidden_dim, state_dim * 2)  
         self.project_hidden_action = nn.Linear(hidden_dim + action_dim, hidden_dim)  
           
         # 后验网络:输出正态分布参数
         self.posterior = nn.Linear(hidden_dim, state_dim * 2)  
         self.project_hidden_obs = nn.Linear(hidden_dim + embedding_dim, hidden_dim)  
   
         self.state_dim = state_dim  
         self.act_fn = nn.ReLU()  
   
     def forward(self, prev_hidden: torch.Tensor, prev_state: torch.Tensor, actions: torch.Tensor,  
                 obs: torch.Tensor = None, dones: torch.Tensor = None):  
         """  
        动态模型的前向传播
        参数:  
            prev_hidden: RNN的前一隐藏状态,形状 (batch_size, hidden_dim)  
            prev_state: 前一随机状态,形状 (batch_size, state_dim)  
            actions: 独热编码动作序列,形状 (sequence_length, batch_size, action_dim)  
            obs: 编码器输出的观察嵌入,形状 (sequence_length, batch_size, embedding_dim)  
            dones: 终止状态标志
        """  
         B, T, _ = actions.size()  # 用于无观察访问时的推理
   
         # 初始化存储列表
         hiddens_list = []  
         posterior_means_list = []  
         posterior_logvars_list = []  
         prior_means_list = []  
         prior_logvars_list = []  
         prior_states_list = []  
         posterior_states_list = []  
           
         # 存储初始状态
         hiddens_list.append(prev_hidden.unsqueeze(1))    
         prior_states_list.append(prev_state.unsqueeze(1))  
         posterior_states_list.append(prev_state.unsqueeze(1))  
   
         # 时序展开
         for t in range(T - 1):  
             # 提取当前时刻状态和动作
             action_t = actions[:, t, :]  
             obs_t = obs[:, t, :] if obs is not None else torch.zeros(B, self.embedding_dim, device=actions.device)  
             state_t = posterior_states_list[-1][:, 0, :] if obs is not None else prior_states_list[-1][:, 0, :]  
             state_t = state_t if dones is None else state_t * (1 - dones[:, t, :])  
             hidden_t = hiddens_list[-1][:, 0, :]  
               
             # 状态动作组合
             state_action = torch.cat([state_t, action_t], dim=-1)  
             state_action = self.act_fn(self.project_state_action(state_action))  
   
             # RNN 状态更新
             for i in range(len(self.rnn)):  
                 hidden_t = self.rnn[i](state_action, hidden_t)  
   
             # 先验分布计算
             hidden_action = torch.cat([hidden_t, action_t], dim=-1)  
             hidden_action = self.act_fn(self.project_hidden_action(hidden_action))  
             prior_params = self.prior(hidden_action)  
             prior_mean, prior_logvar = torch.chunk(prior_params, 2, dim=-1)  
   
             # 从先验分布采样
             prior_dist = torch.distributions.Normal(prior_mean, torch.exp(F.softplus(prior_logvar)))  
             prior_state_t = prior_dist.rsample()  
   
             # 后验分布计算
             if obs is None:  
                 posterior_mean = prior_mean  
                 posterior_logvar = prior_logvar  
             else:  
                 hidden_obs = torch.cat([hidden_t, obs_t], dim=-1)  
                 hidden_obs = self.act_fn(self.project_hidden_obs(hidden_obs))  
                 posterior_params = self.posterior(hidden_obs)  
                 posterior_mean, posterior_logvar = torch.chunk(posterior_params, 2, dim=-1)  
   
             # 从后验分布采样
             posterior_dist = torch.distributions.Normal(posterior_mean, torch.exp(F.softplus(posterior_logvar)))  
             posterior_state_t = posterior_dist.rsample()  
   
             # 保存状态
             posterior_means_list.append(posterior_mean.unsqueeze(1))  
             posterior_logvars_list.append(posterior_logvar.unsqueeze(1))  
             prior_means_list.append(prior_mean.unsqueeze(1))  
             prior_logvars_list.append(prior_logvar.unsqueeze(1))  
             prior_states_list.append(prior_state_t.unsqueeze(1))  
             posterior_states_list.append(posterior_state_t.unsqueeze(1))  
             hiddens_list.append(hidden_t.unsqueeze(1))  
   
         # 合并时序数据
         hiddens = torch.cat(hiddens_list, dim=1)  
         prior_states = torch.cat(prior_states_list, dim=1)  
         posterior_states = torch.cat(posterior_states_list, dim=1)  
         prior_means = torch.cat(prior_means_list, dim=1)  
         prior_logvars = torch.cat(prior_logvars_list, dim=1)  
         posterior_means = torch.cat(posterior_means_list, dim=1)  
         posterior_logvars = torch.cat(posterior_logvars_list, dim=1)  
   
         return hiddens, prior_states, posterior_states, prior_means, prior_logvars, posterior_means, posterior_logvars

需要特别注意的是,这里的观察输入并非原始观察数据,而是经过编码器处理后的嵌入表示。这种设计能够有效降低计算复杂度并提升模型的泛化能力。

RSSM 整体架构

将前述组件整合为完整的 RSSM 模型。其核心是 generate_rollout 方法,负责调用动态模型并生成环境动态的潜在表示序列。对于没有历史潜在状态的情况(通常发生在轨迹开始时),该方法会进行必要的初始化。下面是完整的实现代码:

 class RSSM:  
     def __init__(self,  
                  encoder: EncoderCNN,  
                  decoder: DecoderCNN,  
                  reward_model: RewardModel,  
                  dynamics_model: nn.Module,  
                  hidden_dim: int,  
                  state_dim: int,  
                  action_dim: int,  
                  embedding_dim: int,  
                  device: str = "mps"):  
         """  
        循环状态空间模型(RSSM)实现
         
        参数:
            encoder: 确定性状态编码器
            decoder: 观察重构解码器
            reward_model: 奖励预测模型
            dynamics_model: 状态动态模型
            hidden_dim: RNN 隐藏层维度
            state_dim: 随机状态维度
            action_dim: 动作空间维度
            embedding_dim: 观察嵌入维度
            device: 计算设备
        """  
         super(RSSM, self).__init__()  
   
         # 模型组件初始化
         self.dynamics = dynamics_model  
         self.encoder = encoder  
         self.decoder = decoder  
         self.reward_model = reward_model  
   
         # 维度参数存储
         self.hidden_dim = hidden_dim  
         self.state_dim = state_dim  
         self.action_dim = action_dim  
         self.embedding_dim = embedding_dim  
   
         # 模型迁移至指定设备
         self.dynamics.to(device)  
         self.encoder.to(device)  
         self.decoder.to(device)  
         self.reward_model.to(device)  
   
     def generate_rollout(self, actions: torch.Tensor, hiddens: torch.Tensor = None, states: torch.Tensor = None,  
                          obs: torch.Tensor = None, dones: torch.Tensor = None):  
         """
        生成状态序列展开
         
        参数:
            actions: 动作序列
            hiddens: 初始隐藏状态(可选)
            states: 初始随机状态(可选)
            obs: 观察序列(可选)
            dones: 终止标志序列
             
        返回:
            完整的状态展开序列
        """
         # 状态初始化
         if hiddens is None:  
             hiddens = torch.zeros(actions.size(0), self.hidden_dim).to(actions.device)  
   
         if states is None:  
             states = torch.zeros(actions.size(0), self.state_dim).to(actions.device)  
   
         # 执行动态模型展开
         dynamics_result = self.dynamics(hiddens, states, actions, obs, dones)  
         hiddens, prior_states, posterior_states, prior_means, prior_logvars, posterior_means, posterior_logvars = dynamics_result  
   
         return hiddens, prior_states, posterior_states, prior_means, prior_logvars, posterior_means, posterior_logvars  
   
     def train(self):  
         """启用训练模式"""
         self.dynamics.train()  
         self.encoder.train()  
         self.decoder.train()  
         self.reward_model.train()  
   
     def eval(self):  
         """启用评估模式"""
         self.dynamics.eval()  
         self.encoder.eval()  
         self.decoder.eval()  
         self.reward_model.eval()  
   
     def encode(self, obs: torch.Tensor):  
         """观察编码"""
         return self.encoder(obs)  
   
     def decode(self, state: torch.Tensor):  
         """状态解码为观察"""
         return self.decoder(state)  
   
     def predict_reward(self, h: torch.Tensor, s: torch.Tensor):  
         """奖励预测"""
         return self.reward_model(h, s)  
   
     def parameters(self):  
         """返回所有可训练参数"""
         return list(self.dynamics.parameters()) + list(self.encoder.parameters()) + \
                list(self.decoder.parameters()) + list(self.reward_model.parameters())  
   
     def save(self, path: str):  
         """模型状态保存"""
         torch.save({  
             "dynamics": self.dynamics.state_dict(),  
             "encoder": self.encoder.state_dict(),  
             "decoder": self.decoder.state_dict(),  
             "reward_model": self.reward_model.state_dict()  
        }, path)  
   
     def load(self, path: str):  
         """模型状态加载"""
         checkpoint = torch.load(path)  
         self.dynamics.load_state_dict(checkpoint["dynamics"])  
         self.encoder.load_state_dict(checkpoint["encoder"])  
         self.decoder.load_state_dict(checkpoint["decoder"])  
         self.reward_model.load_state_dict(checkpoint["reward_model"])

这个实现提供了一个完整的 RSSM 框架,包含了模型的训练、评估、状态保存和加载等基本功能。该框架可以作为基础结构,根据具体应用场景进行扩展和优化。

训练系统设计

RSSM 的训练系统主要包含两个核心组件:经验回放缓冲区(Experience Replay Buffer)和智能体(Agent)。其中,缓冲区负责存储历史经验数据用于训练,而智能体则作为环境与 RSSM 之间的接口,实现数据收集策略。

经验回放缓冲区实现

缓冲区采用循环队列结构,用于存储和管理观察、动作、奖励和终止状态等数据。通过 sample 方法可以随机采样训练序列。

 class Buffer:  
     def __init__(self, buffer_size: int, obs_shape: tuple, action_shape: tuple, device: torch.device):  
         """
        经验回放缓冲区初始化
         
        参数:
            buffer_size: 缓冲区容量
            obs_shape: 观察数据维度
            action_shape: 动作数据维度
            device: 计算设备
        """
         self.buffer_size = buffer_size  
         self.obs_buffer = np.zeros((buffer_size, *obs_shape), dtype=np.float32)  
         self.action_buffer = np.zeros((buffer_size, *action_shape), dtype=np.int32)  
         self.reward_buffer = np.zeros((buffer_size, 1), dtype=np.float32)  
         self.done_buffer = np.zeros((buffer_size, 1), dtype=np.bool_)  
   
         self.device = device  
         self.idx = 0  
   
     def add(self, obs: torch.Tensor, action: int, reward: float, done: bool):  
         """
        添加单步经验数据
        """
         self.obs_buffer[self.idx] = obs  
         self.action_buffer[self.idx] = action  
         self.reward_buffer[self.idx] = reward  
         self.done_buffer[self.idx] = done  
         self.idx = (self.idx + 1) % self.buffer_size  
 
     def sample(self, batch_size: int, sequence_length: int):  
         """
        随机采样经验序列
         
        参数:
            batch_size: 批量大小
            sequence_length: 序列长度
             
        返回:
            经验数据元组 (observations, actions, rewards, dones)
        """
         # 随机选择序列起始位置
         starting_idxs = np.random.randint(0, (self.idx % self.buffer_size) - sequence_length, (batch_size,))  
         
         # 构建完整序列索引
         index_tensor = np.stack([np.arange(start, start + sequence_length) for start in starting_idxs])  
         
         # 提取数据序列
         obs_sequence = self.obs_buffer[index_tensor]  
         action_sequence = self.action_buffer[index_tensor]  
         reward_sequence = self.reward_buffer[index_tensor]  
         done_sequence = self.done_buffer[index_tensor]  
   
         return obs_sequence, action_sequence, reward_sequence, done_sequence  
 
     def save(self, path: str):  
         """保存缓冲区数据"""
         np.savez(path, obs_buffer=self.obs_buffer, action_buffer=self.action_buffer,  
                  reward_buffer=self.reward_buffer, done_buffer=self.done_buffer, idx=self.idx)  
   
     def load(self, path: str):  
         """加载缓冲区数据"""
         data = np.load(path)  
         self.obs_buffer = data["obs_buffer"]  
         self.action_buffer = data["action_buffer"]  
         self.reward_buffer = data["reward_buffer"]  
         self.done_buffer = data["done_buffer"]  
         self.idx = data["idx"]

智能体设计

智能体实现了数据收集和规划功能。当前实现采用了简单的随机策略进行数据收集,但该框架支持扩展更复杂的策略。

 class Policy(ABC):  
     """策略基类"""
     @abstractmethod  
     def __call__(self, obs):  
         pass  
   
 class RandomPolicy(Policy):  
     """随机采样策略"""
     def __init__(self, env: Env):  
         self.env = env  
   
     def __call__(self, obs):  
         return self.env.action_space.sample()  
 
 class Agent:  
     def __init__(self, env: Env, rssm: RSSM, buffer_size: int = 100000,
                  collection_policy: str = "random", device="mps"):  
         """
        智能体初始化
         
        参数:
            env: 环境实例
            rssm: RSSM模型实例
            buffer_size: 经验缓冲区大小
            collection_policy: 数据收集策略类型
            device: 计算设备
        """
         self.env = env  
         # 策略选择
         match collection_policy:  
             case "random":  
                 self.rollout_policy = RandomPolicy(env)  
             case _:  
                 raise ValueError("Invalid rollout policy")  
   
         self.buffer = Buffer(buffer_size, env.observation_space.shape,
                            env.action_space.shape, device=device)  
         self.rssm = rssm  
   
     def data_collection_action(self, obs):  
         """执行数据收集动作"""
         return self.rollout_policy(obs)  
   
     def collect_data(self, num_steps: int):  
         """
        收集训练数据
         
        参数:
            num_steps: 收集步数
        """
         obs = self.env.reset()  
         done = False  
   
         iterator = tqdm(range(num_steps), desc="Data Collection")  
         for _ in iterator:  
             action = self.data_collection_action(obs)  
             next_obs, reward, done, _, _ = self.env.step(action)  
             self.buffer.add(next_obs, action, reward, done)  
             obs = next_obs  
             if done:  
                 obs = self.env.reset()  
   
     def imagine_rollout(self, prev_hidden: torch.Tensor, prev_state: torch.Tensor,
                        actions: torch.Tensor):  
         """
        执行想象展开
         
        参数:
            prev_hidden: 前一隐藏状态
            prev_state: 前一随机状态
            actions: 动作序列
             
        返回:
            完整的模型输出,包括隐藏状态、先验状态、后验状态等
        """
         hiddens, prior_states, posterior_states, prior_means, prior_logvars, \
         posterior_means, posterior_logvars = self.rssm.generate_rollout(
             actions, prev_hidden, prev_state)  
   
         # 在想象阶段使用先验状态预测奖励
         rewards = self.rssm.predict_reward(hiddens, prior_states)  
   
         return hiddens, prior_states, posterior_states, prior_means, \
                prior_logvars, posterior_means, posterior_logvars, rewards  
   
     def plan(self, num_steps: int, prev_hidden: torch.Tensor,
              prev_state: torch.Tensor, actions: torch.Tensor):  
         """
        执行规划
         
        参数:
            num_steps: 规划步数
            prev_hidden: 初始隐藏状态
            prev_state: 初始随机状态
            actions: 动作序列
             
        返回:
            规划得到的隐藏状态和先验状态序列
        """
         hidden_states = []  
         prior_states = []  
   
         hiddens = prev_hidden  
         states = prev_state  
   
         for _ in range(num_steps):  
             hiddens, states, _, _, _, _, _, _ = self.imagine_rollout(
                 hiddens, states, actions)  
             hidden_states.append(hiddens)  
             prior_states.append(states)  
   
         hidden_states = torch.stack(hidden_states)  
         prior_states = torch.stack(prior_states)  
   
         return hidden_states, prior_states

这部分实现提供了完整的数据管理和智能体交互框架。通过经验回放缓冲区,可以高效地存储和重用历史数据;通过智能体的抽象策略接口,可以方便地扩展不同的数据收集策略。同时智能体还实现了基于模型的想象展开和规划功能,为后续的决策制定提供了基础。

训练器实现与实验

训练器设计

训练器是 RSSM 实现中的最后一个关键组件,负责协调模型训练过程。训练器接收 RSSM 模型、智能体、优化器等组件,并实现具体的训练逻辑。

总结

本文详细介绍了基于 PyTorch 实现 RSSM 的完整过程。RSSM 的架构相比传统的 VAE 或 RNN 更为复杂,这主要源于其混合了随机和确定性状态的特性。通过手动实现这一架构,我们可以深入理解其背后的理论基础及其强大之处。RSSM 能够递归地生成未来潜在状态轨迹,这为智能体的行为规划提供了基础。

实现的优点在于其计算负载适中,可以在单个消费级 GPU 上进行训练,在有充足时间的情况下甚至可以在 CPU 上运行。这一工作基于论文《Learning Latent Dynamics for Planning from Pixels》,该论文为 RSSM 类动态模型奠定了基础。后续的研究工作如《Dream to Control: Learning Behaviors by Latent Imagination》进一步发展了这一架构。这些改进的架构将在未来的研究中深入探讨,因为它们对理解 MBRL 方法提供了重要的见解。

作者:Lukas Bierling


喜欢就关注一下吧:

点个 在看 你最好看!