Q-learning (fonction de valeur des états-actions)
Présentation
Le Q-learning est une technique d'apprentissage par renforcement. Il utilise la fonction Q (fonction de valeur des états-actions) qui repose sur un tableau que l'on nomme la Q-table. Les index de cette Q-table représentent les différents état du système. Chaque état dispose d'un tableau d'actions où chacune d'elles possède une valeur correspondant à une estimation qualitative de cette l'action.
Démonstration
États
-
Collisions aux alentours de la tête du serpent:
(3 directions possibles = 2**3 = 8 états possibles).
| a = a1 * 4 + a2 * 2 + a3 = ?
-
Direction de la pomme:
(8 directions = 8 états possibles).
| b = ?
-
État:
(8 * 8 = 64 états possibles).
| a * 8 + b = ?
Exploitation
-
Action:
(max( Q-Table[ état ])).
| Q-Table[ ? ][ ?, ?, ? ] = ?
L'Exploitation
La décision optimale est prise à partir de la Q-table. L'agent va y sélectionner l'action possédant la plus grande valeur parmi toutes les actions disponibles pour l'état actuel
L'Apprentissage
Les récompenses étant distribuées après chaque action de l'agent, c'est à ce moment là que la valeur de l'action est mise à jour dans la Q-table à l'aide de la fonction Q.
q_table[état][action] = (1 - alpha) * q_table[état][action] + alpha * (récompense + gamma * q_table[état_suivant][argmax(q_table[état_suivant])])
Le facteur d'actualisation (gamma) modifie l'importance de l'influence de la meilleure action à l'état suivant sur l'action courrante. Il oscille entre 0 et 1. Plus la valeur se rapprochent de 1, plus l'influence sera forte.
Code
// LE MODULO (adapté pour jongler avec les orientations relatives).
var modulo = function(a, b)
{
return ((a % b) + b) % b;
};
// LES ÉTATS.
var State = function(){};
State.prototype.get_state = function(agent, apple, check_wallCollision, check_tailCollision)
{
// Coordonnées de la tête du serpent.
var x = agent.matrix[0][0];
var y = agent.matrix[0][1];
var st1 = this.get_surrounding(x, y, agent, check_wallCollision, check_tailCollision);
var st2 = this.get_appleDirection(x, y, agent.orientation, apple);
// Combiner les états (8*8 état possibles).
return parseInt(st1 * 8 + st2);
};
State.prototype.get_surrounding = function(x, y, agent, check_wallCollision, check_tailCollision)
{
// Présence de collisions aux alentours de la tête du serpent pour les 3 actions (2**3).
var surrounding = [0, 0, 0];
for (var i = surrounding.length - 1; i >= 0; i--)
{
var new_ori = modulo((agent.orientation + i - 1), 4);
var new_pos = [x + agent.moves[new_ori][0], y + agent.moves[new_ori][1]];
surrounding[i] = (check_wallCollision(new_pos) || check_tailCollision(new_pos));
}
return surrounding[0] * 4 + surrounding[1] * 2 + surrounding[2];
};
State.prototype.get_appleDirection = function(x, y, orientation, apple)
{
// Direction de la pomme en fonction de la position et de l'orientatation du serpent (8**1).
var state;
var ax = apple[0];
var ay = apple[1];
// Nord.
if (ax == x && ay < y)
{
state = modulo((0 + 2 * orientation), 8);
}
// Nord-Est.
else if (ax > x && ay < y)
{
state = modulo((7 + 2 * orientation), 8);
}
// Est.
else if (ax > x && ay == y)
{
state = modulo((6 + 2 * orientation), 8);
}
// Sud-Est.
else if (ax > x && ay > y)
{
state = modulo((5 + 2 * orientation), 8);
}
// Sud.
else if (ax == x && ay > y)
{
state = modulo((4 + 2 * orientation), 8);
}
// Sud-Ouest.
else if (ax < x && ay > y)
{
state = modulo((3 + 2 * orientation), 8);
}
// Ouest.
else if (ax < x && ay == y)
{
state = modulo((2 + 2 * orientation), 8);
}
// Nord-Ouest.
else if (ax < x && ay < y)
{
state = modulo((1 + 2 * orientation), 8);
}
return state;
};
// L'ENVIRONNEMENT.
var Env = function(agent, cols_length, rows_length)
{
this.agent = agent;
this.cols_length = cols_length;
this.rows_length = rows_length;
this.apple;
this.time;
};
Env.prototype.restart = function()
{
this.time = 0;
this.agent.reset(this.cols_length, this.rows_length);
this.reset_apple();
};
Env.prototype.step = function(action)
{
this.time += 1;
this.agent.update_orientation(action);
return this.update(this.agent.get_newPosition());
};
Env.prototype.update = function(new_pos)
{
if (!this.check_wallCollision(new_pos) && !this.check_tailCollision(new_pos) && this.time < (this.rows_length * this.cols_length))
{
// Mettre la position de la tête à jour.
this.agent.matrix.unshift(new_pos);
// Pomme mangée.
if (this.apple[0] == this.agent.matrix[0][0] && this.apple[1] == this.agent.matrix[0][1])
{
// Si la taille du serpent est égale au nombre de cases dans l'aire de jeu la partie est gagnée.
if (this.agent.matrix.length == this.rows_length * this.cols_length)
{
// Partie gagnée - grosse récompense positive.
return {reward: this.rows_length * this.cols_length, done: true};
}
this.time = 0;
// Trouver une nouvelle position aléatoire pour la pomme.
this.reset_apple();
// Pomme mangée - récompense positive.
return {reward: 1, done: false};
}
// Effacer l'ancienne position du bout de la queue.
this.agent.matrix.pop(this.agent, this.cols_length, this.rows_length);
// Mouvement du serpend - pas de récompense.
return {reward: 0, done: false};
}
// Partie perdue - récompense négative.
return {reward: -1, done: true};
};
Env.prototype.reset_apple = function()
{
var grid = [];
var agent = JSON.stringify(this.agent.matrix);
for (var c = 0; c < this.cols_length; c++)
{
for (var r = 0; r < this.rows_length; r++)
{
// Inclure uniquement les cases libres.
if (agent.indexOf('['+c+','+r+']') === -1)
{
grid.push([c, r]);
}
}
}
this.apple = grid[Math.floor(Math.random() * grid.length)];
};
Env.prototype.check_wallCollision = function(coordinates)
{
// Vérifier la collision entre des coordonnées et les murs exterieurs du plateau de jeu.
if ((coordinates[0] >= 0 && coordinates[0] < this.cols_length) && (coordinates[1] >= 0 && coordinates[1] < this.rows_length))
{
return false;
}
return true;
};
Env.prototype.check_tailCollision = function(coordinates)
{
//Vérifier la collision entre des coordonnées et les différentes parties de la queue du serpent.
for (var i = this.agent.matrix.length - 1; i > 0; i--)
{
var pos = this.agent.matrix[i];
if (coordinates[0] == pos[0] && coordinates[1] == pos[1])
{
return true;
}
}
return false;
};
// L'AGENT.
var Agent = function(states_length, actions_length, defaultValue, epsilon, alpha, gamma)
{
// Le tableau de valeur des états-actions.
this.q_table = this.init_qTable(states_length, actions_length, defaultValue);
// Taux d'exploration.
this.epsilon = epsilon;
// Taux d'apprentissage.
this.alpha = alpha;
// Facteur d'actualisation.
this.gamma = gamma;
// Déplacement: haut, droite, bas, gauche.
this.moves = [[0, -1], [1, 0], [0, 1], [-1, 0]];
// 0: haut, 1: droite, 2: bas, 3: gauche.
this.orientation;
// Les coordonnées des différents éléments du serpent.
this.matrix;
};
Agent.prototype.init_qTable = function(states_length, actions_length, defaultValue)
{
var q_table = [];
for (var s = 0; s < states_length; s++)
{
q_table[s] = [];
for (var a = 0; a < actions_length; a++)
{
q_table[s][a] = defaultValue;
}
}
return q_table;
};
Agent.prototype.reset = function(cols_length, rows_length)
{
this.orientation = 0;
// Position aléatoire.
this.matrix = [[Math.floor(Math.random() * cols_length), Math.floor(Math.random() * rows_length)]];
};
Agent.prototype.update_orientation = function(action)
{
this.orientation = modulo(this.orientation + action - 1, 4);
};
Agent.prototype.get_newPosition = function()
{
return [this.matrix[0][0] + this.moves[this.orientation][0], this.matrix[0][1] + this.moves[this.orientation][1]];
};
Agent.prototype.choose_action = function(state)
{
// Exploration:
if (Math.random() < this.epsilon)
{
return Math.floor(Math.random() * 3);
}
// Exploitation:
else
{
// Choisir l'action possédant la plus grande valeur.
return this.argmax(this.q_table[state]);
}
};
Agent.prototype.argmax = function(array)
{
var max = -Infinity;
var index;
for (var i = array.length - 1; i >= 0; i--)
{
if (array[i] > max)
{
max = array[i];
index = i;
}
}
return index;
};
Agent.prototype.update_qTable = function(state, action, next_state, reward)
{
// La fonction Q ou fonction de valeur des états-actions.
this.q_table[state][action] = (1 - this.alpha) * this.q_table[state][action] + this.alpha * (reward + this.gamma * this.q_table[next_state][this.argmax(this.q_table[next_state])]);
};
// LE Q-LEARNING.
var Qlearning = function()
{
this.cols_length = 10;
this.rows_length = 10;
this.state = new State();
};
Qlearning.prototype.train = function(epoch, batch_size)
{
var agent = new Agent(64, 3, 0, 1, 0.1, 0.1);
var env = new Env(agent, this.cols_length, this.rows_length);
// Le nombre d'entrainement avec une diminution progressive d'epsilon et d'alpha.
for (var e = 0; e < epoch; e++)
{
var scores = [];
// Le nombre d'entrainement avec les valeurs d'epsilon et d'alpha actuelles.
for (var b = 0; b < batch_size; b++)
{
env.restart();
var state = this.state.get_state(env.agent, env.apple, env.check_wallCollision.bind(env), env.check_tailCollision.bind(env));
var done = false;
while (done === false)
{
var action = env.agent.choose_action(state);
var gameInfos = env.step(action);
var next_state = this.state.get_state(env.agent, env.apple, env.check_wallCollision.bind(env), env.check_tailCollision.bind(env));
var reward = gameInfos.reward;
env.agent.update_qTable(state, action, next_state, reward);
done = gameInfos.done;
state = next_state;
}
scores.push(env.agent.matrix.length);
}
// Afficher les statistiques.
this.display_results(scores, e, env.agent);
// Diminuer le taux d'exploration.
env.agent.epsilon *= 0.95;
// Diminuer le taux d'apprentissage.
env.agent.alpha *= 0.99;
};
};
Qlearning.prototype.display_results = function(array, epoch, agent)
{
var output = 0;
var length = array.length;
for (var i = 0; i < length; i++)
{
output += array[i];
}
var average = Math.round(output / length);
console.log(epoch + " | moyenne: " + average + ", max: " + array[agent.argmax(array)] + ", epsilon: " + Math.round(agent.epsilon * 1000) / 1000 + ", alpha: " + Math.round(agent.alpha * 1000) / 1000 + ", gamma: " + agent.gamma);
};
var demo = new Qlearning();
demo.train(100, 100);