1. 이 글에서 다루는 내용
- CartPole용 공통 Actor-Critic 네트워크 구조
- A3C 구현 코드 흐름 (멀티프로세스 + Global / Local 모델)
- A2C 구현 코드 흐름 (ParallelEnv + 동기식 업데이트)
- 각 부분이 어떤 역할을 하는지를 중심으로 설명합니다.
알고리즘 개념(A3C가 뭔지, A2C가 왜 나왔는지)은 블로그에 따로 다뤘으니,
여기서는 코드 레벨에서 어떻게 돌아가는지에 집중합니다.
2. 공통 Actor-Critic 네트워크 구조
A3C와 A2C 모두 CartPole 상태(4차원)를 입력으로 받는 같은 형태의 Actor-Critic 네트워크를 사용합니다.
class ActorCritic(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(4, 256)
self.fc_pi = nn.Linear(256, 2)
self.fc_v = nn.Linear(256, 1)
def pi(self, x, softmax_dim=0):
x = F.relu(self.fc1(x))
x = self.fc_pi(x)
return F.softmax(x, dim=softmax_dim)
def v(self, x):
x = F.relu(self.fc1(x))
return self.fc_v(x)
- 입력: CartPole 상태
[카트 위치, 속도, 막대 각도, 각속도]→ 길이 4 - 공통 은닉층(
fc1): 상태에서 의미 있는 표현을 추출 fc_pi(Actor):- 출력 차원 2 – 행동 2개(왼쪽, 오른쪽)에 대한 로짓
softmax를 통과하면 정책 확률 (\pi(a|s))가 됩니다.
fc_v(Critic):- 출력 차원 1 – 현재 상태 가치 (V(s))
이 네트워크를 A3C에서는 Global / Local 모델로, A2C에서는 하나의 공유 모델로 사용합니다.
3. A3C 구현 – Global / Local 모델과 멀티프로세싱
3.1 전체 구조 개요
if __name__ == "__main__": 부분을 보면 A3C의 전체 구조가 한눈에 보입니다.
if __name__ == "__main__":
mp.set_start_method("spawn", force=True)
global_model = ActorCritic()
global_model.share_memory()
processes = []
# 학습 프로세스 3개, 테스트 프로세스 1개
for rank in range(n_train_processes + 1):
if rank == 0:
p = mp.Process(target=test, args=(global_model,))
else:
p = mp.Process(target=train, args=(global_model, rank))
p.start()
processes.append(p)
for p in processes:
p.join()
print("process exitcode:", p.exitcode, flush=True)
mp.set_start_method("spawn", force=True):- 멀티프로세싱 시작 방식을 spawn 으로 설정 (새 프로세스를 깨끗하게 띄움)
global_model = ActorCritic(); global_model.share_memory():- 전역 모델을 하나 만들고, 공유 메모리에 올립니다.
- 여러 프로세스가 같은 모델 파라미터를 함께 쓰고 업데이트할 수 있게 하는 설정입니다.
rank == 0:- 테스트 전용 프로세스 (
test) 실행
- 테스트 전용 프로세스 (
rank >= 1:- 학습용 worker 프로세스 (
train) 실행
- 학습용 worker 프로세스 (
즉,
1개의 Global 모델 + N개의 학습 worker + 1개의 테스트 프로세스
가 동시에 돌아가는 구조입니다.
3.2 train(worker) 함수 – Local 모델과 rollout
def train(global_model, rank):
local_model = ActorCritic()
local_model.load_state_dict(global_model.state_dict())
optimizer = optim.Adam(global_model.parameters(), lr=learning_rate)
env = gym.make("CartPole-v1")
- 각 worker는 시작할 때:
- Global 모델의 파라미터를 복사해서 Local 모델을 만듭니다.
- Optimizer는 Global 모델 파라미터를 대상으로 합니다.
에피소드 루프 안에서는 이렇게 동작합니다.
for n_epi in range(max_train_ep):
done = False
s, _ = env.reset()
while not done:
s_lst, a_lst, r_lst = [], [], []
for _ in range(update_interval):
prob = local_model.pi(torch.from_numpy(s).float())
m = Categorical(prob)
a = m.sample().item()
s_prime, r, terminated, truncated, info = env.step(a)
done = terminated or truncated
s_lst.append(s)
a_lst.append([a])
r_lst.append(r / 100.0)
s = s_prime
if done:
break
- Local 모델로 정책 확률
prob = π(a|s)를 구하고,Categorical(prob)에서 행동a를 샘플링합니다.
- 환경 한 스텝 진행 →
(s_prime, r, done) - 상태/행동/보상을
update_interval스텝만큼 모으거나,- 에피소드가 끝날 때까지 리스트에 쌓습니다.
3.3 TD 타깃과 Advantage 계산
s_final = torch.tensor(s_prime, dtype=torch.float)
R = 0.0 if done else local_model.v(s_final).item()
td_target_lst = []
for reward in r_lst[::-1]:
R = gamma * R + reward
td_target_lst.append([R])
td_target_lst.reverse()
s_batch = torch.from_numpy(np.array(s_lst)).float()
a_batch = torch.tensor(a_lst)
td_target = torch.tensor(td_target_lst, dtype=torch.float)
advantage = td_target - local_model.v(s_batch)
- 에피소드가 끝나지 않았다면 마지막 상태 가치
V(s_final)를 초기값으로 사용하고,- 보상 리스트를 뒤에서부터 돌면서 n-step TD 타깃을 계산합니다.
- 그 후,
\[
\text{advantage} = \text{td_target} - V(s)
\]
을 이용해 Advantage를 계산합니다.
3.4 손실 계산과 Global 모델 업데이트
pi = local_model.pi(s_batch, softmax_dim=1)
pi_a = pi.gather(1, a_batch)
value_loss = F.smooth_l1_loss(local_model.v(s_batch), td_target.detach())
policy_loss = -torch.log(pi_a) * advantage.detach()
loss = policy_loss + value_loss
local_model.zero_grad()
optimizer.zero_grad()
loss.mean().backward()
for global_param, local_param in zip(global_model.parameters(), local_model.parameters()):
global_param._grad = local_param.grad
optimizer.step()
local_model.load_state_dict(global_model.state_dict())
pi: 각 상태에서 [왼쪽, 오른쪽] 행동 확률pi_a: 실제 선택한 행동의 확률 (\pi(a|s))- Policy Loss:
-log(pi_a) * advantage- Advantage > 0 → 해당 행동 확률 ↑
- Advantage < 0 → 해당 행동 확률 ↓
- Value Loss: SmoothL1로 V(s)를 TD 타깃에 맞추도록 학습
핵심은 아래 부분입니다.
for global_param, local_param in zip(global_model.parameters(), local_model.parameters()):
global_param._grad = local_param.grad
- Local 모델에서 계산된 gradient를
- Global 모델의
_grad에 복사하고 optimizer.step()은 Global 모델만 업데이트합니다.
- Global 모델의
- 그 다음, Global 모델의 최신 파라미터를 다시 Local 모델에 복사해 동기화합니다.
이 과정을 여러 프로세스가 동시에 수행하는 것이 A3C 구현의 핵심입니다.
3.5 테스트 프로세스
def test(global_model):
env = gym.make("CartPole-v1")
score = 0.0
print_interval = 5
for n_epi in range(max_test_ep):
done = False
s, _ = env.reset()
while not done:
prob = global_model.pi(torch.from_numpy(s).float())
a = Categorical(prob).sample().item()
s_prime, r, terminated, truncated, info = env.step(a)
done = terminated or truncated
s = s_prime
score += r
if n_epi % print_interval == 0 and n_epi != 0:
print(f"[TEST] episode={n_epi}, avg score={score / print_interval:.1f}", flush=True)
score = 0.0
time.sleep(0.5)
- Global 모델이 현재 어느 정도 성능인지 주기적으로 출력하는 코드입니다.
- 학습과 테스트가 동시에 돌아가므로, 학습이 진행될수록 평균 점수가 올라가는지 확인할 수 있습니다.
주의: 멀티프로세싱 코드이기 때문에
Colab보다는 로컬(PyCharm, VSCode, 터미널) 환경에서 실행하는 것이 안정적입니다.
4. A2C 구현 – ParallelEnv와 동기식 업데이트
두 번째 절에서는 A2C 코드를 구현합니다.
핵심은 여러 환경을 동시에 돌리되, 업데이트는 한 번에 동기적으로 하는 것입니다.
4.1 ParallelEnv – 여러 환경을 한 번에 관리
먼저 각 worker 프로세스에서 실행할 worker()와, 이를 감싸는 ParallelEnv 클래스를 정의합니다.
def worker(worker_id, master_end, worker_end):
master_end.close()
env = gym.make("CartPole-v1")
obs, _ = env.reset(seed=worker_id)
while True:
cmd, data = worker_end.recv()
if cmd == "step":
obs, reward, terminated, truncated, info = env.step(int(data))
done = terminated or truncated
if done:
obs, _ = env.reset()
worker_end.send((obs, reward, done, info))
elif cmd == "reset":
obs, _ = env.reset()
worker_end.send(obs)
elif cmd == "close":
env.close()
worker_end.close()
break
elif cmd == "get_spaces":
worker_end.send((env.observation_space, env.action_space))
else:
raise NotImplementedError
- 각 worker는
step,reset,close같은 명령을 파이프로 받아 환경을 조작합니다.
이를 여러 개 모아서 관리하는 것이 ParallelEnv입니다.
class ParallelEnv:
def __init__(self, n_train_processes):
self.nenvs = n_train_processes
self.waiting = False
self.closed = False
self.workers = []
master_ends, worker_ends = zip(*[mp.Pipe() for _ in range(self.nenvs)])
self.master_ends = master_ends
self.worker_ends = worker_ends
for worker_id, (master_end, worker_end) in enumerate(zip(master_ends, worker_ends)):
p = mp.Process(target=worker, args=(worker_id, master_end, worker_end))
p.daemon = True
p.start()
self.workers.append(p)
for worker_end in worker_ends:
worker_end.close()
self.master_ends를 통해 여러 환경에 동시에 명령을 보낼 수 있습니다.
주요 메서드는 다음과 같습니다.
def reset(self):
for master_end in self.master_ends:
master_end.send(("reset", None))
return np.stack([master_end.recv() for master_end in self.master_ends]).astype(np.float32)
def step_async(self, actions):
for master_end, action in zip(self.master_ends, actions):
master_end.send(("step", int(action)))
self.waiting = True
def step_wait(self):
results = [master_end.recv() for master_end in self.master_ends]
self.waiting = False
obs, rews, dones, infos = zip(*results)
return (
np.stack(obs).astype(np.float32),
np.array(rews, dtype=np.float32),
np.array(dones, dtype=np.bool_),
infos,
)
def step(self, actions):
self.step_async(actions)
return self.step_wait()
reset():- 모든 환경을 한 번에 초기화
step(actions):actions배열을 받아 각 환경에 행동을 보내고,- 그 결과를 한 번에 모아서 반환
이제 envs = ParallelEnv(n_train_processes)로 여러 CartPole 환경을 동시에 다룰 수 있습니다.
4.2 A2C 학습 루프
def main():
envs = ParallelEnv(n_train_processes)
model = ActorCritic()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
step_idx = 0
s = envs.reset()
while step_idx < max_train_steps:
s_lst, a_lst, r_lst, mask_lst = [], [], [], []
for _ in range(update_interval):
with torch.no_grad():
prob = model.pi(torch.from_numpy(s).float(), softmax_dim=1)
a = Categorical(prob).sample().numpy()
s_prime, r, done, info = envs.step(a)
s_lst.append(s.copy())
a_lst.append(a.copy())
r_lst.append(r / 100.0)
mask_lst.append(1.0 - done.astype(np.float32))
s = s_prime
step_idx += 1
s: shape(nenvs, 4)인 상태 배열a: shape(nenvs,)인 행동 배열 – 각 환경마다 하나씩r,done도 각각nenvs크기의 배열로 들어옵니다.mask_lst:- done이면 0, 아니면 1 – TD 타깃 계산에서 종료 여부를 반영하는 용도
4.3 n-step TD Target 계산
def compute_target(v_final, r_lst, mask_lst):
G = v_final.reshape(-1)
td_target = []
for r, mask in zip(r_lst[::-1], mask_lst[::-1]):
G = r + gamma * G * mask
td_target.append(G)
td_target.reverse()
return torch.tensor(np.array(td_target), dtype=torch.float32)
v_final: 마지막 상태들의 가치 (V(s_{T})) – shape(nenvs, 1)r_lst,mask_lst: 길이update_interval인 리스트- 각 원소는
(nenvs,)모양의 보상/마스크 배열
- 각 원소는
- 뒤에서부터 거꾸로 돌면서
- n-step Return을 계산해
td_target을 만듭니다.
- n-step Return을 계산해
학습 루프에서는 이렇게 사용합니다.
s_final = torch.from_numpy(s_prime).float()
with torch.no_grad():
v_final = model.v(s_final).cpu().numpy()
td_target = compute_target(v_final, r_lst, mask_lst)
td_target_vec = td_target.reshape(-1)
s_vec = torch.tensor(np.array(s_lst), dtype=torch.float32).reshape(-1, 4)
a_vec = torch.tensor(np.array(a_lst), dtype=torch.long).reshape(-1).unsqueeze(1)
values = model.v(s_vec).reshape(-1)
advantage = td_target_vec - values
pi = model.pi(s_vec, softmax_dim=1)
pi_a = pi.gather(1, a_vec).reshape(-1)
loss = -(torch.log(pi_a + 1e-8) * advantage.detach()).mean() + \
F.smooth_l1_loss(values, td_target_vec)
s_vec,a_vec,td_target_vec:- 여러 환경 × 여러 step 데이터를 한 번에 큰 배치로 펼친 것
- Advantage와 정책/가치 손실은 A3C와 거의 동일하지만,
- 모든 worker에서 모은 데이터를 한꺼번에 업데이트한다는 점이 다릅니다.
이게 바로 A2C의 동기식(batch) 업데이트입니다.
[ 오늘의 정리 ] – 코드 관점에서 본 A3C vs A2C
- A3C 코드 포인트
- torch.multiprocessing으로 여러 train 프로세스를 띄우고,
- 각 프로세스는 Local 모델로 rollout → Global 모델에 gradient 반영 → Local 재동기화
- 업데이트는 비동기(asynchronous) – 각 프로세스가 제멋대로 타이밍에 맞춰 optimizer.step() 호출
- A2C 코드 포인트
- ParallelEnv로 여러 환경을 동시에 돌리되,
- 하나의 Actor-Critic 모델로 모든 환경을 한꺼번에 보고,
- 일정 step 마다 큰 배치로 한 번에 업데이트
- GPU에서 쓰기 좋고, 재현성이 높으며, 구현이 단순
CartPole 예제를 통해 이 두 구현을 직접 실행해 보면,
“멀티프로세스로 경험을 모으는 방식은 비슷하지만,
A3C는 비동기 업데이트, A2C는 동기 batch 업데이트”
라는 차이를 코드 수준에서 자연스럽게 이해할 수 있습니다.
'개발 기록실 > 실험 & 구현' 카테고리의 다른 글
| 데이콘 구조물 안정성 대회 도전기 – 전처리부터 Pseudo Labeling까지 (0) | 2026.03.30 |
|---|---|
| 랜덤 벽 GridWorld에서 TD Learning으로 상태가치 함수 배우기 (0) | 2026.03.13 |
| REINFORCE로 CartPole-v1 학습하기 – 정책 기반 에이전트 실습 (0) | 2026.03.13 |
| [YOLOv8 + RNN] 편의점/매장 이상행동(전도·파손) 탐지 파이프라인 만들기 (0) | 2026.03.09 |
| [OpenCV + Machine Learning] Kaggle 주조 제품 불량 이미지를 이용한 Random Forest 분류기 만들기 (0) | 2026.03.09 |