Trainer¶
BanditTrainer
¶
-
class
numpy_ml.bandits.trainer.
BanditTrainer
[source]¶ An object to facilitate multi-armed bandit training, comparison, and evaluation.
-
compare
(policies, bandit, n_trials, n_duplicates, plot=True, seed=None, smooth_weight=0.999, out_dir=None)[source]¶ Compare the performance of multiple policies on the same bandit environment, generating a plot for each.
Parameters: - policies (list of
BanditPolicyBase
instances) – The multi-armed bandit policies to compare. - bandit (
Bandit
instance) – The environment to train the policies on. - n_trials (int) – The number of trials per run.
- n_duplicates (int) – The number of times to evaluate each policy on the bandit environment. Larger values permit a better estimate of the variance in payoff / cumulative regret for each policy.
- plot (bool) – Whether to generate a plot of the policy’s average reward and regret across the episodes. Default is True.
- seed (int) – The seed for the random number generator. Default is None.
- smooth_weight (float in [0, 1]) – The smoothing weight. Values closer to 0 result in less smoothing, values closer to 1 produce more aggressive smoothing. Default is 0.999.
- out_dir (str or None) – Plots will be saved to this directory if plot is True. If out_dir is None, plots will not be saved. Default is None.
- policies (list of
-
train
(policy, bandit, n_trials, n_duplicates, plot=True, axes=None, verbose=True, print_every=100, smooth_weight=0.999, out_dir=None)[source]¶ Train a MAB policies on a multi-armed bandit problem, logging training statistics along the way.
Parameters: - policy (
BanditPolicyBase
instance) – The multi-armed bandit policy to train. - bandit (
Bandit
instance) – The environment to run the policy on. - n_trials (int) – The number of trials per run.
- n_duplicates (int) – The number of runs to evaluate
- plot (bool) – Whether to generate a plot of the policy’s average reward and regret across the episodes. Default is True.
- axes (list of
Axis
instances or None) – If not None andplot = True
, these are the axes that will be used to plot the cumulative reward and regret, respectively. Default is None. - verbose (boolean) – Whether to print run statistics during training. Default is True.
- print_every (int) – The number of episodes to run before printing loss values to
stdout. This is ignored if
verbose
is false. Default is 100. - smooth_weight (float in [0, 1]) – The smoothing weight. Values closer to 0 result in less smoothing, values closer to 1 produce more aggressive smoothing. Default is 0.999.
- out_dir (str or None) – Plots will be saved to this directory if plot is True. If out_dir is None, plots will not be saved. Default is None.
Returns: policy (
BanditPolicyBase
instance) – The policy trained during the last (i.e. most recent) duplicate run.- policy (
-
init_logs
(policies)[source]¶ Initialize the episode logs.
Notes
Training logs are represented as a nested set of dictionaries with the following structure:
log[model_id][metric][trial_number][duplicate_number]For example,
logs['model1']['regret'][3][1]
holds the regret value accrued on the 3rd trial of the 2nd duplicate run for model1.Available fields are ‘regret’, ‘cregret’ (cumulative regret), ‘reward’, ‘mse’ (mean-squared error between estimated arm EVs and the true EVs), ‘optimal_arm’, ‘selected_arm’, and ‘optimal_reward’.
-