{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lu5JZZ8qkgKD"
      },
      "source": [
        "# Introduction to deep learning\n",
        "* [Notebook (Exercise) ](https://colab.research.google.com/drive/1X5x9_qOIklhs5SqDRKxANbJY5Ey16B4K) for this section (open in another tab)\n",
        "\n",
        "* [Video](https://www.youtube.com/watch?v=Y0q0jnyfIp0) (21 min)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Flow of regression analysis by deep learning\n",
        "\n",
        "Let's follow the flow of deep learning, using the simplest machine learning problem, regression analysis (prediction of numerical values) as an example.\n",
        "\n",
        "The flow of machine learning is as follows, that is not limited to regression problems.\n",
        "\n",
        "> A 窶徇odel窶� in machine learning is a 窶彷unction窶� that makes some predictions on input data. A function has parameters, and once the parameters are determined, predictions can be made. If appropriate parameters are obtained by learning with data, appropriate predictions can be expected for new data.\n",
        "\n",
        "1. Data acquisition and preprocessing\n",
        "\n",
        "> In preprocessing, input and correct answer data are prepared for supervised learning.\n",
        "The data is then formatted according to the model to be used, and the data scattering is standardized.\n",
        "We also separate the data into training data and evaluation data so that we can evaluate the performance of the model.\n",
        "Alternatively, the data may be divided into three groups: training data, evaluation data, and test data for generalization evaluation.\n",
        "\n",
        "2. Create a neural network model\n",
        "\n",
        "> The model of regression analysis is a function $\\hat{y} = f(x, \\theta)$ that predicts an approximation $\\hat{y}$ of $y$ for an input $x$. Note that $\\theta$ is a parameter that is updated by training. In addition to the parameters to be trained, neural networks have various parameters related to the structure of the model, such as the number of layers, the size of each layer, the activation function, and the dropout ratio. Since these structural parameters are not improved in training process, they are called hyperparameters to distinguish them from the parameters to be trained. In other words, a neural network has many hyperparameters that should be predetermined.\n",
        "\n",
        "3. Training and evaluation of neural network models\n",
        "\n",
        "> Machine learning, also called statistical machine learning, uses data to optimize model parameters so that predictions can be made for new data. The same is true for deep learning, and optimization is performed using training data so that the error between the model's prediction and the correct value becomes small. However, simply optimizing a model with high expressive power, such as deep learning models, may result in over-optimization for the training data, resulting in a decline in predictive performance (generalization performance) for new data. This situation is called 窶徙verfitting窶�. To avoid overfitting, various ideas have been proposed for neural network structures and for learning methods.\n",
        ">\n",
        ">The evaluation data is used to evaluate the generalization performance as well as to adjust the non-learned parameters of the model (called hyper-parameters). The test data are used purely for the evaluation of generalization performance.\n",
        "\n",
        "4. Applying Neural Network Models to New Data\n",
        "\n",
        ">The model $f(x,\\theta)$ obtained from training is used to predict values for test data and new input data $x$.\n",
        "\n"
      ],
      "metadata": {
        "id": "u00ZqhurKA5p"
      }
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YUHFj0FCf66G"
      },
      "source": [
        "**Deep learnibng frameworks**\n",
        "\n",
        "Deep learning frameworks offer building blocks for designing, training, and validating deep neural networks through a high-level programming interface.\n",
        "The most popular libraries for deep learning are TensorFlow and PyTorch. TensorFlow is suitable for practical use, especially when used with a partial library called Keras, because it is easy to describe and has a function to be loaded on terminal devices. PyTorch is highly customizable and is often used in research. This notebook uses TensorFlow-Keras, which is the easiest to implement, to train a neural network model for regression problems.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "khixiwtkJTVY"
      },
      "outputs": [],
      "source": [
        "# import libraries\n",
        "import numpy as np # numpy is a library for numerical analysis.\n",
        "import pandas as pd # pandas is library for data analysis\n",
        "import matplotlib.pyplot as plt # a library for plotting graphs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6WzaNHS7f57T"
      },
      "source": [
        "## Data acquisition and preprocessing\n",
        "\n",
        "Load a dataset about average house prices in Boston from [this csv file](https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 621
        },
        "id": "TaBRXQzZ90l1",
        "outputId": "df978c69-1e5e-4e71-caad-4fe5aea00d39"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "--2024-03-14 01:53:51--  https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 35735 (35K) [text/plain]\n",
            "Saving to: 窶錬ostonHousing.csv窶兔n",
            "\n",
            "\rBostonHousing.csv     0%[                    ]       0  --.-KB/s               \rBostonHousing.csv   100%[===================>]  34.90K  --.-KB/s    in 0.003s  \n",
            "\n",
            "2024-03-14 01:53:52 (10.1 MB/s) - 窶錬ostonHousing.csv窶� saved [35735/35735]\n",
            "\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "        crim    zn  indus  chas    nox     rm   age     dis  rad  tax  \\\n",
              "0    0.00632  18.0   2.31     0  0.538  6.575  65.2  4.0900    1  296   \n",
              "1    0.02731   0.0   7.07     0  0.469  6.421  78.9  4.9671    2  242   \n",
              "2    0.02729   0.0   7.07     0  0.469  7.185  61.1  4.9671    2  242   \n",
              "3    0.03237   0.0   2.18     0  0.458  6.998  45.8  6.0622    3  222   \n",
              "4    0.06905   0.0   2.18     0  0.458  7.147  54.2  6.0622    3  222   \n",
              "..       ...   ...    ...   ...    ...    ...   ...     ...  ...  ...   \n",
              "501  0.06263   0.0  11.93     0  0.573  6.593  69.1  2.4786    1  273   \n",
              "502  0.04527   0.0  11.93     0  0.573  6.120  76.7  2.2875    1  273   \n",
              "503  0.06076   0.0  11.93     0  0.573  6.976  91.0  2.1675    1  273   \n",
              "504  0.10959   0.0  11.93     0  0.573  6.794  89.3  2.3889    1  273   \n",
              "505  0.04741   0.0  11.93     0  0.573  6.030  80.8  2.5050    1  273   \n",
              "\n",
              "     ptratio       b  lstat  medv  \n",
              "0       15.3  396.90   4.98  24.0  \n",
              "1       17.8  396.90   9.14  21.6  \n",
              "2       17.8  392.83   4.03  34.7  \n",
              "3       18.7  394.63   2.94  33.4  \n",
              "4       18.7  396.90   5.33  36.2  \n",
              "..       ...     ...    ...   ...  \n",
              "501     21.0  391.99   9.67  22.4  \n",
              "502     21.0  396.90   9.08  20.6  \n",
              "503     21.0  396.90   5.64  23.9  \n",
              "504     21.0  393.45   6.48  22.0  \n",
              "505     21.0  396.90   7.88  11.9  \n",
              "\n",
              "[506 rows x 14 columns]"
            ],
            "text/html": [
              "\n",
              "  <div id=\"df-e6ae2fdd-41bf-4837-8409-d5afe9c3ceb5\" class=\"colab-df-container\">\n",
              "    <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>crim</th>\n",
              "      <th>zn</th>\n",
              "      <th>indus</th>\n",
              "      <th>chas</th>\n",
              "      <th>nox</th>\n",
              "      <th>rm</th>\n",
              "      <th>age</th>\n",
              "      <th>dis</th>\n",
              "      <th>rad</th>\n",
              "      <th>tax</th>\n",
              "      <th>ptratio</th>\n",
              "      <th>b</th>\n",
              "      <th>lstat</th>\n",
              "      <th>medv</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>0.00632</td>\n",
              "      <td>18.0</td>\n",
              "      <td>2.31</td>\n",
              "      <td>0</td>\n",
              "      <td>0.538</td>\n",
              "      <td>6.575</td>\n",
              "      <td>65.2</td>\n",
              "      <td>4.0900</td>\n",
              "      <td>1</td>\n",
              "      <td>296</td>\n",
              "      <td>15.3</td>\n",
              "      <td>396.90</td>\n",
              "      <td>4.98</td>\n",
              "      <td>24.0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>0.02731</td>\n",
              "      <td>0.0</td>\n",
              "      <td>7.07</td>\n",
              "      <td>0</td>\n",
              "      <td>0.469</td>\n",
              "      <td>6.421</td>\n",
              "      <td>78.9</td>\n",
              "      <td>4.9671</td>\n",
              "      <td>2</td>\n",
              "      <td>242</td>\n",
              "      <td>17.8</td>\n",
              "      <td>396.90</td>\n",
              "      <td>9.14</td>\n",
              "      <td>21.6</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>0.02729</td>\n",
              "      <td>0.0</td>\n",
              "      <td>7.07</td>\n",
              "      <td>0</td>\n",
              "      <td>0.469</td>\n",
              "      <td>7.185</td>\n",
              "      <td>61.1</td>\n",
              "      <td>4.9671</td>\n",
              "      <td>2</td>\n",
              "      <td>242</td>\n",
              "      <td>17.8</td>\n",
              "      <td>392.83</td>\n",
              "      <td>4.03</td>\n",
              "      <td>34.7</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>0.03237</td>\n",
              "      <td>0.0</td>\n",
              "      <td>2.18</td>\n",
              "      <td>0</td>\n",
              "      <td>0.458</td>\n",
              "      <td>6.998</td>\n",
              "      <td>45.8</td>\n",
              "      <td>6.0622</td>\n",
              "      <td>3</td>\n",
              "      <td>222</td>\n",
              "      <td>18.7</td>\n",
              "      <td>394.63</td>\n",
              "      <td>2.94</td>\n",
              "      <td>33.4</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>0.06905</td>\n",
              "      <td>0.0</td>\n",
              "      <td>2.18</td>\n",
              "      <td>0</td>\n",
              "      <td>0.458</td>\n",
              "      <td>7.147</td>\n",
              "      <td>54.2</td>\n",
              "      <td>6.0622</td>\n",
              "      <td>3</td>\n",
              "      <td>222</td>\n",
              "      <td>18.7</td>\n",
              "      <td>396.90</td>\n",
              "      <td>5.33</td>\n",
              "      <td>36.2</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>...</th>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>501</th>\n",
              "      <td>0.06263</td>\n",
              "      <td>0.0</td>\n",
              "      <td>11.93</td>\n",
              "      <td>0</td>\n",
              "      <td>0.573</td>\n",
              "      <td>6.593</td>\n",
              "      <td>69.1</td>\n",
              "      <td>2.4786</td>\n",
              "      <td>1</td>\n",
              "      <td>273</td>\n",
              "      <td>21.0</td>\n",
              "      <td>391.99</td>\n",
              "      <td>9.67</td>\n",
              "      <td>22.4</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>502</th>\n",
              "      <td>0.04527</td>\n",
              "      <td>0.0</td>\n",
              "      <td>11.93</td>\n",
              "      <td>0</td>\n",
              "      <td>0.573</td>\n",
              "      <td>6.120</td>\n",
              "      <td>76.7</td>\n",
              "      <td>2.2875</td>\n",
              "      <td>1</td>\n",
              "      <td>273</td>\n",
              "      <td>21.0</td>\n",
              "      <td>396.90</td>\n",
              "      <td>9.08</td>\n",
              "      <td>20.6</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>503</th>\n",
              "      <td>0.06076</td>\n",
              "      <td>0.0</td>\n",
              "      <td>11.93</td>\n",
              "      <td>0</td>\n",
              "      <td>0.573</td>\n",
              "      <td>6.976</td>\n",
              "      <td>91.0</td>\n",
              "      <td>2.1675</td>\n",
              "      <td>1</td>\n",
              "      <td>273</td>\n",
              "      <td>21.0</td>\n",
              "      <td>396.90</td>\n",
              "      <td>5.64</td>\n",
              "      <td>23.9</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>504</th>\n",
              "      <td>0.10959</td>\n",
              "      <td>0.0</td>\n",
              "      <td>11.93</td>\n",
              "      <td>0</td>\n",
              "      <td>0.573</td>\n",
              "      <td>6.794</td>\n",
              "      <td>89.3</td>\n",
              "      <td>2.3889</td>\n",
              "      <td>1</td>\n",
              "      <td>273</td>\n",
              "      <td>21.0</td>\n",
              "      <td>393.45</td>\n",
              "      <td>6.48</td>\n",
              "      <td>22.0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>505</th>\n",
              "      <td>0.04741</td>\n",
              "      <td>0.0</td>\n",
              "      <td>11.93</td>\n",
              "      <td>0</td>\n",
              "      <td>0.573</td>\n",
              "      <td>6.030</td>\n",
              "      <td>80.8</td>\n",
              "      <td>2.5050</td>\n",
              "      <td>1</td>\n",
              "      <td>273</td>\n",
              "      <td>21.0</td>\n",
              "      <td>396.90</td>\n",
              "      <td>7.88</td>\n",
              "      <td>11.9</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "<p>506 rows テ� 14 columns</p>\n",
              "</div>\n",
              "    <div class=\"colab-df-buttons\">\n",
              "\n",
              "  <div class=\"colab-df-container\">\n",
              "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-e6ae2fdd-41bf-4837-8409-d5afe9c3ceb5')\"\n",
              "            title=\"Convert this dataframe to an interactive table.\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
              "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
              "  </svg>\n",
              "    </button>\n",
              "\n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    .colab-df-buttons div {\n",
              "      margin-bottom: 4px;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "    <script>\n",
              "      const buttonEl =\n",
              "        document.querySelector('#df-e6ae2fdd-41bf-4837-8409-d5afe9c3ceb5 button.colab-df-convert');\n",
              "      buttonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "      async function convertToInteractive(key) {\n",
              "        const element = document.querySelector('#df-e6ae2fdd-41bf-4837-8409-d5afe9c3ceb5');\n",
              "        const dataTable =\n",
              "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                    [key], {});\n",
              "        if (!dataTable) return;\n",
              "\n",
              "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "          + ' to learn more about interactive tables.';\n",
              "        element.innerHTML = '';\n",
              "        dataTable['output_type'] = 'display_data';\n",
              "        await google.colab.output.renderOutput(dataTable, element);\n",
              "        const docLink = document.createElement('div');\n",
              "        docLink.innerHTML = docLinkHtml;\n",
              "        element.appendChild(docLink);\n",
              "      }\n",
              "    </script>\n",
              "  </div>\n",
              "\n",
              "\n",
              "<div id=\"df-045a9a15-1ac0-496c-83d9-7818c6bd3d93\">\n",
              "  <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-045a9a15-1ac0-496c-83d9-7818c6bd3d93')\"\n",
              "            title=\"Suggest charts\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "     width=\"24px\">\n",
              "    <g>\n",
              "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
              "    </g>\n",
              "</svg>\n",
              "  </button>\n",
              "\n",
              "<style>\n",
              "  .colab-df-quickchart {\n",
              "      --bg-color: #E8F0FE;\n",
              "      --fill-color: #1967D2;\n",
              "      --hover-bg-color: #E2EBFA;\n",
              "      --hover-fill-color: #174EA6;\n",
              "      --disabled-fill-color: #AAA;\n",
              "      --disabled-bg-color: #DDD;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart {\n",
              "      --bg-color: #3B4455;\n",
              "      --fill-color: #D2E3FC;\n",
              "      --hover-bg-color: #434B5C;\n",
              "      --hover-fill-color: #FFFFFF;\n",
              "      --disabled-bg-color: #3B4455;\n",
              "      --disabled-fill-color: #666;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart {\n",
              "    background-color: var(--bg-color);\n",
              "    border: none;\n",
              "    border-radius: 50%;\n",
              "    cursor: pointer;\n",
              "    display: none;\n",
              "    fill: var(--fill-color);\n",
              "    height: 32px;\n",
              "    padding: 0;\n",
              "    width: 32px;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart:hover {\n",
              "    background-color: var(--hover-bg-color);\n",
              "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "    fill: var(--button-hover-fill-color);\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart-complete:disabled,\n",
              "  .colab-df-quickchart-complete:disabled:hover {\n",
              "    background-color: var(--disabled-bg-color);\n",
              "    fill: var(--disabled-fill-color);\n",
              "    box-shadow: none;\n",
              "  }\n",
              "\n",
              "  .colab-df-spinner {\n",
              "    border: 2px solid var(--fill-color);\n",
              "    border-color: transparent;\n",
              "    border-bottom-color: var(--fill-color);\n",
              "    animation:\n",
              "      spin 1s steps(1) infinite;\n",
              "  }\n",
              "\n",
              "  @keyframes spin {\n",
              "    0% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "      border-left-color: var(--fill-color);\n",
              "    }\n",
              "    20% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    30% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    40% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    60% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    80% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "    90% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "  }\n",
              "</style>\n",
              "\n",
              "  <script>\n",
              "    async function quickchart(key) {\n",
              "      const quickchartButtonEl =\n",
              "        document.querySelector('#' + key + ' button');\n",
              "      quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
              "      quickchartButtonEl.classList.add('colab-df-spinner');\n",
              "      try {\n",
              "        const charts = await google.colab.kernel.invokeFunction(\n",
              "            'suggestCharts', [key], {});\n",
              "      } catch (error) {\n",
              "        console.error('Error during call to suggestCharts:', error);\n",
              "      }\n",
              "      quickchartButtonEl.classList.remove('colab-df-spinner');\n",
              "      quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
              "    }\n",
              "    (() => {\n",
              "      let quickchartButtonEl =\n",
              "        document.querySelector('#df-045a9a15-1ac0-496c-83d9-7818c6bd3d93 button');\n",
              "      quickchartButtonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "    })();\n",
              "  </script>\n",
              "</div>\n",
              "\n",
              "  <div id=\"id_ef419da1-1785-4542-aa05-220c671c833a\">\n",
              "    <style>\n",
              "      .colab-df-generate {\n",
              "        background-color: #E8F0FE;\n",
              "        border: none;\n",
              "        border-radius: 50%;\n",
              "        cursor: pointer;\n",
              "        display: none;\n",
              "        fill: #1967D2;\n",
              "        height: 32px;\n",
              "        padding: 0 0 0 0;\n",
              "        width: 32px;\n",
              "      }\n",
              "\n",
              "      .colab-df-generate:hover {\n",
              "        background-color: #E2EBFA;\n",
              "        box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "        fill: #174EA6;\n",
              "      }\n",
              "\n",
              "      [theme=dark] .colab-df-generate {\n",
              "        background-color: #3B4455;\n",
              "        fill: #D2E3FC;\n",
              "      }\n",
              "\n",
              "      [theme=dark] .colab-df-generate:hover {\n",
              "        background-color: #434B5C;\n",
              "        box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "        filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "        fill: #FFFFFF;\n",
              "      }\n",
              "    </style>\n",
              "    <button class=\"colab-df-generate\" onclick=\"generateWithVariable('df')\"\n",
              "            title=\"Generate code using this dataframe.\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M7,19H8.4L18.45,9,17,7.55,7,17.6ZM5,21V16.75L18.45,3.32a2,2,0,0,1,2.83,0l1.4,1.43a1.91,1.91,0,0,1,.58,1.4,1.91,1.91,0,0,1-.58,1.4L9.25,21ZM18.45,9,17,7.55Zm-12,3A5.31,5.31,0,0,0,4.9,8.1,5.31,5.31,0,0,0,1,6.5,5.31,5.31,0,0,0,4.9,4.9,5.31,5.31,0,0,0,6.5,1,5.31,5.31,0,0,0,8.1,4.9,5.31,5.31,0,0,0,12,6.5,5.46,5.46,0,0,0,6.5,12Z\"/>\n",
              "  </svg>\n",
              "    </button>\n",
              "    <script>\n",
              "      (() => {\n",
              "      const buttonEl =\n",
              "        document.querySelector('#id_ef419da1-1785-4542-aa05-220c671c833a button.colab-df-generate');\n",
              "      buttonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "      buttonEl.onclick = () => {\n",
              "        google.colab.notebook.generateWithVariable('df');\n",
              "      }\n",
              "      })();\n",
              "    </script>\n",
              "  </div>\n",
              "\n",
              "    </div>\n",
              "  </div>\n"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "dataframe",
              "variable_name": "df",
              "summary": "{\n  \"name\": \"df\",\n  \"rows\": 506,\n  \"fields\": [\n    {\n      \"column\": \"crim\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 8.60154510533249,\n        \"min\": 0.00632,\n        \"max\": 88.9762,\n        \"num_unique_values\": 504,\n        \"samples\": [\n          0.09178,\n          0.05644,\n          0.10574\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"zn\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 23.32245299451514,\n        \"min\": 0.0,\n        \"max\": 100.0,\n        \"num_unique_values\": 26,\n        \"samples\": [\n          25.0,\n          30.0,\n          18.0\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"indus\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 6.860352940897585,\n        \"min\": 0.46,\n        \"max\": 27.74,\n        \"num_unique_values\": 76,\n        \"samples\": [\n          8.14,\n          1.47,\n          1.22\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"chas\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0,\n        \"min\": 0,\n        \"max\": 1,\n        \"num_unique_values\": 2,\n        \"samples\": [\n          1,\n          0\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"nox\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0.11587767566755595,\n        \"min\": 0.385,\n        \"max\": 0.871,\n        \"num_unique_values\": 81,\n        \"samples\": [\n          0.401,\n          0.538\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"rm\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0.7026171434153233,\n        \"min\": 3.561,\n        \"max\": 8.78,\n        \"num_unique_values\": 446,\n        \"samples\": [\n          6.849,\n          4.88\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"age\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 28.148861406903617,\n        \"min\": 2.9,\n        \"max\": 100.0,\n        \"num_unique_values\": 356,\n        \"samples\": [\n          51.8,\n          33.8\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"dis\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 2.105710126627611,\n        \"min\": 1.1296,\n        \"max\": 12.1265,\n        \"num_unique_values\": 412,\n        \"samples\": [\n          2.2955,\n          4.2515\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"rad\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 8,\n        \"min\": 1,\n        \"max\": 24,\n        \"num_unique_values\": 9,\n        \"samples\": [\n          7,\n          2\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"tax\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 168,\n        \"min\": 187,\n        \"max\": 711,\n        \"num_unique_values\": 66,\n        \"samples\": [\n          370,\n          666\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"ptratio\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 2.1649455237144406,\n        \"min\": 12.6,\n        \"max\": 22.0,\n        \"num_unique_values\": 46,\n        \"samples\": [\n          19.6,\n          15.6\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"b\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 91.29486438415783,\n        \"min\": 0.32,\n        \"max\": 396.9,\n        \"num_unique_values\": 357,\n        \"samples\": [\n          396.24,\n          395.11\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"lstat\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 7.141061511348571,\n        \"min\": 1.73,\n        \"max\": 37.97,\n        \"num_unique_values\": 455,\n        \"samples\": [\n          6.15,\n          4.32\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"medv\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 9.197104087379818,\n        \"min\": 5.0,\n        \"max\": 50.0,\n        \"num_unique_values\": 229,\n        \"samples\": [\n          14.1,\n          22.5\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    }\n  ]\n}"
            }
          },
          "metadata": {},
          "execution_count": 2
        }
      ],
      "source": [
        "# Dataset download: A file named BostonHousing.csv is downloaded to the home directory (/content/)\n",
        "!wget 'https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv'\n",
        "\n",
        "# Read dataset: Read csv file and store contents in variable df of Pandas DataFrame class.\n",
        "# header=0: Specifies that row 0 is a column name.\n",
        "# sep=',': Specifies that the delimiter is a comma.\n",
        "df = pd.read_csv('/content/BostonHousing.csv', header=0, sep=',')\n",
        "\n",
        "# Display df\n",
        "df"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2bgTDBakDajU"
      },
      "source": [
        "The data consists of 506 rows and 14 columns.\n",
        "\n",
        "Of these, medv (average home price in each district) is used as the objective value (correct answer), and the remaining 13 items, excluding the categorical variables chas and rad and the ethically problematic b (percentage of blacks), are used as explanatory variables.\n",
        "\n",
        "This creates data for a regression problem that predicts a one-dimensional value, the average home price in the district, from 10 dimensions of input data for each district."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        },
        "id": "BNDEeoBcJaMt",
        "outputId": "7d325417-d0c8-4bf4-c7a6-d065b9f20261"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "      crim    zn  indus    nox     rm   age     dis  tax  ptratio  lstat\n",
              "0  0.00632  18.0   2.31  0.538  6.575  65.2  4.0900  296     15.3   4.98\n",
              "1  0.02731   0.0   7.07  0.469  6.421  78.9  4.9671  242     17.8   9.14\n",
              "2  0.02729   0.0   7.07  0.469  7.185  61.1  4.9671  242     17.8   4.03\n",
              "3  0.03237   0.0   2.18  0.458  6.998  45.8  6.0622  222     18.7   2.94\n",
              "4  0.06905   0.0   2.18  0.458  7.147  54.2  6.0622  222     18.7   5.33"
            ],
            "text/html": [
              "\n",
              "  <div id=\"df-0189ae43-f552-4872-8af0-15b20fe276cc\" class=\"colab-df-container\">\n",
              "    <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>crim</th>\n",
              "      <th>zn</th>\n",
              "      <th>indus</th>\n",
              "      <th>nox</th>\n",
              "      <th>rm</th>\n",
              "      <th>age</th>\n",
              "      <th>dis</th>\n",
              "      <th>tax</th>\n",
              "      <th>ptratio</th>\n",
              "      <th>lstat</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>0.00632</td>\n",
              "      <td>18.0</td>\n",
              "      <td>2.31</td>\n",
              "      <td>0.538</td>\n",
              "      <td>6.575</td>\n",
              "      <td>65.2</td>\n",
              "      <td>4.0900</td>\n",
              "      <td>296</td>\n",
              "      <td>15.3</td>\n",
              "      <td>4.98</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>0.02731</td>\n",
              "      <td>0.0</td>\n",
              "      <td>7.07</td>\n",
              "      <td>0.469</td>\n",
              "      <td>6.421</td>\n",
              "      <td>78.9</td>\n",
              "      <td>4.9671</td>\n",
              "      <td>242</td>\n",
              "      <td>17.8</td>\n",
              "      <td>9.14</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>0.02729</td>\n",
              "      <td>0.0</td>\n",
              "      <td>7.07</td>\n",
              "      <td>0.469</td>\n",
              "      <td>7.185</td>\n",
              "      <td>61.1</td>\n",
              "      <td>4.9671</td>\n",
              "      <td>242</td>\n",
              "      <td>17.8</td>\n",
              "      <td>4.03</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>0.03237</td>\n",
              "      <td>0.0</td>\n",
              "      <td>2.18</td>\n",
              "      <td>0.458</td>\n",
              "      <td>6.998</td>\n",
              "      <td>45.8</td>\n",
              "      <td>6.0622</td>\n",
              "      <td>222</td>\n",
              "      <td>18.7</td>\n",
              "      <td>2.94</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>0.06905</td>\n",
              "      <td>0.0</td>\n",
              "      <td>2.18</td>\n",
              "      <td>0.458</td>\n",
              "      <td>7.147</td>\n",
              "      <td>54.2</td>\n",
              "      <td>6.0622</td>\n",
              "      <td>222</td>\n",
              "      <td>18.7</td>\n",
              "      <td>5.33</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "    <div class=\"colab-df-buttons\">\n",
              "\n",
              "  <div class=\"colab-df-container\">\n",
              "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-0189ae43-f552-4872-8af0-15b20fe276cc')\"\n",
              "            title=\"Convert this dataframe to an interactive table.\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
              "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
              "  </svg>\n",
              "    </button>\n",
              "\n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    .colab-df-buttons div {\n",
              "      margin-bottom: 4px;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "    <script>\n",
              "      const buttonEl =\n",
              "        document.querySelector('#df-0189ae43-f552-4872-8af0-15b20fe276cc button.colab-df-convert');\n",
              "      buttonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "      async function convertToInteractive(key) {\n",
              "        const element = document.querySelector('#df-0189ae43-f552-4872-8af0-15b20fe276cc');\n",
              "        const dataTable =\n",
              "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                    [key], {});\n",
              "        if (!dataTable) return;\n",
              "\n",
              "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "          + ' to learn more about interactive tables.';\n",
              "        element.innerHTML = '';\n",
              "        dataTable['output_type'] = 'display_data';\n",
              "        await google.colab.output.renderOutput(dataTable, element);\n",
              "        const docLink = document.createElement('div');\n",
              "        docLink.innerHTML = docLinkHtml;\n",
              "        element.appendChild(docLink);\n",
              "      }\n",
              "    </script>\n",
              "  </div>\n",
              "\n",
              "\n",
              "<div id=\"df-01a6ca39-bbb2-473f-8af1-2ccc2328de26\">\n",
              "  <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-01a6ca39-bbb2-473f-8af1-2ccc2328de26')\"\n",
              "            title=\"Suggest charts\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "     width=\"24px\">\n",
              "    <g>\n",
              "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
              "    </g>\n",
              "</svg>\n",
              "  </button>\n",
              "\n",
              "<style>\n",
              "  .colab-df-quickchart {\n",
              "      --bg-color: #E8F0FE;\n",
              "      --fill-color: #1967D2;\n",
              "      --hover-bg-color: #E2EBFA;\n",
              "      --hover-fill-color: #174EA6;\n",
              "      --disabled-fill-color: #AAA;\n",
              "      --disabled-bg-color: #DDD;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart {\n",
              "      --bg-color: #3B4455;\n",
              "      --fill-color: #D2E3FC;\n",
              "      --hover-bg-color: #434B5C;\n",
              "      --hover-fill-color: #FFFFFF;\n",
              "      --disabled-bg-color: #3B4455;\n",
              "      --disabled-fill-color: #666;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart {\n",
              "    background-color: var(--bg-color);\n",
              "    border: none;\n",
              "    border-radius: 50%;\n",
              "    cursor: pointer;\n",
              "    display: none;\n",
              "    fill: var(--fill-color);\n",
              "    height: 32px;\n",
              "    padding: 0;\n",
              "    width: 32px;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart:hover {\n",
              "    background-color: var(--hover-bg-color);\n",
              "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "    fill: var(--button-hover-fill-color);\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart-complete:disabled,\n",
              "  .colab-df-quickchart-complete:disabled:hover {\n",
              "    background-color: var(--disabled-bg-color);\n",
              "    fill: var(--disabled-fill-color);\n",
              "    box-shadow: none;\n",
              "  }\n",
              "\n",
              "  .colab-df-spinner {\n",
              "    border: 2px solid var(--fill-color);\n",
              "    border-color: transparent;\n",
              "    border-bottom-color: var(--fill-color);\n",
              "    animation:\n",
              "      spin 1s steps(1) infinite;\n",
              "  }\n",
              "\n",
              "  @keyframes spin {\n",
              "    0% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "      border-left-color: var(--fill-color);\n",
              "    }\n",
              "    20% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    30% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    40% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    60% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    80% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "    90% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "  }\n",
              "</style>\n",
              "\n",
              "  <script>\n",
              "    async function quickchart(key) {\n",
              "      const quickchartButtonEl =\n",
              "        document.querySelector('#' + key + ' button');\n",
              "      quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
              "      quickchartButtonEl.classList.add('colab-df-spinner');\n",
              "      try {\n",
              "        const charts = await google.colab.kernel.invokeFunction(\n",
              "            'suggestCharts', [key], {});\n",
              "      } catch (error) {\n",
              "        console.error('Error during call to suggestCharts:', error);\n",
              "      }\n",
              "      quickchartButtonEl.classList.remove('colab-df-spinner');\n",
              "      quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
              "    }\n",
              "    (() => {\n",
              "      let quickchartButtonEl =\n",
              "        document.querySelector('#df-01a6ca39-bbb2-473f-8af1-2ccc2328de26 button');\n",
              "      quickchartButtonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "    })();\n",
              "  </script>\n",
              "</div>\n",
              "\n",
              "    </div>\n",
              "  </div>\n"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "dataframe",
              "variable_name": "x",
              "summary": "{\n  \"name\": \"x\",\n  \"rows\": 506,\n  \"fields\": [\n    {\n      \"column\": \"crim\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 8.60154510533249,\n        \"min\": 0.00632,\n        \"max\": 88.9762,\n        \"num_unique_values\": 504,\n        \"samples\": [\n          0.09178,\n          0.05644,\n          0.10574\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"zn\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 23.32245299451514,\n        \"min\": 0.0,\n        \"max\": 100.0,\n        \"num_unique_values\": 26,\n        \"samples\": [\n          25.0,\n          30.0,\n          18.0\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"indus\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 6.860352940897585,\n        \"min\": 0.46,\n        \"max\": 27.74,\n        \"num_unique_values\": 76,\n        \"samples\": [\n          8.14,\n          1.47,\n          1.22\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"nox\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0.11587767566755595,\n        \"min\": 0.385,\n        \"max\": 0.871,\n        \"num_unique_values\": 81,\n        \"samples\": [\n          0.401,\n          0.538,\n          0.52\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"rm\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0.7026171434153233,\n        \"min\": 3.561,\n        \"max\": 8.78,\n        \"num_unique_values\": 446,\n        \"samples\": [\n          6.849,\n          4.88,\n          5.693\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"age\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 28.148861406903617,\n        \"min\": 2.9,\n        \"max\": 100.0,\n        \"num_unique_values\": 356,\n        \"samples\": [\n          51.8,\n          33.8,\n          70.3\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"dis\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 2.105710126627611,\n        \"min\": 1.1296,\n        \"max\": 12.1265,\n        \"num_unique_values\": 412,\n        \"samples\": [\n          2.2955,\n          4.2515,\n          3.2628\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"tax\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 168,\n        \"min\": 187,\n        \"max\": 711,\n        \"num_unique_values\": 66,\n        \"samples\": [\n          370,\n          666,\n          296\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"ptratio\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 2.1649455237144406,\n        \"min\": 12.6,\n        \"max\": 22.0,\n        \"num_unique_values\": 46,\n        \"samples\": [\n          19.6,\n          15.6,\n          14.4\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"lstat\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 7.141061511348571,\n        \"min\": 1.73,\n        \"max\": 37.97,\n        \"num_unique_values\": 455,\n        \"samples\": [\n          6.15,\n          4.32,\n          18.05\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    }\n  ]\n}"
            }
          },
          "metadata": {},
          "execution_count": 3
        }
      ],
      "source": [
        "# Let x be df excluding the columns chas, rad, b, and medv. axis=1 specifies \"column\".\n",
        "x = df.drop(['chas', 'rad', 'b', 'medv'], axis=1)\n",
        "\n",
        "# Let y be the column of medv out of df\n",
        "y = df['medv']\n",
        "\n",
        "# Display the first 5 rows of x\n",
        "x.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0ajnkT-4iIkU"
      },
      "source": [
        "The input data are linearly transformed so that the mean is 0 and the variance is 1 for each item."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c255InMhKvy1"
      },
      "outputs": [],
      "source": [
        "mean = x.mean(axis=0)    # column-wise mean\n",
        "standard = x.std(axis=0) # standard deviation per column\n",
        "x = (x-mean)/standard    # standardized to mean 0, variance 1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mdibvB2HiUkN"
      },
      "source": [
        "Of the data, 80% is randomly selected as training data, and the remaining 20% is used as test data to evaluate the training results."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        },
        "id": "rBTFQfTg6s3e",
        "outputId": "e18da050-4318-4594-a316-4371006bc571"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "         crim        zn     indus       nox        rm       age       dis  \\\n",
              "220 -0.378471 -0.487240 -0.719610 -0.411598  0.948405  0.707847 -0.443244   \n",
              "71  -0.401645 -0.487240 -0.047633 -1.222799 -0.460613 -1.814457  0.708672   \n",
              "240 -0.406931  0.799074 -0.904732 -1.093352  0.871549 -0.507122  1.206746   \n",
              "6   -0.409837  0.048724 -0.476182 -0.264892 -0.388027 -0.070159  0.838414   \n",
              "417  2.595705 -0.487240  1.014995  1.072726 -1.395688  0.729163 -1.019866   \n",
              "\n",
              "          tax   ptratio     lstat  \n",
              "220 -0.600682 -0.487557 -0.412132  \n",
              "71  -0.612548  0.343873 -0.388326  \n",
              "240 -0.642216 -0.857081 -0.178274  \n",
              "6   -0.576948 -1.503749 -0.031237  \n",
              "417  1.529413  0.805778  1.958664  "
            ],
            "text/html": [
              "\n",
              "  <div id=\"df-65570388-dbc1-413f-97b5-3f86d5db8c3b\" class=\"colab-df-container\">\n",
              "    <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>crim</th>\n",
              "      <th>zn</th>\n",
              "      <th>indus</th>\n",
              "      <th>nox</th>\n",
              "      <th>rm</th>\n",
              "      <th>age</th>\n",
              "      <th>dis</th>\n",
              "      <th>tax</th>\n",
              "      <th>ptratio</th>\n",
              "      <th>lstat</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>220</th>\n",
              "      <td>-0.378471</td>\n",
              "      <td>-0.487240</td>\n",
              "      <td>-0.719610</td>\n",
              "      <td>-0.411598</td>\n",
              "      <td>0.948405</td>\n",
              "      <td>0.707847</td>\n",
              "      <td>-0.443244</td>\n",
              "      <td>-0.600682</td>\n",
              "      <td>-0.487557</td>\n",
              "      <td>-0.412132</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>71</th>\n",
              "      <td>-0.401645</td>\n",
              "      <td>-0.487240</td>\n",
              "      <td>-0.047633</td>\n",
              "      <td>-1.222799</td>\n",
              "      <td>-0.460613</td>\n",
              "      <td>-1.814457</td>\n",
              "      <td>0.708672</td>\n",
              "      <td>-0.612548</td>\n",
              "      <td>0.343873</td>\n",
              "      <td>-0.388326</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>240</th>\n",
              "      <td>-0.406931</td>\n",
              "      <td>0.799074</td>\n",
              "      <td>-0.904732</td>\n",
              "      <td>-1.093352</td>\n",
              "      <td>0.871549</td>\n",
              "      <td>-0.507122</td>\n",
              "      <td>1.206746</td>\n",
              "      <td>-0.642216</td>\n",
              "      <td>-0.857081</td>\n",
              "      <td>-0.178274</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>6</th>\n",
              "      <td>-0.409837</td>\n",
              "      <td>0.048724</td>\n",
              "      <td>-0.476182</td>\n",
              "      <td>-0.264892</td>\n",
              "      <td>-0.388027</td>\n",
              "      <td>-0.070159</td>\n",
              "      <td>0.838414</td>\n",
              "      <td>-0.576948</td>\n",
              "      <td>-1.503749</td>\n",
              "      <td>-0.031237</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>417</th>\n",
              "      <td>2.595705</td>\n",
              "      <td>-0.487240</td>\n",
              "      <td>1.014995</td>\n",
              "      <td>1.072726</td>\n",
              "      <td>-1.395688</td>\n",
              "      <td>0.729163</td>\n",
              "      <td>-1.019866</td>\n",
              "      <td>1.529413</td>\n",
              "      <td>0.805778</td>\n",
              "      <td>1.958664</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "    <div class=\"colab-df-buttons\">\n",
              "\n",
              "  <div class=\"colab-df-container\">\n",
              "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-65570388-dbc1-413f-97b5-3f86d5db8c3b')\"\n",
              "            title=\"Convert this dataframe to an interactive table.\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
              "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
              "  </svg>\n",
              "    </button>\n",
              "\n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    .colab-df-buttons div {\n",
              "      margin-bottom: 4px;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "    <script>\n",
              "      const buttonEl =\n",
              "        document.querySelector('#df-65570388-dbc1-413f-97b5-3f86d5db8c3b button.colab-df-convert');\n",
              "      buttonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "      async function convertToInteractive(key) {\n",
              "        const element = document.querySelector('#df-65570388-dbc1-413f-97b5-3f86d5db8c3b');\n",
              "        const dataTable =\n",
              "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                    [key], {});\n",
              "        if (!dataTable) return;\n",
              "\n",
              "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "          + ' to learn more about interactive tables.';\n",
              "        element.innerHTML = '';\n",
              "        dataTable['output_type'] = 'display_data';\n",
              "        await google.colab.output.renderOutput(dataTable, element);\n",
              "        const docLink = document.createElement('div');\n",
              "        docLink.innerHTML = docLinkHtml;\n",
              "        element.appendChild(docLink);\n",
              "      }\n",
              "    </script>\n",
              "  </div>\n",
              "\n",
              "\n",
              "<div id=\"df-32b770ed-dfbc-48e3-97b9-11cd3d7d4706\">\n",
              "  <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-32b770ed-dfbc-48e3-97b9-11cd3d7d4706')\"\n",
              "            title=\"Suggest charts\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "     width=\"24px\">\n",
              "    <g>\n",
              "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
              "    </g>\n",
              "</svg>\n",
              "  </button>\n",
              "\n",
              "<style>\n",
              "  .colab-df-quickchart {\n",
              "      --bg-color: #E8F0FE;\n",
              "      --fill-color: #1967D2;\n",
              "      --hover-bg-color: #E2EBFA;\n",
              "      --hover-fill-color: #174EA6;\n",
              "      --disabled-fill-color: #AAA;\n",
              "      --disabled-bg-color: #DDD;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart {\n",
              "      --bg-color: #3B4455;\n",
              "      --fill-color: #D2E3FC;\n",
              "      --hover-bg-color: #434B5C;\n",
              "      --hover-fill-color: #FFFFFF;\n",
              "      --disabled-bg-color: #3B4455;\n",
              "      --disabled-fill-color: #666;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart {\n",
              "    background-color: var(--bg-color);\n",
              "    border: none;\n",
              "    border-radius: 50%;\n",
              "    cursor: pointer;\n",
              "    display: none;\n",
              "    fill: var(--fill-color);\n",
              "    height: 32px;\n",
              "    padding: 0;\n",
              "    width: 32px;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart:hover {\n",
              "    background-color: var(--hover-bg-color);\n",
              "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "    fill: var(--button-hover-fill-color);\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart-complete:disabled,\n",
              "  .colab-df-quickchart-complete:disabled:hover {\n",
              "    background-color: var(--disabled-bg-color);\n",
              "    fill: var(--disabled-fill-color);\n",
              "    box-shadow: none;\n",
              "  }\n",
              "\n",
              "  .colab-df-spinner {\n",
              "    border: 2px solid var(--fill-color);\n",
              "    border-color: transparent;\n",
              "    border-bottom-color: var(--fill-color);\n",
              "    animation:\n",
              "      spin 1s steps(1) infinite;\n",
              "  }\n",
              "\n",
              "  @keyframes spin {\n",
              "    0% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "      border-left-color: var(--fill-color);\n",
              "    }\n",
              "    20% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    30% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    40% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    60% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    80% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "    90% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "  }\n",
              "</style>\n",
              "\n",
              "  <script>\n",
              "    async function quickchart(key) {\n",
              "      const quickchartButtonEl =\n",
              "        document.querySelector('#' + key + ' button');\n",
              "      quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
              "      quickchartButtonEl.classList.add('colab-df-spinner');\n",
              "      try {\n",
              "        const charts = await google.colab.kernel.invokeFunction(\n",
              "            'suggestCharts', [key], {});\n",
              "      } catch (error) {\n",
              "        console.error('Error during call to suggestCharts:', error);\n",
              "      }\n",
              "      quickchartButtonEl.classList.remove('colab-df-spinner');\n",
              "      quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
              "    }\n",
              "    (() => {\n",
              "      let quickchartButtonEl =\n",
              "        document.querySelector('#df-32b770ed-dfbc-48e3-97b9-11cd3d7d4706 button');\n",
              "      quickchartButtonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "    })();\n",
              "  </script>\n",
              "</div>\n",
              "\n",
              "    </div>\n",
              "  </div>\n"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "dataframe",
              "variable_name": "x_train",
              "summary": "{\n  \"name\": \"x_train\",\n  \"rows\": 404,\n  \"fields\": [\n    {\n      \"column\": \"crim\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0.946102434909994,\n        \"min\": -0.41936692921321594,\n        \"max\": 9.924109610233579,\n        \"num_unique_values\": 402,\n        \"samples\": [\n          -0.4151095545727809,\n          -0.39995297532992174,\n          -0.41857753615454435\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"zn\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 1.013482862307835,\n        \"min\": -0.4872401872268242,\n        \"max\": 3.800473460369229,\n        \"num_unique_values\": 26,\n        \"samples\": [\n          0.5846882246721891,\n          2.9429307308500183,\n          -0.4872401872268242\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"indus\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 1.0193369448962804,\n        \"min\": -1.5563016579624498,\n        \"max\": 2.420170140940476,\n        \"num_unique_values\": 73,\n        \"samples\": [\n          1.0149946225598236,\n          -1.3289081093448338,\n          -0.769170143516856\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"nox\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0.9941564183915629,\n        \"min\": -1.4644327158872215,\n        \"max\": 2.7296451960161567,\n        \"num_unique_values\": 80,\n        \"samples\": [\n          -0.5755643518418554,\n          -0.41159834294028286,\n          -1.0674623785465727\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"rm\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0.9887203612398842,\n        \"min\": -3.8764132257185957,\n        \"max\": 3.5515296431831973,\n        \"num_unique_values\": 362,\n        \"samples\": [\n          -0.42645470603705693,\n          1.3725335649519952,\n          0.12434312693190092\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"age\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0.9985112731416391,\n        \"min\": -2.3331281587703483,\n        \"max\": 1.1163896954823942,\n        \"num_unique_values\": 297,\n        \"samples\": [\n          -1.8322197988840079,\n          -1.071975905155803,\n          0.15365093286396667\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"dis\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0.9880671308858123,\n        \"min\": -1.244636027820426,\n        \"max\": 3.956602196521782,\n        \"num_unique_values\": 346,\n        \"samples\": [\n          -0.44590310692516566,\n          -0.17758507356657002,\n          -0.916005799352907\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"tax\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 1.0137649567496823,\n        \"min\": -1.3126909925173598,\n        \"max\": 1.796416438923003,\n        \"num_unique_values\": 63,\n        \"samples\": [\n          -0.903285624636854,\n          -0.9922867915673987,\n          -0.6006816570730019\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"ptratio\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 1.0137998294302908,\n        \"min\": -2.7047025122329584,\n        \"max\": 1.6372081257179822,\n        \"num_unique_values\": 44,\n        \"samples\": [\n          1.6372081257179822,\n          0.5286351968794448,\n          0.2514919646698097\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"lstat\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 1.022121240733615,\n        \"min\": -1.5296133808324988,\n        \"max\": 3.4066275329281117,\n        \"num_unique_values\": 369,\n        \"samples\": [\n          0.8439833138694119,\n          -1.0983049548925647,\n          -0.21468282252860854\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    }\n  ]\n}"
            }
          },
          "metadata": {},
          "execution_count": 5
        }
      ],
      "source": [
        "# Load sklearn's train_test_split function\n",
        "from sklearn.model_selection import train_test_split\n",
        "\n",
        "# Fix the random seed to 0 to get the same result every time\n",
        "np.random.seed(0)\n",
        "\n",
        "# Split data into input data for training, input data for test data, correct answer for training data, and correct answer for test data with train_test_split function\n",
        "x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)\n",
        "\n",
        "# Display the first 5 rows of input data for training\n",
        "x_train.head()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "lIMWrIJoa9Y8",
        "outputId": "44889032-6037-404a-d96f-4625c4ee311f"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "((404, 10), (404,))"
            ]
          },
          "metadata": {},
          "execution_count": 6
        }
      ],
      "source": [
        "# Display the type as an array of input data and correct answer data for training: 404 rows of data\n",
        "x_train.shape, y_train.shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wmBuyIdz2uzR"
      },
      "source": [
        "## Training Neural Networks with TensorFlow-Keras\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n-O6CYH1k4LW"
      },
      "source": [
        "### Creating a model\n",
        "\n",
        "Using a class called Sequential, we created the following model\n",
        "\n",
        "10 input dimensions --- 1000 dimensions --- 800 dimensions --- 100 dimensions --- 1 output dimension\n",
        "\n",
        "\n",
        "```\n",
        "model.add(Dense(1000, activation = 'relu'))\n",
        "```\n",
        "adds a layer with 1000 neurons and ReLU activation function.\n",
        "\n",
        "\n",
        "* A neural network transmits data from the input (leftmost) to a layer of neurons arranged vertically in a row and transforms them in sequence, as shown in the figure below.\n",
        "\n",
        "* A single neuron (a vertex in the graph) performs a \"linear sum + constant\" on the output $(x_1, x_2, \\cdots, x_n)$ of the vertices on the left.\n",
        "$$ a = w_1x_1 +w_2 x_2 + \\cdots+ w_n x_n + b$$\n",
        "and outputs the value $y =f(a)$ after applying the activation function $f$ to it.\n",
        "\n",
        "* The activation function is a nonlinear function, typically ReLU (Rectified Linear Unit $= \\max(x,0)$)\n",
        "is used.\n",
        "\n",
        "\n",
        "><img src=\"https://vitalflux.com/wp-content/uploads/2023/02/Sklearn-Neural-Network-MLPRegressor-Regression-Model--300x166.png\" width=400>\n",
        ">\n",
        ">Neural Networks: Figure Source https://vitalflux.com/wp-content/uploads/2023/02/Sklearn-Neural-Network-MLPRegressor-Regression-Model--300x166.png\n",
        "\n",
        "><img src=\"https://pytorch.org/docs/stable/_images/ReLU.png\"  width=400>\n",
        ">\n",
        "> ReLU: Figure Source https://pytorch.org/docs/stable/_images/ReLU.png"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ExYe19IE2uzX"
      },
      "outputs": [],
      "source": [
        "'''\n",
        "Generating a Neural Network Model\n",
        "The model consists of\n",
        "Input (10 dimensions) - 1000 dimensions - 800 dimensions - 100 dimensions - Predictions (1 dimension)\n",
        "Sequential() is a class of models that can be written without branching from the input\n",
        "Dense() is the all-connecting (affine) layer\n",
        "activation is the activation function, ReLU is used here\n",
        "'''\n",
        "\n",
        "from tensorflow.keras.models import Sequential\n",
        "from tensorflow.keras.layers import Dense\n",
        "\n",
        "# Create an instance of the Sequential class and name it as窶徇odel窶拿n",
        "model = Sequential()\n",
        "\n",
        "# Add a 1000 dimensional layer to model; activation function is ReLU\n",
        "model.add(Dense(1000, activation = 'relu'))\n",
        "\n",
        "# Add a 800 dimensional layer to model; activation function is ReLU\n",
        "model.add(Dense(800, activation = 'relu'))\n",
        "\n",
        "# Add a 100 dimensional layer to model; activation function is ReLU\n",
        "model.add(Dense(100, activation = 'relu'))\n",
        "\n",
        "# Add a 1-dimensional layer to model; no activation function is applied to the last layer for the regression problem\n",
        "model.add(Dense(1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kNE0EjZgmxOf"
      },
      "source": [
        "Compile\n",
        "\n",
        "Prepare for training by executing the compile method of the Sequential class. The optimization method is Adam, an improved version of SGD (Stochastic Gradient Descent), and the error function is the mean squared error."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7jUh9eMq2uzY"
      },
      "outputs": [],
      "source": [
        "'''\n",
        "Compile the model\n",
        "Compiling the model prepares inverse propagation.\n",
        "Specify Adam as the optimization function\n",
        "Mean squared error is specified as the error function\n",
        "'''\n",
        "\n",
        "from tensorflow.keras.optimizers import Adam\n",
        "\n",
        "# Compiling with optimization function Adam, learning coefficient 1e-3 = 0.001, and mean squared error (average of the squares of the errors) as the loss function\n",
        "model.compile(Adam(learning_rate=1e-3), loss=\"mean_squared_error\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "44ZI5FKInD7m"
      },
      "source": [
        "### Learning\n",
        "\n",
        "Learning is performed using a method of the Sequential class called fit.\n",
        "\n",
        "The training data is used for 150 epochs (i.e., 150 rounds of the dataset), and then evaluated on test data. The batch size of 128 means that 128 data are input at a time and the parameters are updated.\n",
        "\n",
        "In this case, there are 404 data in the training data set, so $404/128=4$ (rounded up) updates are performed per epoch.\n",
        "There are 150 epochs, so the total number of updates is $4\\times 150=600$ times.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "iOPm3DHi2uzZ",
        "outputId": "0cc7c0ab-0873-454a-bb01-bcf3a2d55292"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch 1/150\n",
            "4/4 [==============================] - 1s 98ms/step - loss: 542.2038 - val_loss: 378.2771\n",
            "Epoch 2/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 313.3249 - val_loss: 142.4200\n",
            "Epoch 3/150\n",
            "4/4 [==============================] - 0s 33ms/step - loss: 120.3861 - val_loss: 167.2755\n",
            "Epoch 4/150\n",
            "4/4 [==============================] - 0s 38ms/step - loss: 112.8099 - val_loss: 76.4208\n",
            "Epoch 5/150\n",
            "4/4 [==============================] - 0s 34ms/step - loss: 42.2795 - val_loss: 71.8820\n",
            "Epoch 6/150\n",
            "4/4 [==============================] - 0s 40ms/step - loss: 45.0763 - val_loss: 58.8306\n",
            "Epoch 7/150\n",
            "4/4 [==============================] - 0s 36ms/step - loss: 29.7077 - val_loss: 47.6159\n",
            "Epoch 8/150\n",
            "4/4 [==============================] - 0s 78ms/step - loss: 26.8966 - val_loss: 44.8242\n",
            "Epoch 9/150\n",
            "4/4 [==============================] - 0s 43ms/step - loss: 24.5714 - val_loss: 36.3940\n",
            "Epoch 10/150\n",
            "4/4 [==============================] - 0s 36ms/step - loss: 19.3088 - val_loss: 34.3080\n",
            "Epoch 11/150\n",
            "4/4 [==============================] - 0s 38ms/step - loss: 17.4799 - val_loss: 31.7423\n",
            "Epoch 12/150\n",
            "4/4 [==============================] - 0s 40ms/step - loss: 16.1699 - val_loss: 32.1581\n",
            "Epoch 13/150\n",
            "4/4 [==============================] - 0s 79ms/step - loss: 15.8482 - val_loss: 31.9483\n",
            "Epoch 14/150\n",
            "4/4 [==============================] - 0s 39ms/step - loss: 15.0977 - val_loss: 30.9777\n",
            "Epoch 15/150\n",
            "4/4 [==============================] - 0s 34ms/step - loss: 14.6053 - val_loss: 30.4021\n",
            "Epoch 16/150\n",
            "4/4 [==============================] - 0s 38ms/step - loss: 13.8320 - val_loss: 30.5984\n",
            "Epoch 17/150\n",
            "4/4 [==============================] - 0s 36ms/step - loss: 13.3431 - val_loss: 30.5575\n",
            "Epoch 18/150\n",
            "4/4 [==============================] - 0s 39ms/step - loss: 12.9117 - val_loss: 29.3249\n",
            "Epoch 19/150\n",
            "4/4 [==============================] - 0s 34ms/step - loss: 12.6846 - val_loss: 28.2872\n",
            "Epoch 20/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 12.1732 - val_loss: 27.7082\n",
            "Epoch 21/150\n",
            "4/4 [==============================] - 0s 36ms/step - loss: 11.9274 - val_loss: 27.1071\n",
            "Epoch 22/150\n",
            "4/4 [==============================] - 0s 36ms/step - loss: 11.6732 - val_loss: 26.6342\n",
            "Epoch 23/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 11.4140 - val_loss: 25.2945\n",
            "Epoch 24/150\n",
            "4/4 [==============================] - 0s 39ms/step - loss: 11.2454 - val_loss: 25.4208\n",
            "Epoch 25/150\n",
            "4/4 [==============================] - 0s 40ms/step - loss: 10.9239 - val_loss: 25.6217\n",
            "Epoch 26/150\n",
            "4/4 [==============================] - 0s 37ms/step - loss: 10.7363 - val_loss: 25.1891\n",
            "Epoch 27/150\n",
            "4/4 [==============================] - 0s 36ms/step - loss: 10.5209 - val_loss: 25.5208\n",
            "Epoch 28/150\n",
            "4/4 [==============================] - 0s 40ms/step - loss: 10.1774 - val_loss: 24.6743\n",
            "Epoch 29/150\n",
            "4/4 [==============================] - 0s 40ms/step - loss: 10.0736 - val_loss: 24.5310\n",
            "Epoch 30/150\n",
            "4/4 [==============================] - 0s 36ms/step - loss: 10.0467 - val_loss: 24.2703\n",
            "Epoch 31/150\n",
            "4/4 [==============================] - 0s 37ms/step - loss: 9.9615 - val_loss: 23.8944\n",
            "Epoch 32/150\n",
            "4/4 [==============================] - 0s 41ms/step - loss: 9.5542 - val_loss: 24.5444\n",
            "Epoch 33/150\n",
            "4/4 [==============================] - 0s 36ms/step - loss: 9.5783 - val_loss: 25.8316\n",
            "Epoch 34/150\n",
            "4/4 [==============================] - 0s 37ms/step - loss: 9.7056 - val_loss: 24.9243\n",
            "Epoch 35/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 9.3160 - val_loss: 22.9954\n",
            "Epoch 36/150\n",
            "4/4 [==============================] - 0s 37ms/step - loss: 9.2122 - val_loss: 22.8383\n",
            "Epoch 37/150\n",
            "4/4 [==============================] - 0s 38ms/step - loss: 8.7401 - val_loss: 23.8713\n",
            "Epoch 38/150\n",
            "4/4 [==============================] - 0s 37ms/step - loss: 8.5705 - val_loss: 22.4284\n",
            "Epoch 39/150\n",
            "4/4 [==============================] - 0s 52ms/step - loss: 8.3343 - val_loss: 21.0313\n",
            "Epoch 40/150\n",
            "4/4 [==============================] - 0s 54ms/step - loss: 8.7656 - val_loss: 21.2243\n",
            "Epoch 41/150\n",
            "4/4 [==============================] - 0s 51ms/step - loss: 8.5711 - val_loss: 23.2545\n",
            "Epoch 42/150\n",
            "4/4 [==============================] - 0s 54ms/step - loss: 8.3963 - val_loss: 23.4132\n",
            "Epoch 43/150\n",
            "4/4 [==============================] - 0s 62ms/step - loss: 8.2590 - val_loss: 20.8971\n",
            "Epoch 44/150\n",
            "4/4 [==============================] - 0s 59ms/step - loss: 8.1156 - val_loss: 21.2252\n",
            "Epoch 45/150\n",
            "4/4 [==============================] - 0s 51ms/step - loss: 7.7277 - val_loss: 22.2657\n",
            "Epoch 46/150\n",
            "4/4 [==============================] - 0s 51ms/step - loss: 7.4283 - val_loss: 21.8822\n",
            "Epoch 47/150\n",
            "4/4 [==============================] - 0s 57ms/step - loss: 7.1957 - val_loss: 20.5706\n",
            "Epoch 48/150\n",
            "4/4 [==============================] - 0s 51ms/step - loss: 7.2384 - val_loss: 20.1177\n",
            "Epoch 49/150\n",
            "4/4 [==============================] - 0s 60ms/step - loss: 7.0801 - val_loss: 20.8228\n",
            "Epoch 50/150\n",
            "4/4 [==============================] - 0s 52ms/step - loss: 6.8289 - val_loss: 20.3527\n",
            "Epoch 51/150\n",
            "4/4 [==============================] - 0s 56ms/step - loss: 7.1758 - val_loss: 19.0590\n",
            "Epoch 52/150\n",
            "4/4 [==============================] - 0s 45ms/step - loss: 7.0969 - val_loss: 20.2688\n",
            "Epoch 53/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 6.7781 - val_loss: 21.1951\n",
            "Epoch 54/150\n",
            "4/4 [==============================] - 0s 36ms/step - loss: 6.7375 - val_loss: 19.5856\n",
            "Epoch 55/150\n",
            "4/4 [==============================] - 0s 37ms/step - loss: 6.9143 - val_loss: 19.4323\n",
            "Epoch 56/150\n",
            "4/4 [==============================] - 0s 32ms/step - loss: 6.3245 - val_loss: 21.3883\n",
            "Epoch 57/150\n",
            "4/4 [==============================] - 0s 36ms/step - loss: 6.4558 - val_loss: 20.5827\n",
            "Epoch 58/150\n",
            "4/4 [==============================] - 0s 36ms/step - loss: 6.2914 - val_loss: 19.7638\n",
            "Epoch 59/150\n",
            "4/4 [==============================] - 0s 41ms/step - loss: 6.0283 - val_loss: 19.5449\n",
            "Epoch 60/150\n",
            "4/4 [==============================] - 0s 34ms/step - loss: 6.0108 - val_loss: 19.6069\n",
            "Epoch 61/150\n",
            "4/4 [==============================] - 0s 36ms/step - loss: 5.6011 - val_loss: 19.2124\n",
            "Epoch 62/150\n",
            "4/4 [==============================] - 0s 39ms/step - loss: 5.5410 - val_loss: 19.0138\n",
            "Epoch 63/150\n",
            "4/4 [==============================] - 0s 33ms/step - loss: 5.6207 - val_loss: 19.3390\n",
            "Epoch 64/150\n",
            "4/4 [==============================] - 0s 41ms/step - loss: 5.5212 - val_loss: 19.2685\n",
            "Epoch 65/150\n",
            "4/4 [==============================] - 0s 38ms/step - loss: 5.2504 - val_loss: 19.3252\n",
            "Epoch 66/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 5.5253 - val_loss: 19.4586\n",
            "Epoch 67/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 5.3667 - val_loss: 20.0553\n",
            "Epoch 68/150\n",
            "4/4 [==============================] - 0s 36ms/step - loss: 5.3671 - val_loss: 18.5738\n",
            "Epoch 69/150\n",
            "4/4 [==============================] - 0s 46ms/step - loss: 5.1529 - val_loss: 19.0400\n",
            "Epoch 70/150\n",
            "4/4 [==============================] - 0s 38ms/step - loss: 4.9816 - val_loss: 19.5474\n",
            "Epoch 71/150\n",
            "4/4 [==============================] - 0s 36ms/step - loss: 4.9404 - val_loss: 18.4542\n",
            "Epoch 72/150\n",
            "4/4 [==============================] - 0s 37ms/step - loss: 4.8934 - val_loss: 18.6778\n",
            "Epoch 73/150\n",
            "4/4 [==============================] - 0s 36ms/step - loss: 4.7336 - val_loss: 19.7403\n",
            "Epoch 74/150\n",
            "4/4 [==============================] - 0s 37ms/step - loss: 4.7942 - val_loss: 18.5726\n",
            "Epoch 75/150\n",
            "4/4 [==============================] - 0s 42ms/step - loss: 4.6575 - val_loss: 18.3666\n",
            "Epoch 76/150\n",
            "4/4 [==============================] - 0s 40ms/step - loss: 4.6568 - val_loss: 18.2663\n",
            "Epoch 77/150\n",
            "4/4 [==============================] - 0s 37ms/step - loss: 4.7075 - val_loss: 17.9285\n",
            "Epoch 78/150\n",
            "4/4 [==============================] - 0s 34ms/step - loss: 4.5847 - val_loss: 19.0227\n",
            "Epoch 79/150\n",
            "4/4 [==============================] - 0s 36ms/step - loss: 4.6171 - val_loss: 18.5115\n",
            "Epoch 80/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 4.5494 - val_loss: 18.1896\n",
            "Epoch 81/150\n",
            "4/4 [==============================] - 0s 40ms/step - loss: 4.3259 - val_loss: 18.4847\n",
            "Epoch 82/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 4.2287 - val_loss: 17.6227\n",
            "Epoch 83/150\n",
            "4/4 [==============================] - 0s 37ms/step - loss: 4.2339 - val_loss: 17.5507\n",
            "Epoch 84/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 4.1158 - val_loss: 18.1849\n",
            "Epoch 85/150\n",
            "4/4 [==============================] - 0s 34ms/step - loss: 4.2792 - val_loss: 18.7086\n",
            "Epoch 86/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 4.1098 - val_loss: 17.2935\n",
            "Epoch 87/150\n",
            "4/4 [==============================] - 0s 36ms/step - loss: 4.7991 - val_loss: 17.8858\n",
            "Epoch 88/150\n",
            "4/4 [==============================] - 0s 38ms/step - loss: 5.1245 - val_loss: 17.1625\n",
            "Epoch 89/150\n",
            "4/4 [==============================] - 0s 34ms/step - loss: 4.4317 - val_loss: 16.4478\n",
            "Epoch 90/150\n",
            "4/4 [==============================] - 0s 41ms/step - loss: 4.4366 - val_loss: 18.0499\n",
            "Epoch 91/150\n",
            "4/4 [==============================] - 0s 39ms/step - loss: 4.4117 - val_loss: 16.6964\n",
            "Epoch 92/150\n",
            "4/4 [==============================] - 0s 38ms/step - loss: 4.2698 - val_loss: 17.3918\n",
            "Epoch 93/150\n",
            "4/4 [==============================] - 0s 36ms/step - loss: 3.9362 - val_loss: 18.7956\n",
            "Epoch 94/150\n",
            "4/4 [==============================] - 0s 37ms/step - loss: 3.9248 - val_loss: 16.6190\n",
            "Epoch 95/150\n",
            "4/4 [==============================] - 0s 41ms/step - loss: 4.0259 - val_loss: 16.4504\n",
            "Epoch 96/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 3.6345 - val_loss: 17.3056\n",
            "Epoch 97/150\n",
            "4/4 [==============================] - 0s 43ms/step - loss: 3.8261 - val_loss: 17.0214\n",
            "Epoch 98/150\n",
            "4/4 [==============================] - 0s 39ms/step - loss: 3.8824 - val_loss: 17.5772\n",
            "Epoch 99/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 3.5354 - val_loss: 17.1024\n",
            "Epoch 100/150\n",
            "4/4 [==============================] - 0s 37ms/step - loss: 3.6960 - val_loss: 16.8976\n",
            "Epoch 101/150\n",
            "4/4 [==============================] - 0s 33ms/step - loss: 3.7528 - val_loss: 16.9268\n",
            "Epoch 102/150\n",
            "4/4 [==============================] - 0s 38ms/step - loss: 3.8965 - val_loss: 17.5388\n",
            "Epoch 103/150\n",
            "4/4 [==============================] - 0s 33ms/step - loss: 3.7162 - val_loss: 16.4088\n",
            "Epoch 104/150\n",
            "4/4 [==============================] - 0s 42ms/step - loss: 3.6625 - val_loss: 17.6933\n",
            "Epoch 105/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 3.7193 - val_loss: 18.8482\n",
            "Epoch 106/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 3.4145 - val_loss: 17.5481\n",
            "Epoch 107/150\n",
            "4/4 [==============================] - 0s 40ms/step - loss: 3.8765 - val_loss: 18.2193\n",
            "Epoch 108/150\n",
            "4/4 [==============================] - 0s 38ms/step - loss: 3.9756 - val_loss: 18.5340\n",
            "Epoch 109/150\n",
            "4/4 [==============================] - 0s 37ms/step - loss: 3.4812 - val_loss: 17.4787\n",
            "Epoch 110/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 4.0229 - val_loss: 18.1465\n",
            "Epoch 111/150\n",
            "4/4 [==============================] - 0s 41ms/step - loss: 3.6270 - val_loss: 17.0778\n",
            "Epoch 112/150\n",
            "4/4 [==============================] - 0s 38ms/step - loss: 3.3341 - val_loss: 17.7077\n",
            "Epoch 113/150\n",
            "4/4 [==============================] - 0s 33ms/step - loss: 3.3702 - val_loss: 18.1563\n",
            "Epoch 114/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 3.5171 - val_loss: 16.9974\n",
            "Epoch 115/150\n",
            "4/4 [==============================] - 0s 33ms/step - loss: 3.1994 - val_loss: 17.3032\n",
            "Epoch 116/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 3.2157 - val_loss: 16.3297\n",
            "Epoch 117/150\n",
            "4/4 [==============================] - 0s 34ms/step - loss: 3.0004 - val_loss: 16.7432\n",
            "Epoch 118/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 2.9415 - val_loss: 17.2599\n",
            "Epoch 119/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 3.0669 - val_loss: 17.1835\n",
            "Epoch 120/150\n",
            "4/4 [==============================] - 0s 33ms/step - loss: 3.0717 - val_loss: 16.5335\n",
            "Epoch 121/150\n",
            "4/4 [==============================] - 0s 52ms/step - loss: 2.9440 - val_loss: 16.5038\n",
            "Epoch 122/150\n",
            "4/4 [==============================] - 0s 52ms/step - loss: 3.5278 - val_loss: 18.6104\n",
            "Epoch 123/150\n",
            "4/4 [==============================] - 0s 60ms/step - loss: 4.3809 - val_loss: 17.8559\n",
            "Epoch 124/150\n",
            "4/4 [==============================] - 0s 60ms/step - loss: 4.0733 - val_loss: 16.9863\n",
            "Epoch 125/150\n",
            "4/4 [==============================] - 0s 51ms/step - loss: 3.0045 - val_loss: 17.4551\n",
            "Epoch 126/150\n",
            "4/4 [==============================] - 0s 50ms/step - loss: 2.9737 - val_loss: 16.8521\n",
            "Epoch 127/150\n",
            "4/4 [==============================] - 0s 53ms/step - loss: 2.7982 - val_loss: 17.0489\n",
            "Epoch 128/150\n",
            "4/4 [==============================] - 0s 53ms/step - loss: 2.7982 - val_loss: 16.7516\n",
            "Epoch 129/150\n",
            "4/4 [==============================] - 0s 60ms/step - loss: 2.7997 - val_loss: 16.7671\n",
            "Epoch 130/150\n",
            "4/4 [==============================] - 0s 51ms/step - loss: 2.6560 - val_loss: 17.1813\n",
            "Epoch 131/150\n",
            "4/4 [==============================] - 0s 51ms/step - loss: 2.6543 - val_loss: 17.2898\n",
            "Epoch 132/150\n",
            "4/4 [==============================] - 0s 53ms/step - loss: 2.7647 - val_loss: 16.7646\n",
            "Epoch 133/150\n",
            "4/4 [==============================] - 0s 53ms/step - loss: 2.6827 - val_loss: 17.0182\n",
            "Epoch 134/150\n",
            "4/4 [==============================] - 0s 52ms/step - loss: 2.6362 - val_loss: 16.9628\n",
            "Epoch 135/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 2.7376 - val_loss: 17.7534\n",
            "Epoch 136/150\n",
            "4/4 [==============================] - 0s 33ms/step - loss: 2.6291 - val_loss: 17.1146\n",
            "Epoch 137/150\n",
            "4/4 [==============================] - 0s 36ms/step - loss: 2.6103 - val_loss: 17.7249\n",
            "Epoch 138/150\n",
            "4/4 [==============================] - 0s 36ms/step - loss: 2.9757 - val_loss: 17.8233\n",
            "Epoch 139/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 2.9980 - val_loss: 17.3445\n",
            "Epoch 140/150\n",
            "4/4 [==============================] - 0s 39ms/step - loss: 2.6846 - val_loss: 16.8408\n",
            "Epoch 141/150\n",
            "4/4 [==============================] - 0s 42ms/step - loss: 2.6002 - val_loss: 17.2965\n",
            "Epoch 142/150\n",
            "4/4 [==============================] - 0s 43ms/step - loss: 2.6645 - val_loss: 16.4657\n",
            "Epoch 143/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 2.6646 - val_loss: 17.1878\n",
            "Epoch 144/150\n",
            "4/4 [==============================] - 0s 37ms/step - loss: 2.4603 - val_loss: 17.9888\n",
            "Epoch 145/150\n",
            "4/4 [==============================] - 0s 39ms/step - loss: 2.5077 - val_loss: 17.7425\n",
            "Epoch 146/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 2.7038 - val_loss: 16.7032\n",
            "Epoch 147/150\n",
            "4/4 [==============================] - 0s 36ms/step - loss: 2.7077 - val_loss: 16.4277\n",
            "Epoch 148/150\n",
            "4/4 [==============================] - 0s 41ms/step - loss: 2.3702 - val_loss: 17.2785\n",
            "Epoch 149/150\n",
            "4/4 [==============================] - 0s 34ms/step - loss: 2.4953 - val_loss: 17.1210\n",
            "Epoch 150/150\n",
            "4/4 [==============================] - 0s 35ms/step - loss: 2.5684 - val_loss: 17.4941\n"
          ]
        }
      ],
      "source": [
        "'''\n",
        "train on training data with function \"fit\" and evaluate on test data\n",
        "\n",
        "batch_size: Number of data in a mini-batch\n",
        "epochs: Number of times to process all the data. 1 epoch = 1 round\n",
        "verbose: Display format, 0 means nothing is displayed, 1 means training results are displayed at each epoch\n",
        "validation _data: data for evaluation (here we use the test data as is, since we do not adjust the hyperparameters)\n",
        "'''\n",
        "history = model.fit(x_train, y_train, batch_size=128, epochs=150, verbose=1,\n",
        "          validation_data=(x_test, y_test))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ecmx7-vl9hd4"
      },
      "source": [
        "loss and val_loss are the mean squared error values for the training and test data, respectively."
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Prediction on test data\n",
        "\n",
        "The trained model can then be used to predict with the method predict. Let's apply predict method to the test data."
      ],
      "metadata": {
        "id": "wd42BUU2f0AO"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Predict for test data: return value is a NumPy array\n",
        "predict = model.predict(x_test)\n",
        "\n",
        "# shape of predict\n",
        "predict.shape"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "NTkY4n40oPaj",
        "outputId": "b95eebb7-2ac1-4edc-960f-6c5214de7bb8"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "4/4 [==============================] - 0s 5ms/step\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(102, 1)"
            ]
          },
          "metadata": {},
          "execution_count": 10
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "`predict` is now a second-order array of type (102,1). The `y_test` was a first-order array of type (102,), so the types are not aligned.\n",
        "\n",
        "Let's calculate the mean-square error after converting `predict` to the first order array using the NumPy method `flatten()`.\n"
      ],
      "metadata": {
        "id": "l8xJ-zZhoftb"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ovKh8tClHwNF",
        "outputId": "6839cc09-780e-4780-b073-300cfc99a471"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "mean of squared errors =  17.494140171134045\n"
          ]
        }
      ],
      "source": [
        "# make it first-order with flatten()\n",
        "predict = predict.flatten()\n",
        "\n",
        "# Compute the mean squared error between predicted and correct values\n",
        "MSE = np.mean((predict - y_test)**2)\n",
        "\n",
        "print('mean of squared errors = ', MSE)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 430
        },
        "id": "roPQsOd3HGER",
        "outputId": "7083302a-67ea-44a2-deba-63bda92fb921"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 640x480 with 1 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "# Visualization of results\n",
        "# Horizontal axis: values of the teacher data (true prices), vertical axis: predicted values\n",
        "\n",
        "plt.scatter(y_test, predict, c='r', marker='s') # scatter plot with red square markers.\n",
        "plt.plot([0,50],[0,50]) # Show a line y=x\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "Q2fig3WkUtZD"
      },
      "execution_count": null,
      "outputs": []
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.8"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}