Skip to main content
Article
reinforcement-learningcontrol-systemspythonstable-baselines3ppogymnasiumai-agentrobotics

Control a CartPole with Reinforcement Learning

Use the Stable Baselines3 library to train a PPO agent that solves the classic CartPole control problem. This guide covers environment setup, agent training, and policy evaluation in a few lines of Python.

beginner15 min4 steps
The play
  1. Install Dependencies
    Install `stable-baselines3` for the Reinforcement Learning algorithms and `gymnasium` for the control environment. The `[extra]` option conveniently includes PyTorch, a required backend.
  2. Create the Control Environment
    Import `gymnasium` and create an instance of the `CartPole-v1` environment. This environment simulates a cart balancing a pole, a classic problem in Reinforcement Learning for Control. We use `make_vec_env` to create multiple parallel environments for faster training.
  3. Define and Train the Agent
    Import the Proximal Policy Optimization (PPO) algorithm, a robust choice for control tasks. Instantiate the model with a standard policy ('MlpPolicy') and the vectorized environment, then call `.learn()` to start training.
  4. Save and Evaluate the Trained Policy
    Save your trained model for later use. To see your agent in action, create a new, single environment with `render_mode='human'` and use `model.predict()` to choose actions in a loop.
Starter code
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

# Step 1: Create a vectorized environment for parallel training
print("Creating vectorized environment...")
vec_env = make_vec_env("CartPole-v1", n_envs=4)

# Step 2: Define the Reinforcement Learning agent (PPO)
# 'MlpPolicy' is a standard feedforward neural network policy.
model = PPO("MlpPolicy", vec_env, verbose=1)

# Step 3: Train the agent
# The agent will interact with the environment to learn a policy.
print("\nTraining the agent...")
model.learn(total_timesteps=25000)
print("Training complete.")

# Step 4: Save the trained model
model.save("ppo_cartpole")
print(f"Model saved to ppo_cartpole.zip")

# Step 5: Evaluate the trained agent visually
print("\nLoading and evaluating the trained agent...")

# Create a single environment for rendering
eval_env = gym.make("CartPole-v1", render_mode="human")
obs, info = eval_env.reset()

# Run the policy in the environment for 1000 timesteps
for i in range(1000):
    # Use the model to predict the best action
    action, _states = model.predict(obs, deterministic=True)
    
    # Take the action in the environment
    obs, reward, terminated, truncated, info = eval_env.step(action)
    
    # If the episode is over, reset the environment
    if terminated or truncated:
        print(f"Episode finished after {i+1} timesteps.")
        obs, info = eval_env.reset()

# Close the environment window
eval_env.close()
print("Evaluation finished.")
Control a CartPole with Reinforcement Learning — Action Pack