API Reference
Environments
DrillInterface.AbstractEnv Type
AbstractEnvAbstract base type for all reinforcement learning environments.
Subtypes must implement the following methods:
reset!(env)- Reset the environmentact!(env, action)- Take an action and return the rewardobserve(env)- Get current observationterminated(env)- Check if episode terminatedtruncated(env)- Check if episode was truncatedaction_space(env)- Get the action spaceobservation_space(env)- Get the observation space
DrillInterface.AbstractParallelEnv Type
AbstractParallelEnv <: AbstractEnvAbstract type for vectorized/parallel environments that manage multiple environment instances.
Key Differences from AbstractEnv
| Method | Single Env | Parallel Env |
|---|---|---|
observe | Returns one observation | Returns vector of observations |
act! | Returns reward | Returns (rewards, terminateds, truncateds, infos) |
terminated | Returns Bool | Returns Vector{Bool} |
truncated | Returns Bool | Returns Vector{Bool} |
Auto-Reset Behavior
Parallel environments automatically reset individual sub-environments when they terminate or truncate. The terminal observation is stored in infos[i]["terminal_observation"] before reset.
DrillInterface.AbstractParallelEnvWrapper Type
AbstractParallelEnvWrapper{E}Wraps a vectorized AbstractParallelEnv (e.g. normalization or monitoring) while remaining an AbstractParallelEnv.
DrillInterface.reset! Function
reset!(env::AbstractEnv) -> NothingReset the environment to its initial state.
Arguments
env::AbstractEnv: The environment to reset
Returns
Nothing
DrillInterface.act! Function
act!(env::AbstractEnv, action) -> rewardTake an action in the environment and return the reward.
Arguments
env::AbstractEnv: The environment to act inaction: The action to take (type depends on environment's action space)
Returns
reward: Numerical reward from taking the action
DrillInterface.observe Function
observe(env::AbstractEnv) -> observationGet the current observation from the environment.
Arguments
env::AbstractEnv: The environment to observe
Returns
observation: Current state observation (type/shape depends on environment's observation space)
DrillInterface.terminated Function
terminated(env::AbstractEnv) -> BoolCheck if the environment episode has terminated due to reaching a terminal state.
Arguments
env::AbstractEnv: The environment to check
Returns
Bool:trueif episode is terminated,falseotherwise
DrillInterface.truncated Function
truncated(env::AbstractEnv) -> BoolCheck if the environment episode has been truncated (e.g., time limit reached).
Arguments
env::AbstractEnv: The environment to check
Returns
Bool:trueif episode is truncated,falseotherwise
DrillInterface.action_space Function
action_space(env::AbstractEnv) -> AbstractSpaceGet the action space specification for the environment.
Arguments
env::AbstractEnv: The environment
Returns
AbstractSpace: The action space (e.g., Box, Discrete)
DrillInterface.observation_space Function
observation_space(env::AbstractEnv) -> AbstractSpaceGet the observation space specification for the environment.
Arguments
env::AbstractEnv: The environment
Returns
AbstractSpace: The observation space (e.g., Box, Discrete)
DrillInterface.get_info Function
get_info(env::AbstractEnv) -> DictGet additional environment information (metadata, debug info, etc.).
Arguments
env::AbstractEnv: The environment
Returns
Dict: Dictionary containing environment-specific information
DrillInterface.number_of_envs Function
number_of_envs(env::AbstractParallelEnv) -> IntGet the number of parallel environments in a parallel environment wrapper.
Arguments
env::AbstractParallelEnv: The parallel environment
Returns
Int: Number of parallel environments
Spaces
DrillInterface.AbstractSpace Type
AbstractSpaceAbstract base type for all observation and action spaces in Drill.jl. Concrete subtypes include Box (continuous) and Discrete (finite actions).
DrillInterface.Box Type
Box{T <: Number} <: AbstractSpaceA continuous space with lower and upper bounds per dimension.
Fields
low::Array{T}: Lower bounds for each dimensionhigh::Array{T}: Upper bounds for each dimensionshape::Tuple{Vararg{Int}}: Shape of the space
Example
# 2D box with different bounds per dimension
space = Box(Float32[-1, -2], Float32[1, 3])
# Uniform bounds
space = Box(-1.0f0, 1.0f0, (4,))DrillInterface.Discrete Type
Discrete{T <: Integer} <: AbstractSpaceA discrete space representing a finite set of integer actions.
Fields
n::T: Number of discrete actionsstart::T: Lowest action value
Example
space = Discrete(4) # Actions: 1, 2, 3, 4
space = Discrete(4, 0) # Actions: 0, 1, 2, 3Algorithms
Drill.PPO Type
PPO{T <: AbstractFloat} <: OnPolicyAlgorithmProximal Policy Optimization algorithm.
Fields
gamma: Discount factor (default: 0.99)gae_lambda: GAE lambda for advantage estimation (default: 0.95)clip_range: PPO clipping parameter (default: 0.2)ent_coef: Entropy coefficient (default: 0.0)vf_coef: Value function coefficient (default: 0.5)max_grad_norm: Maximum gradient norm for clipping (default: 0.5)n_steps: Steps per rollout before update (default: 2048)batch_size: Minibatch size (default: 64)epochs: Number of epochs per update (default: 10)learning_rate: Optimizer learning rate (default: 3e-4)
Example
ppo = PPO(gamma=0.99f0, n_steps=2048, epochs=10)
agent = Agent(model, ppo)
train!(agent, env, ppo, 100_000)Drill.SAC Type
SAC{T <: AbstractFloat, E <: AbstractEntropyCoefficient} <: OffPolicyAlgorithmSoft Actor-Critic algorithm with automatic entropy tuning.
Fields
learning_rate: Optimizer learning rate (default: 3e-4)buffer_capacity: Replay buffer size (default: 1M)start_steps: Random exploration steps before training (default: 100)batch_size: Batch size for updates (default: 256)tau: Soft update coefficient for target networks (default: 0.005)gamma: Discount factor (default: 0.99)train_freq: Steps between gradient updates (default: 1)gradient_steps: Gradient steps per update, -1 for auto (default: 1)ent_coef: Entropy coefficient (AutoEntropyCoefficientorFixedEntropyCoefficient)
Example
sac = SAC(learning_rate=3f-4, buffer_capacity=1_000_000)
model = SACLayer(obs_space, act_space)
agent = Agent(model, sac)
train!(agent, env, sac, 500_000)Drill.train! Function
train!(agent, env, alg::SAC, max_steps; kwargs...)
train!(agent, replay_buffer, env, alg::SAC, max_steps; callbacks=nothing, ad_type=AutoZygote())Train a SAC agent on env. The four-argument form allocates a ReplayBuffer from the environment spaces and capacity alg.buffer_capacity.
The five-argument form reuses an existing replay_buffer (same observation/action spaces as env).
Keyword arguments
callbacks: Optional vector ofAbstractCallbackhooks.ad_type: Lux AD backend for gradient computation (defaultAutoZygote()).
Returns (agent, replay_buffer, training_stats).
train!(agent, env, alg::PPO, max_steps; ad_type=AutoZygote(), callbacks=nothing)Run PPO training on a parallel environment for up to max_steps environment steps (total across all sub-environments).
Rollouts use alg.n_steps steps per sub-environment per iteration. Training stops early if a callback returns false.
Keyword arguments
ad_type: Lux AD backend forcompute_gradients(defaultAutoZygote()).callbacks: Optional vector ofAbstractCallbackhooks; seeon_training_start,on_rollout_start, etc.
Returns nothing (mutates agent in place). On early exit from callbacks, returns nothing without completing the full schedule.
Agents
Drill.Agent Type
Unified Agent for all algorithms.
verbose: 0: nothing 1: progress bar 2: progress bar and stats
sourceMissing docstring.
Missing docstring for AgentStats. Check Documenter's build log for details.
Drill.predict_actions Function
predict_actions(layer::AbstractLayer, obs::AbstractArray, ps, st; deterministic::Bool=false) -> (actions, st)Predict actions from batched observations.
Arguments
layer::AbstractLayer: The actor-critic layerobs::AbstractArray: Batched observations (last dimension is batch)ps: Layer parametersst: Layer statedeterministic::Bool=false: Whether to use deterministic actions
Returns
actions: Vector/Array of actions (raw layer outputs, not processed for environment)st: Updated layer state
Notes
Input observations must be batched (matrix/array format)
Output actions are raw layer outputs (e.g., 1-based for Discrete layers)
Use
to_env()to convert for environment use
Drill.predict_values Function
predict_values(layer::AbstractLayer, obs::AbstractArray, [actions::AbstractArray,] ps, st) -> (values, st)Predict Q-values from batched observations and actions (for Q-Critic layers).
Arguments
layer::AbstractLayer: The actor-critic layerobs::AbstractArray: Batched observations (last dimension is batch)actions::AbstractArray: Batched actions (last dimension is batch) (only for Q-Critic layers)ps: Layer parametersst: Layer state
Returns
values: batched values (tuples of values for multiple Q-Critic networks)st: Updated layer state
Notes
Input observations and actions must be batched (matrix/array format)
Actions should be in raw layer format (e.g., 1-based for Discrete)
Drill.steps_taken Function
steps_taken(stats::AgentStats) -> IntNumber of environment steps recorded in stats (updated by internal training loops via add_step!).
steps_taken(agent::Agent) -> IntTotal environment steps taken by agent during training (same count as steps_taken(agent.stats)).
Drill.evaluate_agent Function
evaluate_agent(agent, env; kwargs...)Evaluate a policy/agent for a specified number of episodes and return performance statistics.
Arguments
agent: The agent to evaluate (must implementpredictmethod)env: The environment to evaluate on (single env or parallel env)
Keyword Arguments
n_eval_episodes::Int = 10: Number of episodes to evaluatedeterministic::Bool = true: Whether to use deterministic actionsrender::Bool = false: Whether to render the environmentcallback::Union{Nothing, Function} = nothing: Optional callback function called after each stepreward_threshold::Union{Nothing, Real} = nothing: Minimum expected mean reward (throws error if not met)return_episode_rewards::Bool = false: If true, returns individual episode rewards and lengthswarn::Bool = true: Whether to warn about missing Monitor wrapperrng::AbstractRNG = Random.default_rng(): Random number generator for reproducible evaluation
Returns
If
return_episode_rewards = false:(mean_reward::Float64, std_reward::Float64)If
return_episode_rewards = true:(episode_rewards::Vector{Float64}, episode_lengths::Vector{Int})
Notes
Episodes are distributed evenly across parallel environments to remove bias
If environment is wrapped with Monitor, episode statistics from Monitor are used
Otherwise, rewards and lengths are tracked manually during evaluation
For environments with reward/length modifying wrappers, consider using Monitor wrapper
Examples
# Basic evaluation
mean_reward, std_reward = evaluate_agent(agent, env; n_eval_episodes=20)
# Get individual episode data
episode_rewards, episode_lengths = evaluate_agent(agent, env;
return_episode_rewards=true, deterministic=false)
# Evaluation with threshold check
mean_reward, std_reward = evaluate_agent(agent, env;
reward_threshold=100.0, n_eval_episodes=50)Drill.AbstractActionAdapter Type
AbstractActionAdapterMaps between policy outputs and environment actions (see to_env, from_env); concrete types include ClampAdapter, TanhScaleAdapter, and DiscreteAdapter.
Drill.to_env Function
to_env(adapter, policy_action, space::AbstractSpace)Convert an action from the policy/model's action domain to the environment's action space. Called right before stepping the environment.
sourceDrill.from_env Function
from_env(adapter, env_action, space::AbstractSpace)Optionally convert an environment action back to the policy/model's action domain. Useful for some off-policy training flows. Default: identity where appropriate.
sourceLayers
Drill.ActorCriticLayer Function
ActorCriticLayer(observation_space, action_space::Box; kwargs...) -> ContinuousActorCriticLayer
ActorCriticLayer(observation_space, action_space::Discrete; kwargs...) -> DiscreteActorCriticLayerUnified constructor that forwards to ContinuousActorCriticLayer or DiscreteActorCriticLayer depending on the action space type.
Drill.ContinuousActorCriticLayer Type
ContinuousActorCriticLayerActor–critic Lux model for continuous (Box) or discrete observations with continuous actions: feature extractor, stochastic actor (with learnable log-std where applicable), and value or Q heads according to critic_type.
Use ContinuousActorCriticLayer(observation_space, action_space::Box; kwargs...) to build a layer.
Drill.DiscreteActorCriticLayer Type
DiscreteActorCriticLayerActor–critic Lux model for discrete actions: feature extractor, categorical policy head, and value head.
Use DiscreteActorCriticLayer(observation_space, action_space::Discrete; kwargs...).
Drill.SACLayer Function
SACLayer(observation_space, action_space::Box; kwargs...)Convenience constructor for SAC: builds a ContinuousActorCriticLayer with Q-critic heads (default critic_type = QCritic()), Gaussian actor on continuous actions, and optional shared features.
Keyword arguments match ContinuousActorCriticLayer where applicable (log_std_init, hidden_dims, activation, shared_features, critic_type).
Buffers
Drill.RolloutBuffer Type
RolloutBufferOn-policy rollout storage: stacked observations, actions, rewards, GAE advantages, returns, old log-probs and values for one PPO update.
Typically constructed via RolloutBuffer(observation_space, action_space, gae_lambda, gamma, n_steps, n_envs).
Drill.ReplayBuffer Type
ReplayBuffer{T,O,OBS,AC}A circular buffer for storing multiple trajectories of off-policy experience data, used for replay-based learning algorithms.
Truncation Logic
If
terminated = true, then there should be notruncated_observationIf
truncated = true, then there should be atruncated_observationIf
terminated = falseandtruncated = false, then we stopped in the middle of an episode, so there should be atruncated_observation
Wrappers
Drill.MultiThreadedParallelEnv Type
MultiThreadedParallelEnv(envs::Vector)Parallel environment that steps sub-environments concurrently with @threads (same observation/action spaces, homogeneous env type).
Use for CPU-bound envs when parallel rollout helps; compare BroadcastedParallelEnv.
Drill.BroadcastedParallelEnv Type
BroadcastedParallelEnv(envs::Vector)Vectorized parallel environment: act!, observe, etc. broadcast over envs on a single thread (same spaces, homogeneous type).
Prefer when threading overhead dominates or env stepping is already cheap.
sourceDrill.NormalizeWrapperEnv Type
NormalizeWrapperEnvAbstractParallelEnvWrapper that optionally normalizes observations and/or rewards using running statistics (RunningMeanStd), with clipping. Used in training to stabilize value learning.
Toggle training vs inference behavior with set_training / is_training; sync stats across parallel copies with sync_normalization_stats! when needed.
Drill.ScalingWrapperEnv Type
ScalingWrapperEnvMaps observation and action spaces to a normalized Box in AbstractEnv.
Construct with ScalingWrapperEnv(env) or ScalingWrapperEnv(env, orig_obs_box, orig_act_box).
Drill.MonitorWrapperEnv Type
MonitorWrapperEnv(env, stats_window=100)Wraps a parallel environment to track per-env episode returns and lengths in rolling buffers (EpisodeStats), exposing them via get_info for logging and evaluate_agent.
Use when you want stable episode metrics under vectorized resets.
sourceDrill.EpisodeStats Type
EpisodeStats{T}(stats_window)Rolling buffers of recent finished-episode returns and lengths (used by MonitorWrapperEnv).
Drill.RunningMeanStd Type
RunningMeanStd{T}Tracks running mean and standard deviation using Welford's online algorithm. Similar to stable-baselines3's RunningMeanStd.
sourceDrill.set_training Function
set_training(env, training::Bool)Return an environment with training mode set when applicable (e.g. NormalizeWrapperEnv); default no-op for other envs.
Drill.is_training Function
is_training(env) -> BoolWhether env is in training mode (obs/reward normalization updates when wrapped with NormalizeWrapperEnv); default true for other envs.
Drill.sync_normalization_stats! Function
sync_normalization_stats!(eval_env::NormalizeWrapperEnv, train_env::NormalizeWrapperEnv)Copy running normalization statistics from train_env to eval_env so evaluation uses the same obs/reward scaling.
Deployment
Drill.extract_policy Function
extract_policy(agent) -> NeuralPolicyCreate a lightweight deployment policy from a trained agent.
sourceDrill.NeuralPolicy Type
NeuralPolicyLightweight inference policy holding a trained layer, Lux parameters/states, the environment action space, and an AbstractActionAdapter. Built via extract_policy; callable on batched or single observations to produce environment actions.
Drill.NormWrapperPolicy Type
NormWrapperPolicyWraps a NeuralPolicy (or compatible policy) with observation normalization from a NormalizeWrapperEnv, matching training-time obs scaling at deployment.
Constructed by extract_policy(agent, normalize_env).
Logging
Drill.AbstractTrainingLogger Type
AbstractTrainingLoggerPluggable training log sink: implement set_step!, log_scalar!, log_metrics!, flush!, close! (and optionally increment_step!, log_hparams!). Use NoTrainingLogger to disable logging.
Concrete backends live in package extensions (TensorBoard, Wandb, DearDiary).
sourceDrill.NoTrainingLogger Type
NoTrainingLoggerNo-op AbstractTrainingLogger; all logging methods are silent.
Drill.log_scalar! Function
log_scalar!(logger::AbstractTrainingLogger, key::AbstractString, value::Real)Log a single scalar metric under key.
Drill.log_metrics! Function
log_metrics!(logger::AbstractTrainingLogger, kv::AbstractDict{<:AbstractString,<:Any})Log multiple metrics at once from a string or symbol-keyed dictionary, or a NamedTuple. If implementing a custom logger backend, only a method for the string-keyed dictionary is required to be implemented.
sourceDrill.set_step! Function
set_step!(logger::AbstractTrainingLogger, step::Integer)Set the global step for subsequent metric logs.
sourceDrill.increment_step! Function
increment_step!(logger::AbstractTrainingLogger, delta::Integer)Advance the logger step counter by delta (optional for backends that only use set_step!).
Drill.log_hparams! Function
log_hparams!(logger::AbstractTrainingLogger, hparams::AbstractDict{<:AbstractString,<:Any}, metrics::AbstractVector{<:AbstractString})Write hyperparameters and associate them with specified metrics for hyperparameter tuning.
sourceDrill.flush! Function
flush!(logger::AbstractTrainingLogger)Ensure any buffered data is pushed to the backend. Implementations may no-op.
sourceDrill.close! Function
close!(logger::AbstractTrainingLogger)Finalize the logger and release resources. Implementations may no-op.
sourceCallbacks
Drill.AbstractCallback Type
AbstractCallbackHook type for training: implement any of on_training_start, on_rollout_start, on_rollout_end, on_step, on_training_end. Each receives (callback, locals::Dict) and must return true to continue or false to stop training.
locals is built with Base.@locals inside train! and contains loop variables (agent, env, step counters, etc.).
Drill.on_training_start Function
on_training_start(callback, locals) -> BoolCalled once at the start of train!. Default: true (continue).
Drill.on_training_end Function
on_training_end(callback, locals) -> BoolCalled when training finishes normally. Default: true.
Drill.on_rollout_start Function
on_rollout_start(callback, locals) -> BoolCalled at the beginning of each rollout collection phase. Default: true.
Drill.on_rollout_end Function
on_rollout_end(callback, locals) -> BoolCalled after rollout data is collected, before gradient updates. Default: true.
Drill.on_step Function
on_step(callback, locals) -> BoolOptional per-step hook when algorithms emit it (default implementation: true). Override in subtypes as needed.