{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "wJpXpmjEYC_T" }, "source": [ "## Building a GPT\n", "\n", "## Attribution\n", "\n", "Note that this bigram notebook is based very heavily on Andrej Karpathy's \"makemore\" aka NN-Zero-to-Hero code and videos. All credit goes to him.\n", "\n", "You can find his repo here: https://github.com/karpathy/ng-video-lecture as well as the makemore repo here: https://github.com/karpathy/makemore\n", "\n", "His video is extremely excellent and can be found here: https://www.youtube.com/watch?v=kCc8FmEb1nY\n", "\n", "Refer to his LICENSE file in this folder.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from IPython.display import display, HTML\n", "display(HTML(\"\"))\n", "display(HTML(\"\"))\n", "display(HTML(\"\"))\n", "\n", "import torch\n", "torch.set_printoptions(linewidth=230)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "h5hjCcLDr2WC", "outputId": "ccc60f0c-fd78-4dbe-8598-0512d1036aad" }, "outputs": [], "source": [ "# Download the tiny shakespeare dataset\n", "!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "O6medjfRsLD9" }, "outputs": [], "source": [ "# read it in to inspect it\n", "with open('input.txt', 'r', encoding='utf-8') as f:\n", " text = f.read()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "6xWI_VyAsN8F", "outputId": "ed819dd0-72e5-40a6-d2ed-928ff73bfda6" }, "outputs": [], "source": [ "print(\"length of dataset in characters: \", len(text))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "2c5V0FvqseE0", "outputId": "25ca7adc-b8c0-42d1-b08c-e0863c5c314e" }, "outputs": [], "source": [ "# First 1000 characters\n", "print(text[:1000])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0e-Rbyr8sfM8", "outputId": "f34e94a9-5b44-4cf3-885b-986731929109" }, "outputs": [], "source": [ "# here are all the unique characters that occur in this text\n", "chars = sorted(list(set(text)))\n", "vocab_size = len(chars)\n", "print(''.join(chars))\n", "print(vocab_size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Yw1LKNCgwjj1", "outputId": "86fcc21c-2cf7-40d9-cd7b-b5a253da4459" }, "outputs": [], "source": [ "# create a mapping from characters to integers\n", "stoi = { ch:i for i,ch in enumerate(chars) }\n", "itos = { i:ch for i,ch in enumerate(chars) }\n", "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n", "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n", "\n", "print(encode(\"hii there\"))\n", "print(decode(encode(\"hii there\")))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "YJb0OXPwzvqg", "outputId": "db7297cc-36a9-4fae-e941-e7bb9e0e91d1" }, "outputs": [], "source": [ "# let's now encode the entire text dataset and store it into a torch.Tensor\n", "import torch # we use PyTorch: https://pytorch.org\n", "data = torch.tensor(encode(text), dtype=torch.long)\n", "print(data.shape, data.dtype)\n", "print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "f_WIXqxz0lU5" }, "outputs": [], "source": [ "# Let's now split up the data into train and validation sets\n", "n = int(0.9*len(data)) # first 90% will be train, rest val\n", "train_data = data[:n]\n", "val_data = data[n:]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "TD5Bj8Y6IAD4", "outputId": "bf23c586-1d33-4af1-b63d-ce6f90b0a528" }, "outputs": [], "source": [ "block_size = 8\n", "train_data[:block_size+1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "9HXDe8vGJCEn", "outputId": "588663aa-1de5-4ef7-aba0-4a96fe828353" }, "outputs": [], "source": [ "x = train_data[:block_size]\n", "y = train_data[1:block_size+1]\n", "for t in range(block_size):\n", " context = x[:t+1]\n", " target = y[t]\n", " print(f\"when input is {context} the target: {target}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Q3k1Czf7LuA9", "outputId": "4ea8e8a0-443c-49bb-b3bf-ba36e1712999" }, "outputs": [], "source": [ "torch.manual_seed(1337)\n", "batch_size = 4 # how many independent sequences will we process in parallel?\n", "block_size = 8 # what is the maximum context length for predictions?\n", "\n", "def get_batch(split):\n", " # generate a small batch of data of inputs x and targets y\n", " data = train_data if split == 'train' else val_data\n", " ix = torch.randint(len(data) - block_size, (batch_size,))\n", " x = torch.stack([data[i:i+block_size] for i in ix])\n", " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n", " return x, y\n", "\n", "xb, yb = get_batch('train')\n", "print('inputs:')\n", "print(xb.shape)\n", "print(xb)\n", "print('targets:')\n", "print(yb.shape)\n", "print(yb)\n", "\n", "print('----')\n", "\n", "for b in range(batch_size): # batch dimension\n", " for t in range(block_size): # time dimension\n", " context = xb[b, :t+1]\n", " target = yb[b,t]\n", " print(f\"when input is {context.tolist()} the target: {target}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "qpyyAeIzQjlO", "outputId": "a650f8dc-da81-400b-bc59-0a595487fdb9" }, "outputs": [], "source": [ "print(xb) # our input to the transformer" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "nql_1ER53oCf", "outputId": "5de90b1b-4603-428a-f571-fe4bd3c45436" }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.nn import functional as F\n", "torch.manual_seed(1337)\n", "\n", "class BigramLanguageModel(nn.Module):\n", "\n", " def __init__(self, vocab_size):\n", " super().__init__()\n", " # each token directly reads off the logits for the next token from a lookup table\n", " self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)\n", "\n", " def forward(self, idx, targets=None):\n", "\n", " # idx and targets are both (B,T) tensor of integers\n", " logits = self.token_embedding_table(idx) # (B,T,C)\n", "\n", " if targets is None:\n", " loss = None\n", " else:\n", " B, T, C = logits.shape\n", " logits = logits.view(B*T, C)\n", " targets = targets.view(B*T)\n", " loss = F.cross_entropy(logits, targets)\n", "\n", " return logits, loss\n", "\n", " def generate(self, idx, max_new_tokens):\n", " # idx is (B, T) array of indices in the current context\n", " for _ in range(max_new_tokens):\n", " # get the predictions\n", " logits, loss = self(idx)\n", " # focus only on the last time step\n", " logits = logits[:, -1, :] # becomes (B, C)\n", " # apply softmax to get probabilities\n", " probs = F.softmax(logits, dim=-1) # (B, C)\n", " # sample from the distribution\n", " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n", " # append sampled index to the running sequence\n", " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n", " return idx\n", "\n", "m = BigramLanguageModel(vocab_size)\n", "logits, loss = m(xb, yb)\n", "print(logits.shape)\n", "print(loss)\n", "\n", "print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eTyJ8qAaDdiF" }, "outputs": [], "source": [ "# create a PyTorch optimizer\n", "optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Hs4kI8YdEkQj", "outputId": "42ded55c-2983-4d91-c528-675b2edfa849" }, "outputs": [], "source": [ "batch_size = 32\n", "for steps in range(100): # increase number of steps for good results...\n", "\n", " # sample a batch of data\n", " xb, yb = get_batch('train')\n", "\n", " # evaluate the loss\n", " logits, loss = m(xb, yb)\n", " optimizer.zero_grad(set_to_none=True)\n", " loss.backward()\n", " optimizer.step()\n", "\n", "print(loss.item())\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "EcVIDWAZEtjN", "outputId": "0ad6f9d2-ad58-4498-a5f8-6f31407bb18b" }, "outputs": [], "source": [ "print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))" ] }, { "cell_type": "markdown", "metadata": { "id": "XinV8nmAnmKN" }, "source": [ "## The mathematical trick in self-attention" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "tukiH-NbRBhA", "outputId": "d981f6d4-ac08-4ec2-8284-82f5fa1e0815" }, "outputs": [], "source": [ "# toy example illustrating how matrix multiplication can be used for a \"weighted aggregation\"\n", "torch.manual_seed(42)\n", "a = torch.tril(torch.ones(3, 3))\n", "a = a / torch.sum(a, 1, keepdim=True)\n", "b = torch.randint(0,10,(3,2)).float()\n", "c = a @ b\n", "print('a=')\n", "print(a)\n", "print('--')\n", "print('b=')\n", "print(b)\n", "print('--')\n", "print('c=')\n", "print(c)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Hs_E24uRE8kr", "outputId": "8bf3ff5f-565e-48b8-de8e-7272706c8e12" }, "outputs": [], "source": [ "# consider the following toy example:\n", "\n", "torch.manual_seed(1337)\n", "B,T,C = 4,8,2 # batch, time, channels\n", "x = torch.randn(B,T,C)\n", "x.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "86NuXX0fn7ps" }, "outputs": [], "source": [ "# We want x[b,t] = mean_{i<=t} x[b,i]\n", "xbow = torch.zeros((B,T,C))\n", "for b in range(B):\n", " for t in range(T):\n", " xprev = x[b,:t+1] # (t,C)\n", " xbow[b,t] = torch.mean(xprev, 0)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "yhdOAd6-wXkZ", "outputId": "eaf6ab61-dff1-4bb7-e623-47f692bad5f9" }, "outputs": [], "source": [ "# version 2: using matrix multiply for a weighted aggregation\n", "wei = torch.tril(torch.ones(T, T))\n", "wei = wei / wei.sum(1, keepdim=True)\n", "xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)\n", "torch.allclose(xbow, xbow2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "wOURrfG-ysoL", "outputId": "080b500d-8110-4602-fcef-7d6f2ebfc6bc" }, "outputs": [], "source": [ "# version 3: use Softmax\n", "tril = torch.tril(torch.ones(T, T))\n", "wei = torch.zeros((T,T))\n", "wei = wei.masked_fill(tril == 0, float('-inf'))\n", "wei = F.softmax(wei, dim=-1)\n", "xbow3 = wei @ x\n", "torch.allclose(xbow, xbow3)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "EDarxEWIRMKq", "outputId": "07b587dd-a91c-4bb0-d7f1-e247cd5dacb5" }, "outputs": [], "source": [ "# version 4: self-attention!\n", "torch.manual_seed(1337)\n", "B,T,C = 4,8,32 # batch, time, channels\n", "x = torch.randn(B,T,C)\n", "\n", "if 0:\n", " # let's see a single Head perform self-attention\n", " head_size = 16\n", " key = nn.Linear(C, head_size, bias=False)\n", " query = nn.Linear(C, head_size, bias=False)\n", " value = nn.Linear(C, head_size, bias=False)\n", " k = key(x) # (B, T, 16)\n", " q = query(x) # (B, T, 16)\n", " wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)\n", "else:\n", " wei = torch.zeros((T,T))\n", "\n", "tril = torch.tril(torch.ones(T, T))\n", "\n", "wei = wei.masked_fill(tril == 0, float('-inf'))\n", "wei = F.softmax(wei, dim=-1)\n", "\n", "v = value(x)\n", "out = wei @ v\n", "#out = wei @ x\n", "\n", "out.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "vT1hdtzXCjgL", "outputId": "6d2c569b-7922-451f-9934-0fc564678d17" }, "outputs": [], "source": [ "wei" ] }, { "cell_type": "markdown", "metadata": { "id": "M5CvobiQ0pLr" }, "source": [ "Notes:\n", "- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.\n", "- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.\n", "- Each example across batch dimension is of course processed completely independently and never \"talk\" to each other\n", "- In an \"encoder\" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a \"decoder\" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.\n", "- \"self-attention\" just means that the keys and values are produced from the same source as queries. In \"cross-attention\", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)\n", "- \"Scaled\" attention additional divides `wei` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4SNbLq5z3oBw" }, "outputs": [], "source": [ "k = torch.randn(B,T,head_size)\n", "q = torch.randn(B,T,head_size)\n", "wei = q @ k.transpose(-2, -1) * head_size**-0.5" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Nl6I9n9IRTSo", "outputId": "0c5b9cd0-af8a-4564-fbad-41d844e54822" }, "outputs": [], "source": [ "k.var()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "T1tQx7oeRvtc", "outputId": "3541ca1a-7447-4ef7-835e-81824aebc1b5" }, "outputs": [], "source": [ "q.var()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "MLb_odHU3iKM", "outputId": "a687a222-5a2c-4cdb-c1bf-17cd05b45b69" }, "outputs": [], "source": [ "wei.var()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "JB82yzt44REI", "outputId": "f07da2f1-10bb-4a7a-bcaa-578587977d00" }, "outputs": [], "source": [ "torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Mpt8569BB9_f", "outputId": "5d8b910a-6192-44ba-ebb2-497d88e0b629" }, "outputs": [], "source": [ "torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=-1) # gets too peaky, converges to one-hot" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "2Num7sX9CKOH", "outputId": "929ceb78-a639-41d6-aac7-12997b5c93f0" }, "outputs": [], "source": [ "class LayerNorm1d: # (used to be BatchNorm1d)\n", "\n", " def __init__(self, dim, eps=1e-5, momentum=0.1):\n", " self.eps = eps\n", " self.gamma = torch.ones(dim)\n", " self.beta = torch.zeros(dim)\n", "\n", " def __call__(self, x):\n", " # calculate the forward pass\n", " xmean = x.mean(1, keepdim=True) # batch mean\n", " xvar = x.var(1, keepdim=True) # batch variance\n", " xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance\n", " self.out = self.gamma * xhat + self.beta\n", " return self.out\n", "\n", " def parameters(self):\n", " return [self.gamma, self.beta]\n", "\n", "torch.manual_seed(1337)\n", "module = LayerNorm1d(100)\n", "x = torch.randn(32, 100) # batch size 32 of 100-dimensional vectors\n", "x = module(x)\n", "x.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "633T2cmnW1uk", "outputId": "7720fa58-0478-4e8a-86a7-502d4cce9443" }, "outputs": [], "source": [ "x[:,0].mean(), x[:,0].std() # mean,std of one feature across all batch inputs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "LN9cK9BoXCYb", "outputId": "6368ece0-600e-417d-8a91-7c1e5d750ba8" }, "outputs": [], "source": [ "x[0,:].mean(), x[0,:].std() # mean,std of a single input from the batch, of its features" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dRJH6wM_XFfU" }, "outputs": [], "source": [ "# French to English translation example:\n", "\n", "# <--------- ENCODE ------------------><--------------- DECODE ----------------->\n", "# les réseaux de neurones sont géniaux! neural networks are awesome!\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ZcvKeBXoZFOY" }, "source": [ "### Params and pre-processing" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.nn import functional as F\n", "\n", "# hyperparameters\n", "batch_size = 16 # how many independent sequences will we process in parallel?\n", "block_size = 32 # what is the maximum context length for predictions?\n", "max_iters = 5000\n", "eval_interval = 100\n", "learning_rate = 1e-3\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "eval_iters = 200\n", "n_embd = 64\n", "n_head = 4\n", "n_layer = 4\n", "dropout = 0.0\n", "# ------------\n", "\n", "torch.manual_seed(1337)\n", "\n", "# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n", "with open('input.txt', 'r', encoding='utf-8') as f:\n", " text = f.read()\n", "\n", "# here are all the unique characters that occur in this text\n", "chars = sorted(list(set(text)))\n", "vocab_size = len(chars)\n", "# create a mapping from characters to integers\n", "stoi = { ch:i for i,ch in enumerate(chars) }\n", "itos = { i:ch for i,ch in enumerate(chars) }\n", "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n", "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n", "\n", "# Train and test splits\n", "data = torch.tensor(encode(text), dtype=torch.long)\n", "n = int(0.9*len(data)) # first 90% will be train, rest val\n", "train_data = data[:n]\n", "val_data = data[n:]\n", "\n", "# data loading\n", "def get_batch(split):\n", " # generate a small batch of data of inputs x and targets y\n", " data = train_data if split == 'train' else val_data\n", " ix = torch.randint(len(data) - block_size, (batch_size,))\n", " x = torch.stack([data[i:i+block_size] for i in ix])\n", " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n", " x, y = x.to(device), y.to(device)\n", " return x, y\n", "\n", "@torch.no_grad()\n", "def estimate_loss():\n", " out = {}\n", " model.eval()\n", " for split in ['train', 'val']:\n", " losses = torch.zeros(eval_iters)\n", " for k in range(eval_iters):\n", " X, Y = get_batch(split)\n", " logits, loss = model(X, Y)\n", " losses[k] = loss.item()\n", " out[split] = losses.mean()\n", " model.train()\n", " return out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Self-attention head" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Head(nn.Module):\n", " \"\"\" one head of self-attention \"\"\"\n", "\n", " def __init__(self, head_size):\n", " super().__init__()\n", " self.key = nn.Linear(n_embd, head_size, bias=False)\n", " self.query = nn.Linear(n_embd, head_size, bias=False)\n", " self.value = nn.Linear(n_embd, head_size, bias=False)\n", " self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n", "\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, x):\n", " B,T,C = x.shape\n", " k = self.key(x) # (B,T,C)\n", " q = self.query(x) # (B,T,C)\n", " # compute attention scores (\"affinities\")\n", " wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)\n", " wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)\n", " wei = F.softmax(wei, dim=-1) # (B, T, T)\n", " wei = self.dropout(wei)\n", " # perform the weighted aggregation of the values\n", " v = self.value(x) # (B,T,C)\n", " out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)\n", " return out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Our new bigram model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class BigramLanguageModel(nn.Module):\n", "\n", " def __init__(self):\n", " super().__init__()\n", " # each token directly reads off the logits for the next token from a lookup table\n", " self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n", " self.position_embedding_table = nn.Embedding(block_size, n_embd)\n", " self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])\n", " self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n", " self.lm_head = nn.Linear(n_embd, vocab_size)\n", "\n", " def forward(self, idx, targets=None):\n", " B, T = idx.shape\n", "\n", " # idx and targets are both (B,T) tensor of integers\n", " tok_emb = self.token_embedding_table(idx) # (B,T,C)\n", " pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n", " x = tok_emb + pos_emb # (B,T,C)\n", " x = self.blocks(x) # (B,T,C)\n", " x = self.ln_f(x) # (B,T,C)\n", " logits = self.lm_head(x) # (B,T,vocab_size)\n", "\n", " if targets is None:\n", " loss = None\n", " else:\n", " B, T, C = logits.shape\n", " logits = logits.view(B*T, C)\n", " targets = targets.view(B*T)\n", " loss = F.cross_entropy(logits, targets)\n", "\n", " return logits, loss\n", "\n", " def generate(self, idx, max_new_tokens):\n", " # idx is (B, T) array of indices in the current context\n", " for _ in range(max_new_tokens):\n", " # crop idx to the last block_size tokens\n", " idx_cond = idx[:, -block_size:]\n", " # get the predictions\n", " logits, loss = self(idx_cond)\n", " # focus only on the last time step\n", " logits = logits[:, -1, :] # becomes (B, C)\n", " # apply softmax to get probabilities\n", " probs = F.softmax(logits, dim=-1) # (B, C)\n", " # sample from the distribution\n", " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n", " # append sampled index to the running sequence\n", " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n", " return idx" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "hoelkOrFY8bN", "outputId": "961304cd-e379-40d4-dd56-8de0b91d2861" }, "outputs": [], "source": [ "\n", "class MultiHeadAttention(nn.Module):\n", " \"\"\" multiple heads of self-attention in parallel \"\"\"\n", "\n", " def __init__(self, num_heads, head_size):\n", " super().__init__()\n", " self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n", " self.proj = nn.Linear(n_embd, n_embd)\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, x):\n", " out = torch.cat([h(x) for h in self.heads], dim=-1)\n", " out = self.dropout(self.proj(out))\n", " return out\n", "\n", "class FeedFoward(nn.Module):\n", " \"\"\" a simple linear layer followed by a non-linearity \"\"\"\n", "\n", " def __init__(self, n_embd):\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(n_embd, 4 * n_embd),\n", " nn.ReLU(),\n", " nn.Linear(4 * n_embd, n_embd),\n", " nn.Dropout(dropout),\n", " )\n", "\n", " def forward(self, x):\n", " return self.net(x)\n", "\n", "class Block(nn.Module):\n", " \"\"\" Transformer block: communication followed by computation \"\"\"\n", "\n", " def __init__(self, n_embd, n_head):\n", " # n_embd: embedding dimension, n_head: the number of heads we'd like\n", " super().__init__()\n", " head_size = n_embd // n_head\n", " self.sa = MultiHeadAttention(n_head, head_size)\n", " self.ffwd = FeedFoward(n_embd)\n", " self.ln1 = nn.LayerNorm(n_embd)\n", " self.ln2 = nn.LayerNorm(n_embd)\n", "\n", " def forward(self, x):\n", " x = x + self.sa(self.ln1(x))\n", " x = x + self.ffwd(self.ln2(x))\n", " return x\n", "\n", "model = BigramLanguageModel()\n", "m = model.to(device)\n", "# print the number of parameters in the model\n", "print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')\n", "\n", "# create a PyTorch optimizer\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n", "\n", "for iter in range(max_iters):\n", "\n", " # every once in a while evaluate the loss on train and val sets\n", " if iter % eval_interval == 0 or iter == max_iters - 1:\n", " losses = estimate_loss()\n", " print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n", "\n", " # sample a batch of data\n", " xb, yb = get_batch('train')\n", "\n", " # evaluate the loss\n", " logits, loss = model(xb, yb)\n", " optimizer.zero_grad(set_to_none=True)\n", " loss.backward()\n", " optimizer.step()\n", "\n", "# generate from the model\n", "context = torch.zeros((1, 1), dtype=torch.long, device=device)\n", "print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fjjvMifYZf7x" }, "outputs": [], "source": [] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 1 }