Skip to content

Getting Started

Quick Example: PPO on CartPole

julia
using Drill
using Zygote  # Required for automatic differentiation
using ClassicControlEnvironments

# Create parallel environments
env = BroadcastedParallelEnv([CartPoleEnv() for _ in 1:4])

# Actor-Critic network
model = ActorCriticLayer(observation_space(env), action_space(env))

# Algorithm with default hyperparameters
ppo = PPO()

# Create agent
agent = Agent(model, ppo; verbose=2)

# Train for 100k steps
learn_stats, timer = train!(agent, env, ppo, 100_000)

# Extract deployment policy
policy = extract_policy(agent)

Quick Example: SAC on Continuous Control

julia
using Drill
using Zygote  # Required for automatic differentiation

# Continuous action environment
env = BroadcastedParallelEnv([YourContinuousEnv() for _ in 1:4])

# SAC-specific policy
model = SACLayer(observation_space(env), action_space(env))

# SAC algorithm
sac = SAC(learning_rate=3f-4, buffer_capacity=1_000_000)

# Create agent
agent = Agent(model, sac; verbose=1)

# Train
agent, buffer, stats, timer = train!(agent, env, sac, 500_000)

Key Hyperparameters

PPO

ParameterDefaultDescription
gamma0.99Discount factor
gae_lambda0.95GAE lambda
clip_range0.2PPO clipping range
ent_coef0.0Entropy coefficient
n_steps2048Steps per rollout
batch_size64Minibatch size
epochs10Epochs per update

SAC

ParameterDefaultDescription
gamma0.99Discount factor
tau0.005Soft update rate
buffer_capacity1MReplay buffer size
batch_size256Batch size
train_freq1Steps between updates

Evaluation

julia
eval_env = CartPoleEnv(max_steps=500)
stats = evaluate_agent(agent, eval_env; n_episodes=10, deterministic=true)
println("Mean return: $(mean(stats.episodic_returns))")

Extracting a policy from the agent

The agent contains everything needed for producing actions from environment observations, in addition to things needed for training. To extract only the necessary structures needed for inference, use the extract_policy function. The object you get is a NeuralPolicy, implementing the minimal policy interface:

julia
# Extract lightweight policy (actor only)
policy = extract_policy(agent)

observation = observe(env)
# Get action for observation
actions = policy(observation; deterministic=true)

If using the @NormalizeWrapperEnv wrapper on the training environment, supply this as the second positional argument to extract_policy:

julia
policy = extract_policy(agent, norm_wrapped_env)

This will give you a NeuralPolicy wrapped in a NormWrapperPolicy, that normalizes the observations you pass to the policy:

julia
normwrapped_policy = extract_policy(agent, norm_wrapped_training_env)

raw_observation = observe(single_test_time_env)
# Get action for observation
actions = policy(raw_observation; deterministic=true)

Here, single_test_time_env does not have a normalization wrapper, but the policy applies the normalization needed before obtaining the action from the layer.