2024年ICML文章,ACE:Off-PolicyActor-CriticwithCausality-AwareEntropyRegularization精讀
該論文是清華項目組組內(nèi)博士師兄寫的文章,項目主頁為 ACE (ace-rl.github.io) ,于2024年7月發(fā)表在ICML期刊
因為最近組內(nèi)(其實只有我)需要從零開始做一個相關項目,前面的幾篇文章都是鋪墊
本文章為強化學習筆記第5篇
本文初編輯于2024.10.5,好像是這個時間,忘記了,前后寫了兩個多星期
CSDN主頁: https://blog.csdn.net/rvdgdsva
博客園主頁: https://www.cnblogs.com/hassle
博客園本文鏈接:
這篇強化學習論文主要介紹了一個名為 ACE 的算法,完整名稱為 Off-Policy Actor-Critic with Causality-Aware Entropy Regularization ,它通過引入因果關系分析和因果熵正則化來解決現(xiàn)有模型在不同動作維度上的不平等探索問題,旨在改進強化學習【注釋1】中探索效率和樣本效率的問題,特別是在高維度連續(xù)控制任務中的表現(xiàn)。
【注釋1】: 強化學習入門這一篇就夠了
在policy【注釋2】學習過程中,不同原始行為的不同意義被先前的model-free RL 算法所忽視。利用這一見解,我們探索了不同行動維度和獎勵之間的因果關系,以評估訓練過程中各種原始行為的重要性。我們引入了一個因果關系感知熵【注釋3】項(causality-aware entropy term),它可以有效地識別并優(yōu)先考慮具有高潛在影響的行為,以實現(xiàn)高效的探索。此外,為了防止過度關注特定的原始行為,我們分析了梯度休眠現(xiàn)象(gradientdormancyphenomenon),并引入了休眠引導的重置機制,以進一步增強我們方法的有效性。與無模型RL基線相比,我們提出的算法 ACE :Off-policy A ctor-criticwith C ausality-aware E ntropyregularization。在跨越7個域的29種不同連續(xù)控制任務中顯示出實質(zhì)性的性能優(yōu)勢,這強調(diào)了我們方法的有效性、多功能性和高效的樣本效率。 基準測試結果和視頻可在https://ace-rl.github.io/上獲得。
【注釋2】: 強化學習算法中on-policy和off-policy
【注釋3】: 最大熵 RL:從Soft Q-Learning到SAC - 知乎
【1】 因果關系分析 :通過引入因果政策-獎勵結構模型,評估不同動作維度(即原始行為)對獎勵的影響大。ǚQ為“因果權重”)。這些權重反映了每個動作維度在不同學習階段的相對重要性。
作出上述改進的原因是:考慮一個簡單的例子,一個機械手最初應該學習放下手臂并抓住物體,然后將注意力轉移到學習手臂朝著最終目標的運動方向上。因此,在策略學習的不同階段強調(diào)對最重要的原始行為的探索是 至關重要的。在探索過程中刻意關注各種原始行為,可以加速智能體在每個階段對基本原始行為的學習,從而提高掌握完整運動任務的效率。
此處可供學習的資料:
【2】 因果熵正則化 :在最大熵強化學習框架的基礎上(如SAC算法),加入了 因果加權的熵正則化項 。與傳統(tǒng)熵正則化不同,這一項根據(jù)各個原始行為的因果權重動態(tài)調(diào)整,強化對重要行為的探索,減少對不重要行為的探索。
作出上述改進的原因是:論文引入了一個因果策略-獎勵結構模型來計算行動空間上的因果權重(causal weights),因果權重會引導agent進行更有效的探索, 鼓勵對因果權重較大的動作維度進行探索,表明對獎勵的重要性更大,并減少對因果權重較小的行為維度的探 索。一般的最大熵目標缺乏對不同學習階段原始行為之間區(qū)別的重要性的認識,可能導致低效的探索。為了解決這一限制,論文引入了一個由因果權重加權的策略熵作為因果關系感知的熵最大化目標,有效地加強了對重要原始行為的探索,并導致了更有效的探索。
此處可供學習的資料:
【3】 梯度“休眠”現(xiàn)象(Gradient Dormancy) :論文觀察到,模型訓練時有些梯度會在某些階段不活躍(即“休眠”)。為了防止模型過度關注某些原始行為,論文引入了 梯度休眠導向的重置機制 。該機制通過周期性地對模型進行擾動(reset),避免模型陷入局部最優(yōu),促進更廣泛的探索。
作出上述改進的原因是:該機制通過一個由梯度休眠程度決定的因素間歇性地干擾智能體的神經(jīng)網(wǎng)絡。將因果關系感知探索與這種新穎的重置機制相結合,旨在促進更高效、更有效的探索,最終提高智能體的整體性能。
通過在多個連續(xù)控制任務中的實驗,ACE 展示出了顯著優(yōu)于主流強化學習算法(如SAC、TD3)的表現(xiàn):
論文中的對比實驗圖表顯示了 ACE 在多種任務下的顯著優(yōu)勢,尤其是在 稀疏獎勵和高維度任務 中,ACE 憑借其探索效率的提升,能更快達到最優(yōu)策略。
在ACE原論文的第21頁,這玩意兒應該寫在正篇的,害的我看了好久的代碼去排流程
不過說實話這偽代碼有夠簡潔的,代碼多少有點糊成一坨了
這是一個強化學習(RL)算法的框架,具體是一個結合因果推斷(Causal Discovery)的離策略(Off-policy)Actor-Critic方法。下面是對每個模塊及其參數(shù)的說明:
源代碼上千行呢,這里只是貼上main_casual里面的部分代碼,并且刪掉了很大一部分代碼以便理清程序脈絡
def train_loop(config, msg = "default"):
# Agent
agent = ACE_agent(env.observation_space.shape[0], env.action_space, config)
memory = ReplayMemory(config.replay_size, config.seed)
local_buffer = ReplayMemory(config.causal_sample_size, config.seed)
for i_episode in itertools.count(1):
done = False
state = env.reset()
while not done:
if config.start_steps > total_numsteps:
action = env.action_space.sample() # Sample random action
else:
action = agent.select_action(state) # Sample action from policy
if len(memory) > config.batch_size:
for i in range(config.updates_per_step):
#* Update parameters of causal weight
if (total_numsteps % config.causal_sample_interval == 0) and (len(local_buffer)>=config.causal_sample_size):
causal_weight, causal_computing_time = get_sa2r_weight(env, local_buffer, agent, sample_size=config.causal_sample_size, causal_method='DirectLiNGAM')
print("Current Causal Weight is: ",causal_weight)
dormant_metrics = {}
# Update parameters of all the networks
critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha, q_sac, dormant_metrics = agent.update_parameters(memory, causal_weight,config.batch_size, updates)
updates += 1
next_state, reward, done, info = env.step(action) # Step
total_numsteps += 1
episode_steps += 1
episode_reward += reward
#* Ignore the "done" signal if it comes from hitting the time horizon.
if '_max_episode_steps' in dir(env):
mask = 1 if episode_steps == env._max_episode_steps else float(not done)
elif 'max_path_length' in dir(env):
mask = 1 if episode_steps == env.max_path_length else float(not done)
else:
mask = 1 if episode_steps == 1000 else float(not done)
memory.push(state, action, reward, next_state, mask) # Append transition to memory
local_buffer.push(state, action, reward, next_state, mask) # Append transition to local_buffer
state = next_state
if total_numsteps > config.num_steps:
break
# test agent
if i_episode % config.eval_interval == 0 and config.eval is True:
eval_reward_list = []
for _ in range(config.eval_episodes):
state = env.reset()
episode_reward = []
done = False
while not done:
action = agent.select_action(state, evaluate=True)
next_state, reward, done, info = env.step(action)
state = next_state
episode_reward.append(reward)
eval_reward_list.append(sum(episode_reward))
avg_reward = np.average(eval_reward_list)
env.close()
初始化 :
config
設置環(huán)境和隨機種子。
ACE_agent
初始化強化學習智能體,該智能體會在后續(xù)過程中學習如何在環(huán)境中行動。
memory
用于存儲所有的歷史數(shù)據(jù),
local_buffer
則用于因果權重的更新。
主訓練循環(huán) :
采樣動作 :如果總步數(shù)較小,則從環(huán)境中隨機采樣動作,否則從策略中選擇動作。通過這種方式,確保早期探索和后期利用。
更新因果權重
:在特定間隔內(nèi),從局部緩沖區(qū)中采樣數(shù)據(jù),通過
get_sa2r_weight
函數(shù)使用DirectLiNGAM算法計算從動作到獎勵的因果權重。這個權重會作為額外信息,幫助智能體優(yōu)化策略。
更新網(wǎng)絡參數(shù)
:當
memory
中的數(shù)據(jù)足夠多時,開始通過采樣更新Q網(wǎng)絡和策略網(wǎng)絡,使用計算出的因果權重來修正損失函數(shù)。
記錄與保存模型 :每隔一定的步數(shù),算法會測試當前策略的性能,記錄并比較獎勵是否超過歷史最佳值,如果是,則保存模型的檢查點。
使用
wandb
記錄訓練過程中的指標,例如損失函數(shù)、獎勵和因果權重的計算時間,這些信息可以幫助調(diào)試和分析訓練過程。
因果發(fā)現(xiàn)模塊
主要通過
get_sa2r_weight
函數(shù)實現(xiàn),并且與
DirectLiNGAM
模型結合,負責計算因果權重。具體代碼在訓練循環(huán)中如下:
causal_weight, causal_computing_time = get_sa2r_weight(env, local_buffer, agent, sample_size=config.causal_sample_size, causal_method='DirectLiNGAM')
在這個代碼段,
get_sa2r_weight
函數(shù)會基于當前環(huán)境、樣本數(shù)據(jù)(
local_buffer
)和因果模型(這里使用的是
DirectLiNGAM
),計算與行動相關的因果權重(
causal_weight
)。這些權重會影響后續(xù)的策略優(yōu)化和參數(shù)更新。關鍵邏輯包括:
total_numsteps % config.causal_sample_interval == 0
時觸發(fā),確保只在指定的步數(shù)間隔內(nèi)計算因果權重,避免每一步都進行因果計算,減輕計算負擔。
local_buffer
中存儲了足夠的樣本(
config.causal_sample_size
),這些樣本用于因果關系的發(fā)現(xiàn)。
DirectLiNGAM
是選擇的因果模型,用于從狀態(tài)、行動和獎勵之間推導出因果關系。
因果權重計算完成后,程序會將這些權重應用到策略優(yōu)化中,并且記錄權重及計算時間等信息。
def get_sa2r_weight(env, memory, agent, sample_size=5000, causal_method='DirectLiNGAM'):
······
return weight, model._running_time
這個代碼的核心是利用DirectLiNGAM模型計算給定狀態(tài)、動作和獎勵之間的因果權重。接下來,用LaTeX公式詳細表述計算因果權重的過程:
數(shù)據(jù)預處理
:
將從
memory
中采樣的
states
(狀態(tài))、
actions
(動作)和
rewards
(獎勵)進行拼接,構建輸入數(shù)據(jù)矩陣
\(X_{\text{ori}}\)
:
其中, \(S\) 代表狀態(tài), \(A\) 代表動作, \(R\) 代表獎勵。接著,構建數(shù)據(jù)框 \(X\) 來進行因果分析。
因果模型擬合 :
將
X_ori
轉換為
X
是為了利用
pandas
數(shù)據(jù)框的便利性和靈活性
使用 DirectLiNGAM 模型對矩陣 \(X\) 進行擬合,得到因果關系的鄰接矩陣 \(A_{\text{model}}\) :
該鄰接矩陣表示狀態(tài)、動作、獎勵之間的因果結構,特別是從動作到獎勵的影響關系。
提取動作對獎勵的因果權重
:
通過鄰接矩陣提取動作對獎勵的因果權重
\(w_{\text{r}}\)
,該權重從鄰接矩陣的最后一行中選擇與動作對應的元素:
其中, \(d_s\) 是狀態(tài)的維度, \(d_a\) 是動作的維度。
因果權重的歸一化
:
對因果權重
\(w_{\text{r}}\)
進行Softmax歸一化,確保它們的總和為1:
調(diào)整權重的尺度
:
最后,因果權重根據(jù)動作的數(shù)量進行縮放:
最終輸出的權重 \(w\) 表示每個動作對獎勵的因果影響,經(jīng)過歸一化和縮放處理,可以用于進一步的策略調(diào)整或分析。
以下是對函數(shù)工作原理的逐步解釋:
策略優(yōu)化模塊
主要由
agent.update_parameters
函數(shù)實現(xiàn)。
agent.update_parameters
這個函數(shù)的主要目的是在強化學習中更新策略 (
policy
) 和價值網(wǎng)絡(critic)的參數(shù),以提升智能體的性能。這個函數(shù)實現(xiàn)了一個基于軟演員評論家(SAC, Soft Actor-Critic)的更新機制,并且加入了因果權重與"休眠"神經(jīng)元(dormant neurons)的處理,以提高模型的魯棒性和穩(wěn)定性。
critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha, q_sac, dormant_metrics = agent.update_parameters(memory, causal_weight, config.batch_size, updates)
通過
agent.update_parameters
函數(shù),程序會更新以下幾個部分:
critic_1_loss
和
critic_2_loss
分別是兩個 Critic 網(wǎng)絡的損失,用于評估當前策略的價值。
policy_loss
表示策略網(wǎng)絡的損失,用于優(yōu)化 agent 的行動選擇。
ent_loss
用來調(diào)節(jié)策略的隨機性,幫助 agent 在探索和利用之間找到平衡。
這些參數(shù)的更新在每次訓練循環(huán)中被調(diào)用,并使用
wandb.log
記錄損失和其他相關的訓練數(shù)據(jù)。
update_parameters
是
ACE_agent
類中的一個關鍵函數(shù),用于根據(jù)經(jīng)驗回放緩沖區(qū)中的樣本數(shù)據(jù)來更新模型的參數(shù)。下面是對其工作原理的詳細解釋:
首先,函數(shù)從
memory
中采樣一批樣本(
state_batch
、
action_batch
、
reward_batch
、
next_state_batch
、
mask_batch
),其中包括狀態(tài)、動作、獎勵、下一個狀態(tài)以及掩碼,用于表示是否為終止狀態(tài)。
state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)
state_batch
:當前的狀態(tài)。
action_batch
:在當前狀態(tài)下執(zhí)行的動作。
reward_batch
:執(zhí)行該動作后獲得的獎勵。
next_state_batch
:執(zhí)行動作后到達的下一個狀態(tài)。
mask_batch
:掩碼,用于表示是否為終止狀態(tài)(1 表示非終止,0 表示終止)。
利用當前策略(policy)網(wǎng)絡,采樣下一個狀態(tài)的動作
next_state_action
和其對應的概率分布對數(shù)
next_state_log_pi
。然后利用目標 Q 網(wǎng)絡
critic_target
估計下一時刻的最小 Q 值,并結合獎勵和折扣因子
\(\gamma\)
計算下一個 Q 值:
with torch.no_grad():
next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch, causal_weight)
qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action)
min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target)
通過策略網(wǎng)絡
self.policy
為下一個狀態(tài)
next_state_batch
采樣動作
next_state_action
和相應的策略熵
next_state_log_pi
。
使用目標 Q 網(wǎng)絡計算
qf1_next_target
和
qf2_next_target
,并取兩者的最小值來減少估計偏差。
最終使用貝爾曼方程計算
next_q_value
,即當前的獎勵加上折扣因子
\(\gamma\)
乘以下一個狀態(tài)的 Q 值。
這里,
\(\alpha\)
是熵項的權重,用于平衡探索和利用的權衡,而
mask_batch
是為了處理終止狀態(tài)的情況。
使用無偏估計來計算目標 Q 值。通過目標網(wǎng)絡 (
critic_target
) 計算出下一個狀態(tài)和動作的 Q 值,并使用獎勵和掩碼更新當前 Q 值
接著,使用當前 Q 網(wǎng)絡
critic
估計當前狀態(tài)和動作下的 Q 值
\(Q_1\)
和
\(Q_2\)
,并計算它們與目標 Q 值的均方誤差損失:
最終 Q 網(wǎng)絡的總損失是兩個 Q 網(wǎng)絡損失之和:
然后,通過反向傳播
qf_loss
來更新 Q 網(wǎng)絡的參數(shù)。
qf1, qf2 = self.critic(state_batch, action_batch)
qf1_loss = F.mse_loss(qf1, next_q_value)
qf2_loss = F.mse_loss(qf2, next_q_value)
qf_loss = qf1_loss + qf2_loss
self.critic_optim.zero_grad()
qf_loss.backward()
self.critic_optim.step()
qf1
和
qf2
是兩個 Q 網(wǎng)絡的輸出,用于減少正向估計偏差。
qf1_loss
和
qf2_loss
分別計算兩個 Q 網(wǎng)絡的誤差,最后將兩者相加為總的 Q 損失
qf_loss
。
self.critic_optim
優(yōu)化器對損失進行反向傳播和參數(shù)更新。
每隔若干步(通過
target_update_interval
控制),開始更新策略網(wǎng)絡
policy
。首先,重新采樣當前狀態(tài)下的策略
\(\pi(a|s)\)
,并計算 Q 值和熵權重下的策略損失:
這個損失通過反向傳播更新策略網(wǎng)絡。
if updates % self.target_update_interval == 0:
pi, log_pi, _ = self.policy.sample(state_batch, causal_weight)
qf1_pi, qf2_pi = self.critic(state_batch, pi)
min_qf_pi = torch.min(qf1_pi, qf2_pi)
policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()
self.policy_optim.zero_grad()
policy_loss.backward()
self.policy_optim.step()
state_batch
進行采樣,得到動作
pi
及其對應的策略熵
log_pi
。
policy_loss
,即
\(\alpha\)
倍的策略熵減去最小的 Q 值。
self.policy_optim
優(yōu)化器對策略損失進行反向傳播和參數(shù)更新。
如果開啟了自動熵項調(diào)整(
automatic_entropy_tuning
),則會進一步更新熵項
\(\alpha\)
的損失:
并通過梯度下降更新 \(\alpha\) 。
如果
automatic_entropy_tuning
為真,則會更新熵項:
if self.automatic_entropy_tuning:
alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
self.alpha_optim.zero_grad()
alpha_loss.backward()
self.alpha_optim.step()
self.alpha = self.log_alpha.exp()
alpha_tlogs = self.alpha.clone()
else:
alpha_loss = torch.tensor(0.).to(self.device)
alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs
alpha_loss
更新
self.alpha
,調(diào)整策略的探索-利用平衡。
qf1_loss
,
qf2_loss
: 兩個 Q 網(wǎng)絡的損失
policy_loss
: 策略網(wǎng)絡的損失
alpha_loss
: 熵權重的損失
alpha_tlogs
: 用于日志記錄的熵權重
next_q_value
: 平均下一個 Q 值
dormant_metrics
: 休眠神經(jīng)元的相關度量
重置機制模塊在代碼中主要體現(xiàn)在
update_parameters
函數(shù)中,并通過
梯度主導度
(dominant metrics) 和
擾動函數(shù)
(perturbation functions) 實現(xiàn)對策略網(wǎng)絡和 Q 網(wǎng)絡的重置。
函數(shù)根據(jù)設定的
reset_interval
判斷是否需要對策略網(wǎng)絡和 Q 網(wǎng)絡進行擾動和重置。這里使用了"休眠"神經(jīng)元的概念,即一些在梯度更新中影響較小的神經(jīng)元,可能會被調(diào)整或重置。
函數(shù)計算了休眠度量
dormant_metrics
和因果權重差異
causal_diff
,通過擾動因子
perturb_factor
來決定是否對網(wǎng)絡進行部分或全部的擾動與重置。
重置機制主要由以下部分組成:
在更新策略時,計算
主導梯度
,即某些特定神經(jīng)元或參數(shù)在更新中主導作用的比率。代碼中通過調(diào)用
cal_dormant_grad(self.policy, type='policy', percentage=0.05)
實現(xiàn)這一計算,代表提取出 5% 的主導梯度來作為判斷因子。
dormant_metrics = cal_dormant_grad(self.policy, type='policy', percentage=0.05)
根據(jù)主導度 ($ \beta_\gamma$ ) 和權重 ($ w$ ),可以得到因果效應的差異。代碼里用
causal_diff
來表示因果差異:
軟重置機制通過平滑更新策略網(wǎng)絡和 Q 網(wǎng)絡,避免過大的權重更新導致的網(wǎng)絡不穩(wěn)定。這在代碼中由
soft_update
實現(xiàn):
soft_update(self.critic_target, self.critic, self.tau)
具體來說,軟更新的公式為:
其中,( \(\tau\) ) 是一個較小的常數(shù),通常介于 ( [0, 1] ) 之間,確保目標網(wǎng)絡的更新是緩慢的,以提高學習的穩(wěn)定性。
每當經(jīng)過一定的重置間隔時,判斷是否需要擾動策略和 Q 網(wǎng)絡。通過調(diào)用
perturb()
和
dormant_perturb()
實現(xiàn)對網(wǎng)絡的擾動(perturbation)。擾動因子由梯度主導度和因果差異共同決定。
策略與 Q 網(wǎng)絡的擾動會在以下兩種情況下發(fā)生:
代碼中每當更新次數(shù)
updates
達到設定的重置間隔
self.reset_interval
,并且
updates > 5000
時,才會觸發(fā)策略與 Q 網(wǎng)絡的重置邏輯。這是為了確保擾動不是頻繁發(fā)生,而是在經(jīng)過一段較長的訓練時間后進行。
具體判斷條件:
if updates % self.reset_interval == 0 and updates > 5000:
在達到了重置間隔后,首先會計算
梯度主導度
或
因果效應的差異
。這可以通過計算因果差異
causal_diff
或梯度主導度
dormant_metrics['policy_grad_dormant_ratio']
來決定是否需要擾動。
梯度主導度
計算方式通過
cal_dormant_grad()
函數(shù)實現(xiàn),如果梯度主導度較低,意味著網(wǎng)絡中的某些神經(jīng)元更新幅度過小,則需要對網(wǎng)絡進行擾動。
因果效應差異
通過計算
causal_diff = np.max(causal_weight) - np.min(causal_weight)
得到,如果差異過大,則可能需要重置。
然后根據(jù)這些值通過擾動因子
factor
進行判斷:
factor = perturb_factor(dormant_metrics['policy_grad_dormant_ratio'])
如果擾動因子 ( \(\text{factor} < 1\) ),網(wǎng)絡會進行擾動:
if factor < 1:
if self.reset == 'reset' or self.reset == 'causal_reset':
perturb(self.policy, self.policy_optim, factor)
perturb(self.critic, self.critic_optim, factor)
perturb(self.critic_target, self.critic_optim, factor)
updates > 5000
)。
這兩種條件同時滿足時,策略和 Q 網(wǎng)絡將被擾動或重置。
在這段代碼中,
factor
是基于網(wǎng)絡中梯度主導度或者因果效應差異計算出來的擾動因子。擾動因子通過函數(shù)
perturb_factor()
進行計算,該函數(shù)會根據(jù)神經(jīng)元的梯度主導度(
dormant_ratio
)或因果效應差異(
causal_diff
)來調(diào)整
factor
的大小。
擾動因子
factor
的計算公式如下:
其中:
( \(\text{dormant\_ratio}\) ) 是網(wǎng)絡中梯度主導度,即表示有多少神經(jīng)元的梯度變化較。ɑ蛘呓咏悖幱凇靶菝摺睜顟B(tài)。
(
\(\text{min\_perturb\_factor}\)
) 是最小擾動因子值,代碼中設定為
0.2
。
(
\(\text{max\_perturb\_factor}\)
) 是最大擾動因子值,代碼中設定為
0.9
。
dormant_ratio :
dormant_ratio
越大,表示越多神經(jīng)元的梯度變化很小,說明網(wǎng)絡更新不充分,需要擾動。
max_perturb_factor :
min_perturb_factor :
在計算因果效應的部分,擾動因子
factor
還會根據(jù)因果效應差異
causal_diff
來調(diào)整。
causal_diff
是通過計算因果效應的最大值與最小值的差異來獲得的:
計算出的
causal_diff
會影響
causal_factor
,并進一步對
factor
進行調(diào)整:
最后,如果選擇了因果重置(
causal_reset
),擾動因子將使用因果差異計算出的
causal_factor
進行二次調(diào)整:
綜上所述,
factor
的最終值是由梯度主導度或因果效應差異來控制的,當休眠神經(jīng)元比例較大或因果效應差異較大時,
factor
會減小,導致網(wǎng)絡進行擾動。
這段代碼主要實現(xiàn)了在強化學習(RL)訓練過程中,定期評估智能體(agent)的性能,并在某些條件下保存最佳模型的檢查點。我們可以分段解釋該代碼:
if i_episode % config.eval_interval == 0 and config.eval is True:
這部分代碼用于判斷是否應該執(zhí)行智能體的評估。條件為:
i_episode % config.eval_interval == 0
:表示每隔
config.eval_interval
個訓練回合(
i_episode
是當前回合數(shù))進行一次評估。
config.eval is True
:確保
eval
設置為
True
,也就是說,評估功能開啟。
如果滿足這兩個條件,代碼將開始執(zhí)行評估操作。
eval_reward_list = []
用于存儲每個評估回合(episode)的累計獎勵,以便之后計算平均獎勵。
for _ in range(config.eval_episodes):
評估階段將運行多個回合(由
config.eval_episodes
指定的回合數(shù)),以獲得智能體的表現(xiàn)。
state = env.reset()
episode_reward = []
done = False
env.reset()
:重置環(huán)境,獲得初始狀態(tài)
state
。
episode_reward
:初始化一個列表,用于存儲當前回合中智能體獲得的所有獎勵。
done = False
:用
done
來跟蹤當前回合是否結束。
while not done:
action = agent.select_action(state, evaluate=True)
next_state, reward, done, info = env.step(action)
state = next_state
episode_reward.append(reward)
動作選擇
:
agent.select_action(state, evaluate=True)
在評估模式下根據(jù)當前狀態(tài)
state
選擇動作。
evaluate=True
表示該選擇是在評估模式下,通常意味著探索行為被關閉(即不進行隨機探索,而是選擇最優(yōu)動作)。
環(huán)境反饋
:
next_state, reward, done, info = env.step(action)
通過執(zhí)行動作
action
,環(huán)境返回下一個狀態(tài)
next_state
,當前獎勵
reward
,回合是否結束的標志
done
,以及附加信息
info
。
狀態(tài)更新
:當前狀態(tài)被更新為
next_state
,并將獲得的獎勵
reward
存儲在
episode_reward
列表中。
循環(huán)持續(xù),直到回合結束(即
done == True
)。
eval_reward_list.append(sum(episode_reward))
當前回合結束后,累計獎勵(
sum(episode_reward)
)被添加到
eval_reward_list
,用于后續(xù)計算平均獎勵。
avg_reward = np.average(eval_reward_list)
在所有評估回合結束后,計算
eval_reward_list
的平均值
avg_reward
。這是當前評估階段智能體的表現(xiàn)指標。
if config.save_checkpoint:
if avg_reward >= best_reward:
best_reward = avg_reward
agent.save_checkpoint(checkpoint_path, 'best')
config.save_checkpoint
為
True
,則表示需要檢查是否保存模型。
avg_reward
是否超過了之前的最佳獎勵
best_reward
,如果是,則更新
best_reward
,并保存當前模型的檢查點。
agent.save_checkpoint(checkpoint_path, 'best')
這行代碼會將智能體的狀態(tài)保存到指定的路徑
checkpoint_path
,并標記為
"best"
,表示這是性能最佳的模型。
咳咳,可以發(fā)現(xiàn)程序只記錄了 0~1000 的數(shù)據(jù),從 1001 開始的每一個數(shù)據(jù)都顯示報錯所以被舍棄掉了。
后面重新下載了github代碼包,發(fā)生了同樣的報錯信息
報錯信息是:你在 X+1 輪次中嘗試記載 X 輪次中的信息,所以這個數(shù)據(jù)被舍棄掉了
大概是主程序哪里有問題吧,我自己也沒調(diào) bug
不過這個項目結題了,主要負責這個項目的博士師兄也畢業(yè)了,也不好說些什么(雖然我有他微信),至少論文里面的模塊挺有用的。ㄊ謩踊
本站所有軟件,都由網(wǎng)友上傳,如有侵犯你的版權,請發(fā)郵件[email protected]
湘ICP備2022002427號-10 湘公網(wǎng)安備:43070202000427號© 2013~2025 haote.com 好特網(wǎng)