forked from nlptown/nlp-notebooks
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Draft notebook on text classification with an LSTM
- Loading branch information
1 parent
0ca491a
commit c59ed2c
Showing
1 changed file
with
241 additions
and
0 deletions.
There are no files selected for viewing
241 changes: 241 additions & 0 deletions
241
drafts/Text classification with an LSTM in PyTorch.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,241 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"TRAIN_PATH = \"data/text_classification/20newsgroups_train.tsv\"\n", | ||
"TEST_PATH = \"data/text_classification/20newsgroups_test.tsv\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from sklearn.datasets import fetch_20newsgroups\n", | ||
"\n", | ||
"train = fetch_20newsgroups(subset=\"train\")\n", | ||
"label2idx = {label: idx for idx, label in enumerate(train.target_names)}" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import csv\n", | ||
"import sys\n", | ||
"from torchtext.data import TabularDataset, Field, BucketIterator\n", | ||
"\n", | ||
"csv.field_size_limit(sys.maxsize)\n", | ||
"\n", | ||
"text = Field(sequential=True, tokenize=\"spacy\")\n", | ||
"label = Field(sequential=False, use_vocab=False, preprocessing=lambda x: label2idx[x])\n", | ||
"\n", | ||
"train_data = TabularDataset(path=TRAIN_PATH, format='tsv', fields=[('label', label), ('text', text)])\n", | ||
"test_data = TabularDataset(path=TEST_PATH, format='tsv', fields=[('label', label), ('text', text)])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'label': <torchtext.data.field.Field at 0x1a2265cc50>,\n", | ||
" 'text': <torchtext.data.field.Field at 0x1a2265cbe0>}" | ||
] | ||
}, | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"train_data.fields" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"VOCAB_SIZE = 30000\n", | ||
"\n", | ||
"text.build_vocab(train_data, max_size=VOCAB_SIZE)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"BATCH_SIZE = 32\n", | ||
"train_iter = BucketIterator(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)\n", | ||
"dev_iter = BucketIterator(dataset=test_data, batch_size=BATCH_SIZE)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch.nn as nn\n", | ||
"\n", | ||
"class LSTMClassifier(nn.Module):\n", | ||
"\n", | ||
" def __init__(self, embedding_dim, hidden_dim, vocab_size, output_size):\n", | ||
" super(LSTMClassifier, self).__init__()\n", | ||
" \n", | ||
" # 1. Embedding Layer\n", | ||
" self.embeddings = nn.Embedding(vocab_size, embedding_dim)\n", | ||
" \n", | ||
" # 2. LSTM Layer\n", | ||
" self.lstm = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True, num_layers=2)\n", | ||
" \n", | ||
" # 3. Dense Layer\n", | ||
" self.hidden2out = nn.Linear(hidden_dim, output_size)\n", | ||
" \n", | ||
" # Optional dropout layer\n", | ||
" self.dropout_layer = nn.Dropout(p=0.4)\n", | ||
"\n", | ||
" def forward(self, batch_text):\n", | ||
"\n", | ||
" embeddings = self.embeddings(batch_text)\n", | ||
" _, (ht, _) = self.lstm(embeddings)\n", | ||
"\n", | ||
" lstm_output = ht[-1]\n", | ||
" lstm_output = self.dropout_layer(lstm_output)\n", | ||
" final_output = self.hidden2out(lstm_output)\n", | ||
" return final_output" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"import torch.optim as optim\n", | ||
"from tqdm import tqdm_notebook as tqdm\n", | ||
"\n", | ||
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", | ||
"\n", | ||
"\n", | ||
"def train(model, train_iter, dev_iter, batch_size, num_batches):\n", | ||
" criterion = nn.CrossEntropyLoss()\n", | ||
" optimizer = optim.Adam(model.parameters())\n", | ||
"\n", | ||
" max_epochs = 20\n", | ||
" for epoch in range(max_epochs):\n", | ||
"\n", | ||
" total_loss = 0\n", | ||
" predictions, correct = [], []\n", | ||
" for batch in tqdm(train_iter, total=num_batches):\n", | ||
" optimizer.zero_grad()\n", | ||
"\n", | ||
" pred = model(batch.text.to(device))\n", | ||
" loss = criterion(pred, batch.label.to(device))\n", | ||
" total_loss += loss.item()\n", | ||
"\n", | ||
" loss.backward()\n", | ||
" optimizer.step()\n", | ||
"\n", | ||
" _, pred_indices = torch.max(pred, 1)\n", | ||
" predictions += list(pred_indices.cpu().numpy())\n", | ||
" correct += list(batch.label.cpu().numpy())\n", | ||
"\n", | ||
" print(\"=== Epoch\", epoch, \"===\")\n", | ||
" print(\"Total training loss:\", total_loss)\n", | ||
" print(\"Training performance:\", precision_recall_fscore_support(correct, predictions))\n", | ||
" \n", | ||
" total_loss = 0\n", | ||
" predictions, correct = [], []\n", | ||
" for batch in dev_iter:\n", | ||
"\n", | ||
" pred = model(batch.text.to(device))\n", | ||
" loss = criterion(pred, batch.label.to(device))\n", | ||
" total_loss += loss.item()\n", | ||
"\n", | ||
" _, pred_indices = torch.max(pred, 1)\n", | ||
" pred_indices = list(pred_indices.cpu().numpy())\n", | ||
" predictions += pred_indices\n", | ||
" correct += list(batch.label.cpu().numpy())\n", | ||
"\n", | ||
" print(\"Total development loss:\", total_loss)\n", | ||
" print(\"Development performance:\", precision_recall_fscore_support(correct, predictions))\n", | ||
" " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "e0b7c1d928ca4a3288aa66e5918d2e72", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"HBox(children=(IntProgress(value=0, max=353), HTML(value='')))" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
} | ||
], | ||
"source": [ | ||
"EMBEDDING_DIM = 300\n", | ||
"HIDDEN_DIM = 256\n", | ||
"NUM_CLASSES = len(label2idx)\n", | ||
"num_batches = int(len(train_data) / BATCH_SIZE)\n", | ||
"\n", | ||
"classifier = LSTMClassifier(EMBEDDING_DIM, HIDDEN_DIM, VOCAB_SIZE+2, NUM_CLASSES) \n", | ||
"\n", | ||
"train(classifier.to(device), train_iter, dev_iter, BATCH_SIZE, num_batches)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |