Deep-RL: DQN — Regression or Classification?
Providing some basic understanding behind a regression approach to a classification problem in DQNs action values.
This article assumes an understanding of classification and regression when it comes to deep learning and a basic understanding of Reinforcement Learning.
This article was made to address issues when it comes approaching DQNs, as it seems we are doing regression to a problem that requires classification. What is really going on? By the end of this article, I hope it will be clear…
Setup: We assume a Deep-RL example of a DQN with 4 discrete available actions, and we want to select one of the available actions depening on our estimated state
When you first start working with DQNs (Deep Q-Networks) you are introduced to some challenging approaches to applying deep neural networks to address Reinforcement Learning problems.
The first thing you come to realize is that you have no dataset to start with! In fact, it is the agent itself that produces the dataset with each consecutive action.
An online model (your case) starts without dataset. Our agent will create the dataset by its own experience in an “as we go” approach.
An offline model, on the other hand, is one that comes with an existing dataset to use for training, and it is what we use in classic Deep Learning approaches for regression and classification. Especially in classification, we are required to have an offline model only, as we don’t only need the data alone, but we also require the label by which it is followed as well!
That can be challenging in nature, right? I mean, consider that your very dataset is some random and unlabeled data which initially you have no idea about.
So very initial data, based on which are required to do your training, is the data from that random equiprobable policy, which at best is irrelevant, and at worst is totally misleading.
So how can you improve your own training, over data that you feed yourself, without any kind of indication of good or bad labelling in the first place?
You definitely cannot approach this as a classification problem, since you have no labels to rely on the quality of actions for your states in the first place, right?
However, your neural network should decide over the best possible action to take based on your existing state. That is, we should come up with the action with the highest probability (a “state” should go “in”, and a chosen “action” should come “out” of our neural network), turning this to a classification problem, where we classify over the most likely action to take, correct?
Example: If we have an action set of let’s say 4 possible actions and for each estimated state, we want our neural network to dictate what action to select, this is now a classification problem, as we are dealing with probability spaces and we need to come up with the topmost space.
So it seems that we are a bit lost here. You must have already realized that certain measures need to be taken to apply deep learning to RL.
To complicate things even more, imagine that you are having your neural network weights configuration (θ), and based on this very configuration, you want to improve to some new θ` configuration. So how can you update your existing neural network weights if you take into account predictions based on the same existing neural network weights without some labelling to guide you? This is not how classification works!
It is getting confusing, right?
Don’t worry, hair pulling is normal at this stage, but I hope we will eventually come to understand how to approach q-learning the “deep way”.
First, let us talk about Q-Learning in the first place, so we understand how the Deep-Q-Learning works.
I assume you are already familiar with the term Q-Table. If you are not, please do visit my article about Basic Understanding in Reinforcement Learning.
In this section, I will briefly talk about SARSAMAX or Q-Learning. Then we will look at the deep learning approach.
Much of the content covered on the Basic Understanding in Reinforcement Learning covers the Monte-Carlo Methods approach logic.
In a nutshell, in Monte-Carlo Methods, we complete a full set of episodes. After each episode, we go back and take notes of the total reward for each state, based on the action that was taken using Bellman Functions.
That really means that we need an entire episode in order to find the expected reward for each state based on each action, as we need to reach the terminal state in order to evaluate for the total reward from that state, all the way to the terminal state using our policy.
Temporal Difference comes to the rescue, so we do not need a complete episode to evaluate the action-value function for our states. Instead, we can use the temporal difference to evaluate for the action-value function at each consecutive action step we take during our episode.
We call this SARSA (State, Action Reward, State(next), Action(next)).
It works by looking into this and the next action, so we can compare the current state/action value, if by comparing it to the next state/action. If the next state/action pair is higher, then our existing state/action should increase by a bit, otherwise decrease by a bit.
This way, we don’t need to wait for an episode to conclude in order to populate our q-table entry for our existing state/action pair.
There are 3 Sarsa types. Sarsa (or Sarsa(0)), SarsaMax (or Q-Learning), and Expected Sarsa.
This article is not about Sarsa, so we will look directly into SarsaMax (or Q-Learning).
SarsaMax is the Sarsa approach, by which we don’t need the next action to be taken to populate our existing q-table state/action pair. Instead, we consider as next action the one dictated by our policy, so we compare our existing action together with the the next best-scored action of our q-table.
Now we can directly start estimating rewards for our existing state/action by simply looking at the q-table for the next action/pair score, and update accordingly.
Updating Q-Table using Sarsa
We already mentioned it earlier, but we will make it clear now…
Using SarsaMax (Q-Learning), we update our q-table’s existing state/action pair, based on the highest scored state/action pair of the next state.
If the next state/action has a higher score, we increase our state/action pair, otherwise, we decrease.
So, our error is measured in terms of distance. So in a deep learning perspective, our Error Function is the distance between our existing state/action value to the next state/action value. So we can express this error function by something like the MSE (Mean of Squared Errors).
If you come from a Deep Neural Network background, you must already know that we are dealing with some Regression approach here as MSE is a regression type error function!
Deep Q Networks
Now imagine that we want to transfer this logic to Deep Learning. In fact, what we just saw above is how we approach temporal difference with sarsamax via a deep neural network!
I will not go into detail here about Deep Reinforcement Learning, if you follow my introductory article on RL, you can find more about it there.
However, during DQN training, we use a regression approach to minimise the error between our existing expected reward to the estimated expected reward by our neural network.
Classification or Regression?
So here we obviously saw that we train using a regression model, but our problem is in fact classification!
We at the very end need to come up with a prediction of what action to select.
State goes in, Action comes out!
So since we select the topmost action, it is the equivalent of the topmost label if we were dealing with classification.
Our actual Q-Network architecture could be a regression neural network that has several outputs in its final layer.
The important note here is that, each of these outputs (one for each possible action) measures the same thing! Estimated total reward!
So step back and think about it…
They all measure reward scores! So since we measure the same thing in multiple outputs, then the output with the highest reward over the same measurement is the topmost candidate.
So even though we represent discrete output values for our expected total rewards here (as we do with typical regression models), because each value represents different quantity over the same currency (our total estimated reward), it automatically transforms the output to a classification model.
You can even use softmax activation function here if you like, as it won’t change the fact that we are looking for the highest score and the output with the highest magnitude is your topmost choice!
So this is a regression neural network, which we treat as classification in nature, providing labelled predictions, where each label is an action to take.