Skip to content

Commit

Permalink
Draft notebook on text classification with an LSTM
Browse files Browse the repository at this point in the history
  • Loading branch information
yvespeirsman committed Apr 9, 2019
1 parent 0ca491a commit c59ed2c
Showing 1 changed file with 241 additions and 0 deletions.
241 changes: 241 additions & 0 deletions drafts/Text classification with an LSTM in PyTorch.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"TRAIN_PATH = \"data/text_classification/20newsgroups_train.tsv\"\n",
"TEST_PATH = \"data/text_classification/20newsgroups_test.tsv\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import fetch_20newsgroups\n",
"\n",
"train = fetch_20newsgroups(subset=\"train\")\n",
"label2idx = {label: idx for idx, label in enumerate(train.target_names)}"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import csv\n",
"import sys\n",
"from torchtext.data import TabularDataset, Field, BucketIterator\n",
"\n",
"csv.field_size_limit(sys.maxsize)\n",
"\n",
"text = Field(sequential=True, tokenize=\"spacy\")\n",
"label = Field(sequential=False, use_vocab=False, preprocessing=lambda x: label2idx[x])\n",
"\n",
"train_data = TabularDataset(path=TRAIN_PATH, format='tsv', fields=[('label', label), ('text', text)])\n",
"test_data = TabularDataset(path=TEST_PATH, format='tsv', fields=[('label', label), ('text', text)])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'label': <torchtext.data.field.Field at 0x1a2265cc50>,\n",
" 'text': <torchtext.data.field.Field at 0x1a2265cbe0>}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data.fields"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"VOCAB_SIZE = 30000\n",
"\n",
"text.build_vocab(train_data, max_size=VOCAB_SIZE)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"BATCH_SIZE = 32\n",
"train_iter = BucketIterator(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)\n",
"dev_iter = BucketIterator(dataset=test_data, batch_size=BATCH_SIZE)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"\n",
"class LSTMClassifier(nn.Module):\n",
"\n",
" def __init__(self, embedding_dim, hidden_dim, vocab_size, output_size):\n",
" super(LSTMClassifier, self).__init__()\n",
" \n",
" # 1. Embedding Layer\n",
" self.embeddings = nn.Embedding(vocab_size, embedding_dim)\n",
" \n",
" # 2. LSTM Layer\n",
" self.lstm = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True, num_layers=2)\n",
" \n",
" # 3. Dense Layer\n",
" self.hidden2out = nn.Linear(hidden_dim, output_size)\n",
" \n",
" # Optional dropout layer\n",
" self.dropout_layer = nn.Dropout(p=0.4)\n",
"\n",
" def forward(self, batch_text):\n",
"\n",
" embeddings = self.embeddings(batch_text)\n",
" _, (ht, _) = self.lstm(embeddings)\n",
"\n",
" lstm_output = ht[-1]\n",
" lstm_output = self.dropout_layer(lstm_output)\n",
" final_output = self.hidden2out(lstm_output)\n",
" return final_output"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.optim as optim\n",
"from tqdm import tqdm_notebook as tqdm\n",
"\n",
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"\n",
"def train(model, train_iter, dev_iter, batch_size, num_batches):\n",
" criterion = nn.CrossEntropyLoss()\n",
" optimizer = optim.Adam(model.parameters())\n",
"\n",
" max_epochs = 20\n",
" for epoch in range(max_epochs):\n",
"\n",
" total_loss = 0\n",
" predictions, correct = [], []\n",
" for batch in tqdm(train_iter, total=num_batches):\n",
" optimizer.zero_grad()\n",
"\n",
" pred = model(batch.text.to(device))\n",
" loss = criterion(pred, batch.label.to(device))\n",
" total_loss += loss.item()\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" _, pred_indices = torch.max(pred, 1)\n",
" predictions += list(pred_indices.cpu().numpy())\n",
" correct += list(batch.label.cpu().numpy())\n",
"\n",
" print(\"=== Epoch\", epoch, \"===\")\n",
" print(\"Total training loss:\", total_loss)\n",
" print(\"Training performance:\", precision_recall_fscore_support(correct, predictions))\n",
" \n",
" total_loss = 0\n",
" predictions, correct = [], []\n",
" for batch in dev_iter:\n",
"\n",
" pred = model(batch.text.to(device))\n",
" loss = criterion(pred, batch.label.to(device))\n",
" total_loss += loss.item()\n",
"\n",
" _, pred_indices = torch.max(pred, 1)\n",
" pred_indices = list(pred_indices.cpu().numpy())\n",
" predictions += pred_indices\n",
" correct += list(batch.label.cpu().numpy())\n",
"\n",
" print(\"Total development loss:\", total_loss)\n",
" print(\"Development performance:\", precision_recall_fscore_support(correct, predictions))\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e0b7c1d928ca4a3288aa66e5918d2e72",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=353), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"EMBEDDING_DIM = 300\n",
"HIDDEN_DIM = 256\n",
"NUM_CLASSES = len(label2idx)\n",
"num_batches = int(len(train_data) / BATCH_SIZE)\n",
"\n",
"classifier = LSTMClassifier(EMBEDDING_DIM, HIDDEN_DIM, VOCAB_SIZE+2, NUM_CLASSES) \n",
"\n",
"train(classifier.to(device), train_iter, dev_iter, BATCH_SIZE, num_batches)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit c59ed2c

Please sign in to comment.