Debuggable Deep Learning
Singh: My name is Avesh [Singh], this is Mantas [Matelis]. We’re both software engineers at Cardiogram, as Mike said and we were working on a deep neural network to predict cardiovascular diseases using wearable devices like the Apple Watch. Along the way, we’ve learned a lot about how to build and debug deep neural networks, and we wanted to share some of that knowledge with you today. Personally, I find it helpful to look through the slides during the talk, and in this slide actually we’ve put some code, so, if you’d like to follow along, Mantas and I both just tweeted out the slides, you can find our Twitter handles right here, just look on our feeds and you can find these slides in PDF format.
We’ve titled this talk Debuggable Deep Learning, but is that an oxymoron? Deep Learning is often seen as a black box, you take a piece of input data, X case several thousand matrix multiplications, and out comes a prediction. This prediction isn’t easily explained, so this poses an obvious problem for machine learning practitioners. How do you construct and debug a model? In this talk, we’re going to walk you through some techniques that we’ve used to demystify the behavior of a DNN, and we don’t promise explainability, but, the point we want to drive home is that constructing a DNN architecture is not alchemy, it’s engineering.
This presentation is split into two parts. First, we’re going to talk about coming up with an initial architecture, to do this, you must understand your problem and your data. This is also going to introduce you to Cardiogram’s data set and model, which is going to be essential to understanding the second part of the talk which is on debugging techniques. After that, we’re going to talk you through these debugging methods that we’ve used to identify and fix problems in our DNN. Let’s get started.
Overview of Cardiogram Data
Cardiogram is a mobile app for iOS and Android. Who here has an Apple watch, or Garmin, or Wear OS device? Great, please download our app, we want more users, more labels. A lot of us track our heart rates when we’re running or biking, but within your heart rate data, you can also see your rem cycles, or how your sleep is disrupted by alcohol. You can quantify the anger you feel when you’re stuck in traffic on the 101, or your anxiety during a job interview. Your heart says a lot about you, and we’re using this data to detect signs of disease.
About 500,000 people use the Cardiogram app daily, we’ve surveyed many of these users in order to come up with a data set of diagnoses. The conditions we’re most interested in are these chronic cardiovascular diseases, diabetes, sleep apnea, and high blood pressure. These conditions form our labels, we’re trying to predict whether a user has diabetes, sleep apnea, or high blood pressure. For each of these conditions, as you can see here, the number of positive labels is in the tens of thousands, and the total labels are in the hundreds of thousands. Our data set is restricted to Apple Watch users, we get both heart rate and step count information from the watch. The step count is intended to provide some context, so, a high heart rate is more expected if the step count is also high, but as a side effect, it also provides a measure of how active the user is.
We get a user’s heart rate and step count at various time intervals. These time intervals are not always consistent, so, if a user takes her watch off to go to sleep, then there’ll be no heart rate readings for eight hours. To account for these gaps, we encode this delta time channel, which is DT here, which stores the amount of time since the previous reading. Let’s say the first time set, the user has a heart rate of 76 beats per minute, and then five seconds later, that rises to 78, and then after that we get a step count reading, and so on. This is the input that we use to our model, it’s a 2D array, shown here.
You might be thinking, this data comes with a lot of challenges. The three that we enumerate here is, one, as I mentioned, readings are taken at an irregular sampling frequency. Number two, the data is also very low dimensional, it includes just the three input channels: heart rate, step count, and delta time. Finally, the data streams are arbitrarily long, some users have been using their watch for years, while others only have data for the past day. Next, I’ll pass it off to Mantas [Matelis], to discuss some solutions to these problems, and to present some ideas around model architecture.
Building an Architecture
Matelis: You understand your problem, you know the characteristics of the data, it’s time to build an architecture. If you’re using an existing well-researched domain like image recognition or speech recognition, use existing architectures, maybe even existing weights, the less that you build yourself, the less you have to debug, and the easier life will be. If you’re in more of an unresearched space, like we are, you have a lot more work to do, so the broad advice that I can give is, start simple and look for incremental gains.
I’m going to go over one of the architectures that we ourselves have built and debugged. You don’t need to understand this entirely, but a general picture of what we’re working with will help some of the later slides. Our input at the bottom is the sensory data that Avesh mentioned a few slides ago, and our output at the top is the risk scores for the four conditions that we try and predict, atrial fibrillation, sleep apnea, diabetes, and high blood pressure.
At a high level, the architecture consists of feeding in the time series of wearable data into several temporal convolution layers, several LSTM layers, a final convolution layer, and then the outputs that correspond to the four risks scores. There are of course alternative data representations that lead to alternative architectures that we can try. Instead of feeding in the raw data inputs directly, we can group the data into hour long intervals, and manually generate features for things that we deem might be important for the neural network in order to successfully discover the patterns in our data, so things like total step count, heart rate percentiles over the hour, indicator variables, so things like, is the user sleeping at the time? Do we think the user was working out?
Then from this data representation, we can use a lot simpler CNN architecture, something a little more off-the-shelf, because in this architecture, the data is much higher dimensional, but also the sampling is fixed. You get one of these per hour, whereas, in the other case, you get a heart rate reading maybe every five seconds, maybe every five minutes, it’s a lot more complicated. This does of course reduce the granularity of the data from being an individual observation of a heart rate or a step count, but it helps to mitigate these issues, and it allows us to use a lot of a much simpler architecture.
This concept is really related to debuggability of DNNs, and its interpretability. Interpretability is just giving it an input, what made the output be the output, and there are a lot of sort of complicated papers about how to make really complicated deep neural networks interpretable, but I’m going to approach it from a simpler angle here. If you can make your model architecture simpler, do so. One of our cases is, we have some disparate data sources, so we have the sensor data that we’ve been talking about, but we also have things like age, sex, and BMI. If we’re able to build separate models out of both of these pieces of information, and then combine them later on, it becomes a lot easier to debug the model. You can tell which part of the model is failing, you can see exactly how each model contributes to the accuracy of the final model, and it’s really easy to run ablation studies. To go take out parts of the model and see, well, how much worse do we do?
Debug Your Model
We have architecture, we understand our data, it’s time to debug your model. There are two kinds of classes of breakages in DNNs, one is the really common “My model is in training or my model doesn’t work at all, maybe your losses is as NaN, or your AAC is point five”, something along those lines. There’s also a more insidious, “My model doesn’t do as well as I think it could”, and so we’re going to talk about both of these. First, you’re often not certain whether a model is broken or whether you have a hard problem. What’s the priority, what do we think? What kind of results can we expect? It’s hard to say, so try and find a baseline with a simple and obvious model that you trust.
In our case, for example, we could take those hourly interval feature vectors that we built and talked about a few slides ago, and through the mean and standard deviations of these into LR, and that would be a really simple baseline that we could compare the results of a DNN against. If it turned out that the DNN did worse than this baseline, well, of course that’s a pretty clear indication that the DNN is fundamentally broken, or something along the way didn’t go as planned, whereas, if the model does better, great, you did a good job applying deep learning. If the model does about the same as logistic regression, or linear regression, that means that there’s no value above and beyond the very simple feature engineering that you did, and so you should take a step back and think differently about the problem.
Another really important way to make sure that your model isn’t entirely broken is to verify your input generation data pipeline. A lot of the time you have this complex series of functions that takes a series of CSVs and turns it into arrays that go into a .predict, or a .fit, or a .train, and it’s really easy to make a mistake along the way. By making sure that the input to the model is what you expect, you can eliminate a whole class of bugs.
Example errors that we’ve made here include training on the wrong label, filtering many more user weeks than we expected, leading to a lot less training data, leading to worse model performance, and then insidious things, like at some point we had different pieces of code that had different understandings of what the ordering of conditions were, so some piece of code thought it was atrial fibrillation, diabetes, sleep apnea, another thought it was diabetes, sleep apnea, atrial fibrillation, and that made the model look a lot worse than it actually was. It’s hard to catch these sorts of things because these conditions are correlated with each other. The easiest way is just to write unit tests and make sure that what you put in is what you’re expecting.
Next thing to try to make sure your model isn’t entirely broken, is to make sure that you’re able to overfit on a small data set. Of course, overfit is normally a bad word, we don’t want to overfit, but, if you turn off regularization, dropout, batch norm, all that, you should be able to overfit on a small portion of your data set maybe a few percent, maybe a bit more. If you’re able to do this, you eliminate another class of errors, so there’s no normal loss curve for this, but in classification, you should expect to be able to get to around 0.99 AUC, or better.
If you can’t, it could mean there’s a number of things that could be going wrong here. Maybe your model architecture isn’t what you expect, maybe you’re doing some funny slicing and carries, and you accidentally dropped most of your input and your training on very low dimensional data. Maybe your input pipeline is bad and broken, and your unit tests didn’t catch it, or alternatively, your learning rate is just really far from what it should be. It could also be that your model just doesn’t have enough capacity, so, make it wider, maybe make it deeper, that’s also an important sign.
We’ve gone through some of the debugging techniques that help with a model that’s doing really poorly, but there are a lot more techniques to help figure out why a model isn’t doing as well as it could be doing, one of these is examining outputs. Telling a machine learning practitioner to examine their outputs is like telling someone to eat their vegetables. You do have to do it, it’s unfortunate, but you do. Here’s an example of an aggregate analysis that we ran on a DNN architecture, here, with this DNN, we initialized the LSTM state with some user metadata, like their age, sex, and BMI. And we wanted to understand the extent to which the model is just regressing over this metadata, and just using the metadata to compute its predictions, ignoring the sensor data that we’re providing it as well.
We graphed the DNN predictions alongside the logistic regression predictions. The DNN here takes in one week of user data at a time, in this graph, each dot is one week of user data. The answer to our initial question is no, the graph is not particularly linear. Clearly, the DNN is using extra signal above and beyond just the age, sex, and BMI, but there’s something else striking about this graph. It’s made up of vertical lines, and each line is actually formed by a single user with multiple weeks’ worth of data, so their LR prediction is unchanging because their age, sex, and BMI don’t change, but the DNN prediction varies over the weeks. Sometimes this ranges from 0.1 to 0.7, so, this is actually a really useful piece of information, it tells us that there’s some sort of improvement that can be made over and above just averaging our DNN predictions, which is what we were doing in the past. Perhaps our filtering isn’t strong enough and we were including weeks of user data where the user had worn their watch for a few hours, and the DNN wasn’t able to pick up enough of the signal. Alternatively, the DNN predictions were actually accurate, because these are sleep apnea predictions and sleep apnea can be sporadic. After a night of drinking, you probably have a lot more apnea events than otherwise.
In addition to an aggregate analysis like that, it’s really useful to take a look at examples of wins and losses. Sort your tuning set by absolute error, and take a look at a handful examples where the prediction is really far from the label, and look in both directions, look for false positives and look for false negatives. From these examples, come up with a hypothesis or a pattern of what’s causing these errors, and then take a look at the wins of the model and make sure this pattern doesn’t apply
You want to find a pattern that explains why your DNN isn’t working in certain cases; next, you can stratify your input set into cases that exhibit the pattern and those that don’t, and take a look at the accuracy metrics here, you should find that the pattern does explain some sort of breakage in the model. This technique isn’t specific to deep learning per se, recently, we were debugging a logistic regression model on hand engineered features, and we discovered a few loss patterns here. First, users who work out a lot during the week, throw off our estimations of daytime heart rate standard deviation. Second, users who travel a lot don’t fit our assumptions about sleep time and wake time, and time zones.
Predicting Synthetic Outputs
Next, I’m going to talk about one more technique that we call predicting synthetic outputs. As a first step in evaluating a model architecture to see if it’s suitable, we trained the DNN to predict a synthetic task using the heart rate and step count data, the task is just a deterministic function of the data. I’ll give you an example, we applied this when coming up with an architecture to predict sleep apnea. From existing literature, we knew that a feature standard deviation of daytime heart rate minus standard deviation of nighttime heart rate was particularly predictive of sleep apnea. We trained the DNN to predict this, but in order for a DNN to be able to predict this with low mean absolute error, it has to have at least a few properties.
First, it has to be able to distinguish day from night, this is not particularly obvious, in the past, we had the slide that showed this delta time channel that advanced in seconds. It’s possible that the DNN won’t know when daytime is, when nighttime is. The DNN also has to be able to remember data from several days in the past, this is a common problem with LSTMs, sometimes they can’t. This is a good sanity check to make sure that your architecture at least has the capability of learning what it has to learn.
In the past, we’ve also used this kind of synthetic task as a form of semi supervised training. As Avesh mentioned, we have over 500,000 daily users, so we have a lot more unlabeled heart rate data than we have labeled heart rate data. We can construct synthetic labels from the unlabeled data, train the network on these, and then use the learned weights, apart from the last layer, as the initial values of supervised training over the labels that we do have. Next, I’ll pass it off to Avesh [Singh], to talk about some more debugging techniques.
Singh: Let’s talk about a pretty simple idea, which is, visualizing your model’s activations. This is the architecture slide that Mantas presented a few minutes ago, and we’re going to be examining the outputs of the convolutional layer. Actually, there are three convolutional layers, so we’re going to be looking at the output of the last one. Oftentimes you’ll see CNNs convolving over images, in our architecture, the input is not an image, it’s a time series data. Our temporal convolutional layers are learning functions to apply to pieces of time series data.
Let’s start by understanding a single neuron here, each individual neuron in this layer takes as input, a time series of data, and it applies some function, which returns another time series. In this diagram is a convolutional neuron with width four, it applies the transformation to its inputs that’s shown here, multiplied by a vector of learn rates W out of bias term B, passes through a non-linearity F, then out comes H, the hidden output of this neuron. The activation function we use here is a rectified linear unit or a value, I was talking about that chair with Mike, and he joked that ReLU is basically a pretentious name from Max, and that’s what it is. We’ve graphed the ReLU, aka the max function here, so we’re going to be visualizing the output H.
We obtained the output from this neuron for every time step, and it applies a convolution with stride one. We have 128 such neurons in this layer, so, ultimately, we’re going to end up with a matrix that looks like this. The rows here are neurons, and each row shows the activation of a single neuron on each time step of data. What we’re hoping we’ll notice is some cells are semantic properties, cells that light up when a user is sleeping, or working out, or anxious, or maybe we won’t, because, after all, the neurons form a distributed representation, and graphing each neuron’s output individually may be meaningless. I want to make sure this presentation is useful to your work, so, I’m going to actually show some code here, it’s a warning, there is some code ahead, this code uses Keras in TensorFlow.
The code should be easy to follow along, even if you’re used to PyTorch, like everything Python, it’s very readable. It’s actually very simple, layer output function here is a Keras function, it takes the input of layer zero and produces the output of the selected layer. We run this function on the actual input data, and we get back layer output, which is the time series output of each neuron, and that’s it. We got this idea from a Google brain paper that’s published at the Distill link below. If you’re interested in using this technique, I’d recommend that you take a look at that paper, it’s really cool and it’s very interactive, like all Distill papers.
We ran this code on the third convolutional layer for model for one week of user data, and we visualized the results in this graph. The shades of blue here show the value of the activations, so smaller values are light blue, larger values are dark blue. Question for you guys, do you notice anything strange about these activations?
Participant 1: No change with time.
Singh: Exactly, they don’t change with time. We would call these dead neurons, they output the same value regardless of their input. Why is this happening? We thought that this might have something to do with our activation function. Remember, we’re using a ReLU activation shown on the left, if we take the derivative of the ReLU, we get a piecewise function shown on the right. Notice that when the input is less than zero, the derivative is zero, the values will not be updated in gradient descent. Perhaps B is very negative in this function, causing the input to F to always be less than zero. That doesn’t really sound right, because if that were the case, then every neuron in this layer would output zero.
What’s more likely happening is that one of the earlier convolutional layers is always outputting zero, so each neuron in this layer just takes on the value of its bias term, because X is just zero, and we can use TensorBoard to verify this is what’s happening. If we were to pop the values of the first convolutional layer prior to the activation, we’d see a histogram like this. Notice that the pre-activation outputs here are all very negative. After passing through the ReLU, they’ll be set to zero, and then after that W times X will be zero, and W times X plus B will just take on the value of the bias term.
Full disclosure, this histogram isn’t actually from our model, we didn’t need to use TensorBoard to debug this problem, because it turns out this is a very common problem that many of you have probably heard of, and there’s also a common solution to it. We can make sure that the gradient always has a non-zero value by using a leaky ReLU. This function has a value of X when X is greater than zero, like the ReLU, but has a small fraction of X when X is less than or equal to zero, so even when the input is very negative, the grading is still propagated and the weights will still update.
We tried using the leaky ReLU for the convolutional layers, and, as you can see here, the activations for each cell now vary throughout time, but, you’ll notice that there are chunks of time here when most cells output zero values. These actually correspond with the times when the user turned their watch into workout mode, which means that the Apple watch is going to take a reading every five seconds, rather than every five minutes. This suggested our convolutional layers can’t really handle these variable time scales. One potential solution that we’ve thought of is basically to take advantage of the fact that the Apple watch operates in two timescales, either every five minutes, or every five seconds, we could process inputs of different time scales separately, and then merge the results prior to the final layer. We actually haven’t tried this yet, so I can’t tell you if this worked.
Instead, let’s talk about a different problem, this is a problem that our DNN suffered from. We call it amnesia, and I’m going to tell you how we created a metric to quantify the issue. Recall that our input is one week of user data. It consists of 4,096 heart rate or step count readings. It’s important that our DNN be able to track long term dependencies, for example, when we’re predicting diabetes, we care a lot about heart rate recovery time. This is the amount of time it takes to get back to your resting heart rate after a workout. In order to compute this, the DNN must be able to store and retrieve the time when you ended your workout. We want to answer this question, is the DNN able to learn long term dependencies, or does it have amnesia?
We can find an answer to this question using gradient analysis. Let’s examine the gradient of the output with respect to the input of each time step. In our architecture, a prediction is output at every time step and they’re later aggregated into a single score, so let’s look at the very last output. If the DNN has amnesia, we expect that the first time step has only a miniscule contribution to the final output, whereas the last time step should have a huge contribution. In other words, the gradient to the output respect to time step 4,095 will be much greater than the gradient of the output respect to the first time step. How much greater is much greater? This is really context specific, in our case, it would be fine but not ideal if the last time step is 10 times as important as the first, but, if the last time step is a billion times more important, then we have a clear problem in our architecture.
This is the idea, the question is, how do we compute these gradients? Once again, warning, there’s some code ahead. This function is a bit more involved in the last function, but don’t be intimidated, I’ll walk through it line by line. We’re going to be writing a function gradient output with respect to input, and it’s going to compute the gradient of the last time step with respect to each of the input time steps and each of the input channels, those are heart rate and step count.
Computing the gradient sounds complicated, but we can use TensorFlow to do the heavy lifting. TensorFlow has a built-in gradient function that’s of course used in weight updates. In Backprop, we update the trainable parameters by using the gradient of the loss with respect to the learned parameters, and here we use the same function to save the gradient of the output with respect to the input.
This function takes a model as input, and we also need to provide some data as the gradient is only a function until we actually run it on some data. The first thing we do is take the output at the last time step, which is 4,095. Also, let’s only look at the first output task, which is diabetes, just for simplicity. That means that output tensor here is going to be a vector of one num users.
We don’t actually care about the gradients per user, we care about the average gradient across users, so, we just take the mean value across all users to get output tensor sum. This is going to be a scalar, it’s the sum of the last output value for each user. Now let’s figure out how the input affects this value, let’s take the gradient of output Tensor sum with respect to the inputs. Inputs here is a 3D tensor of shape, num users by num time steps by num input channels. We’re driving a scalar with respect to a 3D tensor so our results can be a 3D tensor.
Here we want to average over all users, so we take the mean across axis zero, are resulting in a tensor of shape, num time steps by num input channels. For example, gradient tensors of 10 zero is going to be the derivative the last output with respect to the 10th input heart rate, so we’re almost there. We just convert this to a Keras function and execute the function that provided data, returning the resulting gradient. I hope that makes sense, if you’d like to take a closer look offline, we did tweet out the slides, and this is also a slightly abbreviated version of the code. You can find the full code at this tiny URL, which leads to a git gist, take a look and feel free to steal it. We use this code to compute the gradients of our output with respect to each time step of our input, and we’ve graphed that here.
One important note before we dive into this, is that our LSTM layers are bi-directional, meaning they receive their input in order and in reverse order. For this reason, we plot the gradient of the output at time step 2,048, the midpoint with respect to each time step. On the X axis here is the time step of the input we’re taking the gradient with respect to, and on the Y axis is the value of the gradient. For example, the gradient of the output at time step 2,048 with respect to time step about 2,048, is about 0.001. You’ll notice that the Y axis here is a log scale, so, if you were to compare the time step 2,048 with time step 2,500, you’ll see that the gradient has dropped by a factor of a million. The input at time steps far from 2,048 has pretty much zero effect on the output. We would say this architecture definitely suffers from amnesia so we’ve answered that question.
Actually, a few months after we produced this graph, we reran this analysis on a newer architecture, and we found that it no longer has amnesia. The inputs at each time step now have roughly the same impact on the output. How did we fix this? Well, we actually made a number of changes to the architecture during this time. The most likely fix is that we’re no longer running average pooling over the time series output of the LSTM layer, and instead we use the LSTMs output at the last time step that has input data. This problem may actually have been fixed by some other change, let this be a lesson to make incremental changes to your model and to measure their effects following the scientific method. This is especially important if you plan on giving a talk on debuggable deep learning.
To summarize, we walked you through the first steps in creating a DNN architecture, understanding your problem and your data. We talked about model debugging, talked about examining your outputs, predicting synthetic outputs, amnesia, and visualizing activations. Before we take questions, I can’t help but it include a plug, which you’ve heard in every talk I’m sure, which is that we are hiring, the ML team at Cardiogram only has three people right now, and we’re looking for ML engineers with some prior research experience, so, if you’re interested in using our data to build models to predict cardiovascular disease, please shoot us an email.
Questions and Answers
Participant 2: Regarding getting rid of the issue that your first input doesn’t matter for the last output, did you solve it with an hierarchical architecture, or with just going convolutional with all the temporal layers?
Singh: The architecture we ended up using for that had convolutional inputs, and then we had a recurrent layers on top. What we think solved this is that the recurrent layers, instead of outputting at every time step, only output at the last time step that has input data. Remember, our inputs are 4,096, but we may not actually have 4,096 readings from the user, so if we have like 3,000 readings, then we just take the LSTM output at times step 3,000 as the model prediction. That’s a good question, it makes sense that when we’re using average pooling, the impact of the output at a particular time step will be very local to that time step, so this solution makes sense, but, like I said, we haven’t scientifically proven that that was the reason.
Participant 3: The [inaudible 00:29:10] plot. Can you explain that? The one you prepared models logistic against that.
Matelis: This is a scatter plot of logistic regression predictions and DNN predictions for each user week worth of data. Our hypothesis was that, it’s possible that the DNN isn’t actually using any of the heart rates and step counts we’ve given it. If that were the case, this would be a straight line, but it’s not a straight line. That means that the DNN is actually using the heart rates and step counts to make its prediction, but we also find that it’s made up of these vertical lines, and so, that means that each vertical line is one user’s worth of data, and that means that, for a single user that has the same logistic regression prediction, the DNN prediction per week can vary quite a bit. That means that there’s more investigation into how to combine these multiple weeks that may be anywhere from 0.1 to 0.7 into one global risk score that answers the question, how likely is it that we think that this person has sleep apnea?
Participant 4: I’m curious how you guys got the labels for the data.
Singh: I think we’re in a unique opportunity or unique situation, because we create the app as well. We have these 500,000 users who use Cardiogram for various reasons, like tracking their health, or enrolling in habits, and we sporadically ask them questions like, “Do you have diabetes? Do any of your family members have diabetes?” things like that, and we use those answers as the labels. We can actually verify these with a different data set that we’ve built partnering with UC San Francisco, where they have a more formal study where they will send out pages and pages of surveys to a bunch of patients, and we can use those labels as well.
Participant 6: I really like the point about overfitting on a small data set. In my experience that finds most of the problems in the network architecture and your backprop layers. Have there been instances where that strategy didn’t work?
Singh: I actually missed the first part of the question. Could you repeat the beginning?
Participant 6: Overfitting on a small data set to figure out the problems with the network is a very, I think, useful trick. I’m curious if there are any instances where that didn’t find some bugs, which you mentioned where you had to use another technique?
Singh: Yes. I think that we have found is oftentimes overfitting is too easy a task, so a single layer LSTM can overfit on a hand engineered feature. It’s a very basic unit test.
Participant 6: Overfitting on the entire network, but small dataset, like the network that you’re going to deploy, but on a very small dataset?
Singh: There are two aspects here, overfitting on a hand engineer feature and overfitting the actual labels. I’m not sure if there’s any major detriments we found to making it a precursor that our models be able to overfit, as long as you apply regularization afterwards. Our process would be, you can have a large model without much regularization overfit in a smaller data set, and then apply some regularization, like decrease the width of the model or add L2, and then you can trade off the training and tuning accuracy on the full data set.
See more presentations with transcripts