본문 바로가기

A3C & A2C – CartPole 구현 코드 뜯어보기 (PyTorch + 멀티프로세싱)

@eunyoung-study2026. 3. 17. 15:29

1. 이 글에서 다루는 내용

  • CartPole용 공통 Actor-Critic 네트워크 구조
  • A3C 구현 코드 흐름 (멀티프로세스 + Global / Local 모델)
  • A2C 구현 코드 흐름 (ParallelEnv + 동기식 업데이트)
  • 각 부분이 어떤 역할을 하는지를 중심으로 설명합니다.

알고리즘 개념(A3C가 뭔지, A2C가 왜 나왔는지)은 블로그에 따로 다뤘으니,

여기서는 코드 레벨에서 어떻게 돌아가는지에 집중합니다.


2. 공통 Actor-Critic 네트워크 구조

A3C와 A2C 모두 CartPole 상태(4차원)를 입력으로 받는 같은 형태의 Actor-Critic 네트워크를 사용합니다.

class ActorCritic(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 256)
        self.fc_pi = nn.Linear(256, 2)
        self.fc_v = nn.Linear(256, 1)

    def pi(self, x, softmax_dim=0):
        x = F.relu(self.fc1(x))
        x = self.fc_pi(x)
        return F.softmax(x, dim=softmax_dim)

    def v(self, x):
        x = F.relu(self.fc1(x))
        return self.fc_v(x)
  • 입력: CartPole 상태 [카트 위치, 속도, 막대 각도, 각속도] → 길이 4
  • 공통 은닉층(fc1): 상태에서 의미 있는 표현을 추출
  • fc_pi (Actor):
    • 출력 차원 2 – 행동 2개(왼쪽, 오른쪽)에 대한 로짓
    • softmax를 통과하면 정책 확률 (\pi(a|s))가 됩니다.
  • fc_v (Critic):
    • 출력 차원 1 – 현재 상태 가치 (V(s))

이 네트워크를 A3C에서는 Global / Local 모델로, A2C에서는 하나의 공유 모델로 사용합니다.


3. A3C 구현 – Global / Local 모델과 멀티프로세싱

3.1 전체 구조 개요

if __name__ == "__main__": 부분을 보면 A3C의 전체 구조가 한눈에 보입니다.

if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)

    global_model = ActorCritic()
    global_model.share_memory()

    processes = []
    # 학습 프로세스 3개, 테스트 프로세스 1개
    for rank in range(n_train_processes + 1):
        if rank == 0:
            p = mp.Process(target=test, args=(global_model,))
        else:
            p = mp.Process(target=train, args=(global_model, rank))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()
        print("process exitcode:", p.exitcode, flush=True)
  • mp.set_start_method("spawn", force=True):
    • 멀티프로세싱 시작 방식을 spawn 으로 설정 (새 프로세스를 깨끗하게 띄움)
  • global_model = ActorCritic(); global_model.share_memory():
    • 전역 모델을 하나 만들고, 공유 메모리에 올립니다.
    • 여러 프로세스가 같은 모델 파라미터를 함께 쓰고 업데이트할 수 있게 하는 설정입니다.
  • rank == 0:
    • 테스트 전용 프로세스 (test) 실행
  • rank >= 1:
    • 학습용 worker 프로세스 (train) 실행

즉,

1개의 Global 모델 + N개의 학습 worker + 1개의 테스트 프로세스
가 동시에 돌아가는 구조입니다.


3.2 train(worker) 함수 – Local 모델과 rollout

def train(global_model, rank):
    local_model = ActorCritic()
    local_model.load_state_dict(global_model.state_dict())

    optimizer = optim.Adam(global_model.parameters(), lr=learning_rate)
    env = gym.make("CartPole-v1")
  • 각 worker는 시작할 때:
    • Global 모델의 파라미터를 복사해서 Local 모델을 만듭니다.
    • Optimizer는 Global 모델 파라미터를 대상으로 합니다.

에피소드 루프 안에서는 이렇게 동작합니다.

    for n_epi in range(max_train_ep):
        done = False
        s, _ = env.reset()

        while not done:
            s_lst, a_lst, r_lst = [], [], []

            for _ in range(update_interval):
                prob = local_model.pi(torch.from_numpy(s).float())
                m = Categorical(prob)
                a = m.sample().item()

                s_prime, r, terminated, truncated, info = env.step(a)
                done = terminated or truncated

                s_lst.append(s)
                a_lst.append([a])
                r_lst.append(r / 100.0)

                s = s_prime
                if done:
                    break
  • Local 모델로 정책 확률 prob = π(a|s)를 구하고,
    • Categorical(prob)에서 행동 a를 샘플링합니다.
  • 환경 한 스텝 진행 → (s_prime, r, done)
  • 상태/행동/보상을 update_interval 스텝만큼 모으거나,
    • 에피소드가 끝날 때까지 리스트에 쌓습니다.

3.3 TD 타깃과 Advantage 계산

            s_final = torch.tensor(s_prime, dtype=torch.float)
            R = 0.0 if done else local_model.v(s_final).item()

            td_target_lst = []
            for reward in r_lst[::-1]:
                R = gamma * R + reward
                td_target_lst.append([R])
            td_target_lst.reverse()

            s_batch = torch.from_numpy(np.array(s_lst)).float()
            a_batch = torch.tensor(a_lst)
            td_target = torch.tensor(td_target_lst, dtype=torch.float)

            advantage = td_target - local_model.v(s_batch)
  • 에피소드가 끝나지 않았다면 마지막 상태 가치 V(s_final)를 초기값으로 사용하고,
    • 보상 리스트를 뒤에서부터 돌면서 n-step TD 타깃을 계산합니다.
  • 그 후,

\[
\text{advantage} = \text{td_target} - V(s)
\]

을 이용해 Advantage를 계산합니다.


3.4 손실 계산과 Global 모델 업데이트

            pi = local_model.pi(s_batch, softmax_dim=1)
            pi_a = pi.gather(1, a_batch)

            value_loss = F.smooth_l1_loss(local_model.v(s_batch), td_target.detach())
            policy_loss = -torch.log(pi_a) * advantage.detach()
            loss = policy_loss + value_loss

            local_model.zero_grad()
            optimizer.zero_grad()
            loss.mean().backward()

            for global_param, local_param in zip(global_model.parameters(), local_model.parameters()):
                global_param._grad = local_param.grad

            optimizer.step()
            local_model.load_state_dict(global_model.state_dict())
  • pi: 각 상태에서 [왼쪽, 오른쪽] 행동 확률
  • pi_a: 실제 선택한 행동의 확률 (\pi(a|s))
  • Policy Loss: -log(pi_a) * advantage
    • Advantage > 0 → 해당 행동 확률 ↑
    • Advantage < 0 → 해당 행동 확률 ↓
  • Value Loss: SmoothL1로 V(s)를 TD 타깃에 맞추도록 학습

핵심은 아래 부분입니다.

for global_param, local_param in zip(global_model.parameters(), local_model.parameters()):
    global_param._grad = local_param.grad
  • Local 모델에서 계산된 gradient를
    • Global 모델의 _grad에 복사하고
    • optimizer.step()은 Global 모델만 업데이트합니다.
  • 그 다음, Global 모델의 최신 파라미터를 다시 Local 모델에 복사해 동기화합니다.

이 과정을 여러 프로세스가 동시에 수행하는 것이 A3C 구현의 핵심입니다.


3.5 테스트 프로세스

def test(global_model):
    env = gym.make("CartPole-v1")
    score = 0.0
    print_interval = 5

    for n_epi in range(max_test_ep):
        done = False
        s, _ = env.reset()

        while not done:
            prob = global_model.pi(torch.from_numpy(s).float())
            a = Categorical(prob).sample().item()

            s_prime, r, terminated, truncated, info = env.step(a)
            done = terminated or truncated

            s = s_prime
            score += r

        if n_epi % print_interval == 0 and n_epi != 0:
            print(f"[TEST] episode={n_epi}, avg score={score / print_interval:.1f}", flush=True)
            score = 0.0
            time.sleep(0.5)
  • Global 모델이 현재 어느 정도 성능인지 주기적으로 출력하는 코드입니다.
  • 학습과 테스트가 동시에 돌아가므로, 학습이 진행될수록 평균 점수가 올라가는지 확인할 수 있습니다.

주의: 멀티프로세싱 코드이기 때문에
Colab보다는 로컬(PyCharm, VSCode, 터미널) 환경에서 실행하는 것이 안정적입니다.


4. A2C 구현 – ParallelEnv와 동기식 업데이트

두 번째 절에서는 A2C 코드를 구현합니다.
핵심은 여러 환경을 동시에 돌리되, 업데이트는 한 번에 동기적으로 하는 것입니다.

4.1 ParallelEnv – 여러 환경을 한 번에 관리

먼저 각 worker 프로세스에서 실행할 worker()와, 이를 감싸는 ParallelEnv 클래스를 정의합니다.

def worker(worker_id, master_end, worker_end):
    master_end.close()
    env = gym.make("CartPole-v1")
    obs, _ = env.reset(seed=worker_id)

    while True:
        cmd, data = worker_end.recv()

        if cmd == "step":
            obs, reward, terminated, truncated, info = env.step(int(data))
            done = terminated or truncated
            if done:
                obs, _ = env.reset()
            worker_end.send((obs, reward, done, info))

        elif cmd == "reset":
            obs, _ = env.reset()
            worker_end.send(obs)

        elif cmd == "close":
            env.close()
            worker_end.close()
            break

        elif cmd == "get_spaces":
            worker_end.send((env.observation_space, env.action_space))

        else:
            raise NotImplementedError
  • 각 worker는
    • step, reset, close 같은 명령을 파이프로 받아 환경을 조작합니다.

이를 여러 개 모아서 관리하는 것이 ParallelEnv입니다.

class ParallelEnv:
    def __init__(self, n_train_processes):
        self.nenvs = n_train_processes
        self.waiting = False
        self.closed = False
        self.workers = []

        master_ends, worker_ends = zip(*[mp.Pipe() for _ in range(self.nenvs)])
        self.master_ends = master_ends
        self.worker_ends = worker_ends

        for worker_id, (master_end, worker_end) in enumerate(zip(master_ends, worker_ends)):
            p = mp.Process(target=worker, args=(worker_id, master_end, worker_end))
            p.daemon = True
            p.start()
            self.workers.append(p)

        for worker_end in worker_ends:
            worker_end.close()
  • self.master_ends를 통해 여러 환경에 동시에 명령을 보낼 수 있습니다.

주요 메서드는 다음과 같습니다.

    def reset(self):
        for master_end in self.master_ends:
            master_end.send(("reset", None))
        return np.stack([master_end.recv() for master_end in self.master_ends]).astype(np.float32)

    def step_async(self, actions):
        for master_end, action in zip(self.master_ends, actions):
            master_end.send(("step", int(action)))
        self.waiting = True

    def step_wait(self):
        results = [master_end.recv() for master_end in self.master_ends]
        self.waiting = False
        obs, rews, dones, infos = zip(*results)
        return (
            np.stack(obs).astype(np.float32),
            np.array(rews, dtype=np.float32),
            np.array(dones, dtype=np.bool_),
            infos,
        )

    def step(self, actions):
        self.step_async(actions)
        return self.step_wait()
  • reset():
    • 모든 환경을 한 번에 초기화
  • step(actions):
    • actions 배열을 받아 각 환경에 행동을 보내고,
    • 그 결과를 한 번에 모아서 반환

이제 envs = ParallelEnv(n_train_processes)로 여러 CartPole 환경을 동시에 다룰 수 있습니다.


4.2 A2C 학습 루프

def main():
    envs = ParallelEnv(n_train_processes)
    model = ActorCritic()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    step_idx = 0
    s = envs.reset()

    while step_idx < max_train_steps:
        s_lst, a_lst, r_lst, mask_lst = [], [], [], []

        for _ in range(update_interval):
            with torch.no_grad():
                prob = model.pi(torch.from_numpy(s).float(), softmax_dim=1)

            a = Categorical(prob).sample().numpy()
            s_prime, r, done, info = envs.step(a)

            s_lst.append(s.copy())
            a_lst.append(a.copy())
            r_lst.append(r / 100.0)
            mask_lst.append(1.0 - done.astype(np.float32))

            s = s_prime
            step_idx += 1
  • s: shape (nenvs, 4)인 상태 배열
  • a: shape (nenvs,)인 행동 배열 – 각 환경마다 하나씩
  • r, done도 각각 nenvs 크기의 배열로 들어옵니다.
  • mask_lst:
    • done이면 0, 아니면 1 – TD 타깃 계산에서 종료 여부를 반영하는 용도

4.3 n-step TD Target 계산

def compute_target(v_final, r_lst, mask_lst):
    G = v_final.reshape(-1)
    td_target = []

    for r, mask in zip(r_lst[::-1], mask_lst[::-1]):
        G = r + gamma * G * mask
        td_target.append(G)

    td_target.reverse()
    return torch.tensor(np.array(td_target), dtype=torch.float32)
  • v_final: 마지막 상태들의 가치 (V(s_{T})) – shape (nenvs, 1)
  • r_lst, mask_lst: 길이 update_interval 인 리스트
    • 각 원소는 (nenvs,) 모양의 보상/마스크 배열
  • 뒤에서부터 거꾸로 돌면서
    • n-step Return을 계산해 td_target을 만듭니다.

학습 루프에서는 이렇게 사용합니다.

        s_final = torch.from_numpy(s_prime).float()
        with torch.no_grad():
            v_final = model.v(s_final).cpu().numpy()

        td_target = compute_target(v_final, r_lst, mask_lst)

        td_target_vec = td_target.reshape(-1)
        s_vec = torch.tensor(np.array(s_lst), dtype=torch.float32).reshape(-1, 4)
        a_vec = torch.tensor(np.array(a_lst), dtype=torch.long).reshape(-1).unsqueeze(1)

        values = model.v(s_vec).reshape(-1)
        advantage = td_target_vec - values

        pi = model.pi(s_vec, softmax_dim=1)
        pi_a = pi.gather(1, a_vec).reshape(-1)

        loss = -(torch.log(pi_a + 1e-8) * advantage.detach()).mean() + \
            F.smooth_l1_loss(values, td_target_vec)
  • s_vec, a_vec, td_target_vec:
    • 여러 환경 × 여러 step 데이터를 한 번에 큰 배치로 펼친 것
  • Advantage와 정책/가치 손실은 A3C와 거의 동일하지만,
    • 모든 worker에서 모은 데이터를 한꺼번에 업데이트한다는 점이 다릅니다.

이게 바로 A2C의 동기식(batch) 업데이트입니다.

더보기

[ 오늘의 정리 ] – 코드 관점에서 본 A3C vs A2C

  • A3C 코드 포인트
    • torch.multiprocessing으로 여러 train 프로세스를 띄우고,
    • 각 프로세스는 Local 모델로 rollout → Global 모델에 gradient 반영 → Local 재동기화
    • 업데이트는 비동기(asynchronous) – 각 프로세스가 제멋대로 타이밍에 맞춰 optimizer.step() 호출
  • A2C 코드 포인트
    • ParallelEnv로 여러 환경을 동시에 돌리되,
    • 하나의 Actor-Critic 모델로 모든 환경을 한꺼번에 보고,
    • 일정 step 마다 큰 배치로 한 번에 업데이트
    • GPU에서 쓰기 좋고, 재현성이 높으며, 구현이 단순

CartPole 예제를 통해 이 두 구현을 직접 실행해 보면,

“멀티프로세스로 경험을 모으는 방식은 비슷하지만,
A3C는 비동기 업데이트, A2C는 동기 batch 업데이트”
라는 차이를 코드 수준에서 자연스럽게 이해할 수 있습니다.

eunyoung-study
@eunyoung-study :: 은영의 이해 노트

개념을 이해하고, 논문을 풀어보고, 코드로 확인하는 기록 ! 오늘도 파이팅 😉

공감하셨다면 ❤️ 구독도 환영합니다! 🤗

목차