Skip to content

API Reference

Environments

DrillInterface.AbstractEnv Type
julia
AbstractEnv

Abstract base type for all reinforcement learning environments.

Subtypes must implement the following methods:

  • reset!(env) - Reset the environment

  • act!(env, action) - Take an action and return the reward

  • observe(env) - Get current observation

  • terminated(env) - Check if episode terminated

  • truncated(env) - Check if episode was truncated

  • action_space(env) - Get the action space

  • observation_space(env) - Get the observation space

source
DrillInterface.AbstractParallelEnv Type
julia
AbstractParallelEnv <: AbstractEnv

Abstract type for vectorized/parallel environments that manage multiple environment instances.

Key Differences from AbstractEnv

MethodSingle EnvParallel Env
observeReturns one observationReturns vector of observations
act!Returns rewardReturns (rewards, terminateds, truncateds, infos)
terminatedReturns BoolReturns Vector{Bool}
truncatedReturns BoolReturns Vector{Bool}

Auto-Reset Behavior

Parallel environments automatically reset individual sub-environments when they terminate or truncate. The terminal observation is stored in infos[i]["terminal_observation"] before reset.

source
DrillInterface.AbstractParallelEnvWrapper Type
julia
AbstractParallelEnvWrapper{E}

Wraps a vectorized AbstractParallelEnv (e.g. normalization or monitoring) while remaining an AbstractParallelEnv.

source
DrillInterface.reset! Function
julia
reset!(env::AbstractEnv) -> Nothing

Reset the environment to its initial state.

Arguments

  • env::AbstractEnv: The environment to reset

Returns

  • Nothing
source
DrillInterface.act! Function
julia
act!(env::AbstractEnv, action) -> reward

Take an action in the environment and return the reward.

Arguments

  • env::AbstractEnv: The environment to act in

  • action: The action to take (type depends on environment's action space)

Returns

  • reward: Numerical reward from taking the action
source
DrillInterface.observe Function
julia
observe(env::AbstractEnv) -> observation

Get the current observation from the environment.

Arguments

  • env::AbstractEnv: The environment to observe

Returns

  • observation: Current state observation (type/shape depends on environment's observation space)
source
DrillInterface.terminated Function
julia
terminated(env::AbstractEnv) -> Bool

Check if the environment episode has terminated due to reaching a terminal state.

Arguments

  • env::AbstractEnv: The environment to check

Returns

  • Bool: true if episode is terminated, false otherwise
source
DrillInterface.truncated Function
julia
truncated(env::AbstractEnv) -> Bool

Check if the environment episode has been truncated (e.g., time limit reached).

Arguments

  • env::AbstractEnv: The environment to check

Returns

  • Bool: true if episode is truncated, false otherwise
source
DrillInterface.action_space Function
julia
action_space(env::AbstractEnv) -> AbstractSpace

Get the action space specification for the environment.

Arguments

  • env::AbstractEnv: The environment

Returns

  • AbstractSpace: The action space (e.g., Box, Discrete)
source
DrillInterface.observation_space Function
julia
observation_space(env::AbstractEnv) -> AbstractSpace

Get the observation space specification for the environment.

Arguments

  • env::AbstractEnv: The environment

Returns

  • AbstractSpace: The observation space (e.g., Box, Discrete)
source
DrillInterface.get_info Function
julia
get_info(env::AbstractEnv) -> Dict

Get additional environment information (metadata, debug info, etc.).

Arguments

  • env::AbstractEnv: The environment

Returns

  • Dict: Dictionary containing environment-specific information
source
DrillInterface.number_of_envs Function
julia
number_of_envs(env::AbstractParallelEnv) -> Int

Get the number of parallel environments in a parallel environment wrapper.

Arguments

  • env::AbstractParallelEnv: The parallel environment

Returns

  • Int: Number of parallel environments
source

Spaces

DrillInterface.AbstractSpace Type
julia
AbstractSpace

Abstract base type for all observation and action spaces in Drill.jl. Concrete subtypes include Box (continuous) and Discrete (finite actions).

source
DrillInterface.Box Type
julia
Box{T <: Number} <: AbstractSpace

A continuous space with lower and upper bounds per dimension.

Fields

  • low::Array{T}: Lower bounds for each dimension

  • high::Array{T}: Upper bounds for each dimension

  • shape::Tuple{Vararg{Int}}: Shape of the space

Example

julia
# 2D box with different bounds per dimension
space = Box(Float32[-1, -2], Float32[1, 3])

# Uniform bounds
space = Box(-1.0f0, 1.0f0, (4,))
source
DrillInterface.Discrete Type
julia
Discrete{T <: Integer} <: AbstractSpace

A discrete space representing a finite set of integer actions.

Fields

  • n::T: Number of discrete actions

  • start::T: Lowest action value

Example

julia
space = Discrete(4)     # Actions: 1, 2, 3, 4
space = Discrete(4, 0)  # Actions: 0, 1, 2, 3
source

Algorithms

Drill.PPO Type
julia
PPO{T <: AbstractFloat} <: OnPolicyAlgorithm

Proximal Policy Optimization algorithm.

Fields

  • gamma: Discount factor (default: 0.99)

  • gae_lambda: GAE lambda for advantage estimation (default: 0.95)

  • clip_range: PPO clipping parameter (default: 0.2)

  • ent_coef: Entropy coefficient (default: 0.0)

  • vf_coef: Value function coefficient (default: 0.5)

  • max_grad_norm: Maximum gradient norm for clipping (default: 0.5)

  • n_steps: Steps per rollout before update (default: 2048)

  • batch_size: Minibatch size (default: 64)

  • epochs: Number of epochs per update (default: 10)

  • learning_rate: Optimizer learning rate (default: 3e-4)

Example

julia
ppo = PPO(gamma=0.99f0, n_steps=2048, epochs=10)
agent = Agent(model, ppo)
train!(agent, env, ppo, 100_000)
source
Drill.SAC Type
julia
SAC{T <: AbstractFloat, E <: AbstractEntropyCoefficient} <: OffPolicyAlgorithm

Soft Actor-Critic algorithm with automatic entropy tuning.

Fields

  • learning_rate: Optimizer learning rate (default: 3e-4)

  • buffer_capacity: Replay buffer size (default: 1M)

  • start_steps: Random exploration steps before training (default: 100)

  • batch_size: Batch size for updates (default: 256)

  • tau: Soft update coefficient for target networks (default: 0.005)

  • gamma: Discount factor (default: 0.99)

  • train_freq: Steps between gradient updates (default: 1)

  • gradient_steps: Gradient steps per update, -1 for auto (default: 1)

  • ent_coef: Entropy coefficient (AutoEntropyCoefficient or FixedEntropyCoefficient)

Example

julia
sac = SAC(learning_rate=3f-4, buffer_capacity=1_000_000)
model = SACLayer(obs_space, act_space)
agent = Agent(model, sac)
train!(agent, env, sac, 500_000)
source
Drill.train! Function
julia
train!(agent, env, alg::SAC, max_steps; kwargs...)
train!(agent, replay_buffer, env, alg::SAC, max_steps; callbacks=nothing, ad_type=AutoZygote())

Train a SAC agent on env. The four-argument form allocates a ReplayBuffer from the environment spaces and capacity alg.buffer_capacity.

The five-argument form reuses an existing replay_buffer (same observation/action spaces as env).

Keyword arguments

  • callbacks: Optional vector of AbstractCallback hooks.

  • ad_type: Lux AD backend for gradient computation (default AutoZygote()).

Returns (agent, replay_buffer, training_stats).

source
julia
train!(agent, env, alg::PPO, max_steps; ad_type=AutoZygote(), callbacks=nothing)

Run PPO training on a parallel environment for up to max_steps environment steps (total across all sub-environments).

Rollouts use alg.n_steps steps per sub-environment per iteration. Training stops early if a callback returns false.

Keyword arguments

Returns nothing (mutates agent in place). On early exit from callbacks, returns nothing without completing the full schedule.

source

Agents

Drill.Agent Type

Unified Agent for all algorithms.

verbose: 0: nothing 1: progress bar 2: progress bar and stats

source

Missing docstring.

Missing docstring for AgentStats. Check Documenter's build log for details.

Drill.predict_actions Function
julia
predict_actions(layer::AbstractLayer, obs::AbstractArray, ps, st; deterministic::Bool=false) -> (actions, st)

Predict actions from batched observations.

Arguments

  • layer::AbstractLayer: The actor-critic layer

  • obs::AbstractArray: Batched observations (last dimension is batch)

  • ps: Layer parameters

  • st: Layer state

  • deterministic::Bool=false: Whether to use deterministic actions

Returns

  • actions: Vector/Array of actions (raw layer outputs, not processed for environment)

  • st: Updated layer state

Notes

  • Input observations must be batched (matrix/array format)

  • Output actions are raw layer outputs (e.g., 1-based for Discrete layers)

  • Use to_env() to convert for environment use

source
Drill.predict_values Function
julia
predict_values(layer::AbstractLayer, obs::AbstractArray, [actions::AbstractArray,] ps, st) -> (values, st)

Predict Q-values from batched observations and actions (for Q-Critic layers).

Arguments

  • layer::AbstractLayer: The actor-critic layer

  • obs::AbstractArray: Batched observations (last dimension is batch)

  • actions::AbstractArray: Batched actions (last dimension is batch) (only for Q-Critic layers)

  • ps: Layer parameters

  • st: Layer state

Returns

  • values: batched values (tuples of values for multiple Q-Critic networks)

  • st: Updated layer state

Notes

  • Input observations and actions must be batched (matrix/array format)

  • Actions should be in raw layer format (e.g., 1-based for Discrete)

source
Drill.steps_taken Function
julia
steps_taken(stats::AgentStats) -> Int

Number of environment steps recorded in stats (updated by internal training loops via add_step!).

source
julia
steps_taken(agent::Agent) -> Int

Total environment steps taken by agent during training (same count as steps_taken(agent.stats)).

source
Drill.evaluate_agent Function
julia
evaluate_agent(agent, env; kwargs...)

Evaluate a policy/agent for a specified number of episodes and return performance statistics.

Arguments

  • agent: The agent to evaluate (must implement predict method)

  • env: The environment to evaluate on (single env or parallel env)

Keyword Arguments

  • n_eval_episodes::Int = 10: Number of episodes to evaluate

  • deterministic::Bool = true: Whether to use deterministic actions

  • render::Bool = false: Whether to render the environment

  • callback::Union{Nothing, Function} = nothing: Optional callback function called after each step

  • reward_threshold::Union{Nothing, Real} = nothing: Minimum expected mean reward (throws error if not met)

  • return_episode_rewards::Bool = false: If true, returns individual episode rewards and lengths

  • warn::Bool = true: Whether to warn about missing Monitor wrapper

  • rng::AbstractRNG = Random.default_rng(): Random number generator for reproducible evaluation

Returns

  • If return_episode_rewards = false: (mean_reward::Float64, std_reward::Float64)

  • If return_episode_rewards = true: (episode_rewards::Vector{Float64}, episode_lengths::Vector{Int})

Notes

  • Episodes are distributed evenly across parallel environments to remove bias

  • If environment is wrapped with Monitor, episode statistics from Monitor are used

  • Otherwise, rewards and lengths are tracked manually during evaluation

  • For environments with reward/length modifying wrappers, consider using Monitor wrapper

Examples

julia
# Basic evaluation
mean_reward, std_reward = evaluate_agent(agent, env; n_eval_episodes=20)

# Get individual episode data
episode_rewards, episode_lengths = evaluate_agent(agent, env; 
    return_episode_rewards=true, deterministic=false)

# Evaluation with threshold check
mean_reward, std_reward = evaluate_agent(agent, env; 
    reward_threshold=100.0, n_eval_episodes=50)
source
Drill.AbstractActionAdapter Type
julia
AbstractActionAdapter

Maps between policy outputs and environment actions (see to_env, from_env); concrete types include ClampAdapter, TanhScaleAdapter, and DiscreteAdapter.

source
Drill.to_env Function
julia
to_env(adapter, policy_action, space::AbstractSpace)

Convert an action from the policy/model's action domain to the environment's action space. Called right before stepping the environment.

source
Drill.from_env Function
julia
from_env(adapter, env_action, space::AbstractSpace)

Optionally convert an environment action back to the policy/model's action domain. Useful for some off-policy training flows. Default: identity where appropriate.

source

Layers

Drill.ActorCriticLayer Function
julia
ActorCriticLayer(observation_space, action_space::Box; kwargs...) -> ContinuousActorCriticLayer
ActorCriticLayer(observation_space, action_space::Discrete; kwargs...) -> DiscreteActorCriticLayer

Unified constructor that forwards to ContinuousActorCriticLayer or DiscreteActorCriticLayer depending on the action space type.

source
Drill.ContinuousActorCriticLayer Type
julia
ContinuousActorCriticLayer

Actor–critic Lux model for continuous (Box) or discrete observations with continuous actions: feature extractor, stochastic actor (with learnable log-std where applicable), and value or Q heads according to critic_type.

Use ContinuousActorCriticLayer(observation_space, action_space::Box; kwargs...) to build a layer.

source
Drill.DiscreteActorCriticLayer Type
julia
DiscreteActorCriticLayer

Actor–critic Lux model for discrete actions: feature extractor, categorical policy head, and value head.

Use DiscreteActorCriticLayer(observation_space, action_space::Discrete; kwargs...).

source
Drill.SACLayer Function
julia
SACLayer(observation_space, action_space::Box; kwargs...)

Convenience constructor for SAC: builds a ContinuousActorCriticLayer with Q-critic heads (default critic_type = QCritic()), Gaussian actor on continuous actions, and optional shared features.

Keyword arguments match ContinuousActorCriticLayer where applicable (log_std_init, hidden_dims, activation, shared_features, critic_type).

source

Buffers

Drill.RolloutBuffer Type
julia
RolloutBuffer

On-policy rollout storage: stacked observations, actions, rewards, GAE advantages, returns, old log-probs and values for one PPO update.

Typically constructed via RolloutBuffer(observation_space, action_space, gae_lambda, gamma, n_steps, n_envs).

source
Drill.ReplayBuffer Type
julia
ReplayBuffer{T,O,OBS,AC}

A circular buffer for storing multiple trajectories of off-policy experience data, used for replay-based learning algorithms.

Truncation Logic

  • If terminated = true, then there should be no truncated_observation

  • If truncated = true, then there should be a truncated_observation

  • If terminated = false and truncated = false, then we stopped in the middle of an episode, so there should be a truncated_observation

source

Wrappers

Drill.MultiThreadedParallelEnv Type
julia
MultiThreadedParallelEnv(envs::Vector)

Parallel environment that steps sub-environments concurrently with @threads (same observation/action spaces, homogeneous env type).

Use for CPU-bound envs when parallel rollout helps; compare BroadcastedParallelEnv.

source
Drill.BroadcastedParallelEnv Type
julia
BroadcastedParallelEnv(envs::Vector)

Vectorized parallel environment: act!, observe, etc. broadcast over envs on a single thread (same spaces, homogeneous type).

Prefer when threading overhead dominates or env stepping is already cheap.

source
Drill.NormalizeWrapperEnv Type
julia
NormalizeWrapperEnv

AbstractParallelEnvWrapper that optionally normalizes observations and/or rewards using running statistics (RunningMeanStd), with clipping. Used in training to stabilize value learning.

Toggle training vs inference behavior with set_training / is_training; sync stats across parallel copies with sync_normalization_stats! when needed.

source
Drill.ScalingWrapperEnv Type
julia
ScalingWrapperEnv

Maps observation and action spaces to a normalized Box in   (affine transform) so policies see consistent bounds. Wraps a single AbstractEnv.

Construct with ScalingWrapperEnv(env) or ScalingWrapperEnv(env, orig_obs_box, orig_act_box).

source
Drill.MonitorWrapperEnv Type
julia
MonitorWrapperEnv(env, stats_window=100)

Wraps a parallel environment to track per-env episode returns and lengths in rolling buffers (EpisodeStats), exposing them via get_info for logging and evaluate_agent.

Use when you want stable episode metrics under vectorized resets.

source
Drill.EpisodeStats Type
julia
EpisodeStats{T}(stats_window)

Rolling buffers of recent finished-episode returns and lengths (used by MonitorWrapperEnv).

source
Drill.RunningMeanStd Type
julia
RunningMeanStd{T}

Tracks running mean and standard deviation using Welford's online algorithm. Similar to stable-baselines3's RunningMeanStd.

source
Drill.set_training Function
julia
set_training(env, training::Bool)

Return an environment with training mode set when applicable (e.g. NormalizeWrapperEnv); default no-op for other envs.

source
Drill.is_training Function
julia
is_training(env) -> Bool

Whether env is in training mode (obs/reward normalization updates when wrapped with NormalizeWrapperEnv); default true for other envs.

source
Drill.sync_normalization_stats! Function
julia
sync_normalization_stats!(eval_env::NormalizeWrapperEnv, train_env::NormalizeWrapperEnv)

Copy running normalization statistics from train_env to eval_env so evaluation uses the same obs/reward scaling.

source

Deployment

Drill.extract_policy Function
julia
extract_policy(agent) -> NeuralPolicy

Create a lightweight deployment policy from a trained agent.

source
Drill.NeuralPolicy Type
julia
NeuralPolicy

Lightweight inference policy holding a trained layer, Lux parameters/states, the environment action space, and an AbstractActionAdapter. Built via extract_policy; callable on batched or single observations to produce environment actions.

source
Drill.NormWrapperPolicy Type
julia
NormWrapperPolicy

Wraps a NeuralPolicy (or compatible policy) with observation normalization from a NormalizeWrapperEnv, matching training-time obs scaling at deployment.

Constructed by extract_policy(agent, normalize_env).

source

Logging

Drill.AbstractTrainingLogger Type
julia
AbstractTrainingLogger

Pluggable training log sink: implement set_step!, log_scalar!, log_metrics!, flush!, close! (and optionally increment_step!, log_hparams!). Use NoTrainingLogger to disable logging.

Concrete backends live in package extensions (TensorBoard, Wandb, DearDiary).

source
Drill.NoTrainingLogger Type
julia
NoTrainingLogger

No-op AbstractTrainingLogger; all logging methods are silent.

source
Drill.log_scalar! Function
julia
log_scalar!(logger::AbstractTrainingLogger, key::AbstractString, value::Real)

Log a single scalar metric under key.

source
Drill.log_metrics! Function
julia
log_metrics!(logger::AbstractTrainingLogger, kv::AbstractDict{<:AbstractString,<:Any})

Log multiple metrics at once from a string or symbol-keyed dictionary, or a NamedTuple. If implementing a custom logger backend, only a method for the string-keyed dictionary is required to be implemented.

source
Drill.set_step! Function
julia
set_step!(logger::AbstractTrainingLogger, step::Integer)

Set the global step for subsequent metric logs.

source
Drill.increment_step! Function
julia
increment_step!(logger::AbstractTrainingLogger, delta::Integer)

Advance the logger step counter by delta (optional for backends that only use set_step!).

source
Drill.log_hparams! Function
julia
log_hparams!(logger::AbstractTrainingLogger, hparams::AbstractDict{<:AbstractString,<:Any}, metrics::AbstractVector{<:AbstractString})

Write hyperparameters and associate them with specified metrics for hyperparameter tuning.

source
Drill.flush! Function
julia
flush!(logger::AbstractTrainingLogger)

Ensure any buffered data is pushed to the backend. Implementations may no-op.

source
Drill.close! Function
julia
close!(logger::AbstractTrainingLogger)

Finalize the logger and release resources. Implementations may no-op.

source

Callbacks

Drill.AbstractCallback Type
julia
AbstractCallback

Hook type for training: implement any of on_training_start, on_rollout_start, on_rollout_end, on_step, on_training_end. Each receives (callback, locals::Dict) and must return true to continue or false to stop training.

locals is built with Base.@locals inside train! and contains loop variables (agent, env, step counters, etc.).

source
Drill.on_training_start Function
julia
on_training_start(callback, locals) -> Bool

Called once at the start of train!. Default: true (continue).

source
Drill.on_training_end Function
julia
on_training_end(callback, locals) -> Bool

Called when training finishes normally. Default: true.

source
Drill.on_rollout_start Function
julia
on_rollout_start(callback, locals) -> Bool

Called at the beginning of each rollout collection phase. Default: true.

source
Drill.on_rollout_end Function
julia
on_rollout_end(callback, locals) -> Bool

Called after rollout data is collected, before gradient updates. Default: true.

source
Drill.on_step Function
julia
on_step(callback, locals) -> Bool

Optional per-step hook when algorithms emit it (default implementation: true). Override in subtypes as needed.

source