Skip to content

Commit 84be9c7

Browse files
committed
add final eval
1 parent bcc45a5 commit 84be9c7

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

homework04/part1_salary_prediction.ipynb

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
"source": [
77
"# Natural Language Processing with Deep Learning\n",
88
"\n",
9-
"__This is exactly the same notebook as in ../week10_textconv/. Feel free submit the seminar notebook, just make sure you read the assignments at the end.\n",
9+
"__This is exactly the same notebook as in `../week10_textconv/.`__\n",
10+
"\n",
11+
"__Feel free submit the seminar notebook, just make sure you read the assignments at the end.__\n",
1012
"\n",
1113
"Today we're gonna apply the newly learned DL tools for sequence processing to the task of predicting job salary.\n",
1214
"\n",
@@ -769,6 +771,33 @@
769771
" print('\\n\\n')"
770772
]
771773
},
774+
{
775+
"cell_type": "code",
776+
"execution_count": null,
777+
"metadata": {
778+
"collapsed": true
779+
},
780+
"outputs": [],
781+
"source": [
782+
"print(\"Final eval:\")\n",
783+
"for batch in iterate_minibatches(data_val, shuffle=False):\n",
784+
" title_ix = Variable(torch.LongTensor(batch[\"Title\"]), volatile=True)\n",
785+
" desc_ix = Variable(torch.LongTensor(batch[\"FullDescription\"]), volatile=True)\n",
786+
" cat_features = Variable(torch.FloatTensor(batch[\"Categorical\"]), volatile=True)\n",
787+
" reference = Variable(torch.FloatTensor(batch[target_column]), volatile=True)\n",
788+
"\n",
789+
" prediction = model(title_ix, desc_ix, cat_features)\n",
790+
" loss = compute_loss(reference, prediction)\n",
791+
"\n",
792+
" val_loss += loss.data.numpy()[0]\n",
793+
" val_mae += compute_mae(reference, prediction).data.numpy()[0]\n",
794+
" val_batches += 1\n",
795+
"\n",
796+
"print(\"\\tLoss:\\t%.5f\" % (val_loss / val_batches))\n",
797+
"print(\"\\tMAE:\\t%.5f\" % (val_mae / val_batches))\n",
798+
"print('\\n\\n')"
799+
]
800+
},
772801
{
773802
"cell_type": "markdown",
774803
"metadata": {},

0 commit comments

Comments
 (0)