Article
reinforcement-learningcontrol-systemspythonstable-baselines3ppogymnasiumai-agentrobotics
Control a CartPole with Reinforcement Learning
Use the Stable Baselines3 library to train a PPO agent that solves the classic CartPole control problem. This guide covers environment setup, agent training, and policy evaluation in a few lines of Python.
beginner15 min4 steps
The play
- Install DependenciesInstall `stable-baselines3` for the Reinforcement Learning algorithms and `gymnasium` for the control environment. The `[extra]` option conveniently includes PyTorch, a required backend.
- Create the Control EnvironmentImport `gymnasium` and create an instance of the `CartPole-v1` environment. This environment simulates a cart balancing a pole, a classic problem in Reinforcement Learning for Control. We use `make_vec_env` to create multiple parallel environments for faster training.
- Define and Train the AgentImport the Proximal Policy Optimization (PPO) algorithm, a robust choice for control tasks. Instantiate the model with a standard policy ('MlpPolicy') and the vectorized environment, then call `.learn()` to start training.
- Save and Evaluate the Trained PolicySave your trained model for later use. To see your agent in action, create a new, single environment with `render_mode='human'` and use `model.predict()` to choose actions in a loop.
Starter code
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
# Step 1: Create a vectorized environment for parallel training
print("Creating vectorized environment...")
vec_env = make_vec_env("CartPole-v1", n_envs=4)
# Step 2: Define the Reinforcement Learning agent (PPO)
# 'MlpPolicy' is a standard feedforward neural network policy.
model = PPO("MlpPolicy", vec_env, verbose=1)
# Step 3: Train the agent
# The agent will interact with the environment to learn a policy.
print("\nTraining the agent...")
model.learn(total_timesteps=25000)
print("Training complete.")
# Step 4: Save the trained model
model.save("ppo_cartpole")
print(f"Model saved to ppo_cartpole.zip")
# Step 5: Evaluate the trained agent visually
print("\nLoading and evaluating the trained agent...")
# Create a single environment for rendering
eval_env = gym.make("CartPole-v1", render_mode="human")
obs, info = eval_env.reset()
# Run the policy in the environment for 1000 timesteps
for i in range(1000):
# Use the model to predict the best action
action, _states = model.predict(obs, deterministic=True)
# Take the action in the environment
obs, reward, terminated, truncated, info = eval_env.step(action)
# If the episode is over, reset the environment
if terminated or truncated:
print(f"Episode finished after {i+1} timesteps.")
obs, info = eval_env.reset()
# Close the environment window
eval_env.close()
print("Evaluation finished.")