Update notebooks for Google Colab compatability

This commit is contained in:
Jake Walker 2024-06-17 11:11:20 +01:00
parent b3e74b8c20
commit 3e562267fb
4 changed files with 32 additions and 54 deletions

View file

@ -64,30 +64,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"num_train_samples = 50000\n", "from keras.datasets import cifar10\n",
"\n", "(x_train, y_train), (x_test, y_test) = cifar10.load_data()"
"x_train = np.empty((num_train_samples, 3, 32, 32), dtype=\"uint8\")\n",
"y_train = np.empty((num_train_samples,), dtype=\"uint8\")\n",
"\n",
"for i in range(1, 6):\n",
" file_path = os.path.join(\"cifar-10-batches-py\", f\"data_batch_{i}\")\n",
" (\n",
" x_train[(i - 1) * 10000 : i * 10000, :, :, :],\n",
" y_train[(i - 1) * 10000 : i * 10000],\n",
" ) = load_batch(file_path)\n",
"\n",
"file_path = os.path.join(\"cifar-10-batches-py\", \"test_batch\")\n",
"x_test, y_test = load_batch(file_path)\n",
"\n",
"y_train = np.reshape(y_train, (len(y_train), 1))\n",
"y_test = np.reshape(y_test, (len(y_test), 1))\n",
"\n",
"if backend.image_data_format() == \"channels_last\":\n",
" x_train = x_train.transpose(0, 2, 3, 1)\n",
" x_test = x_test.transpose(0, 2, 3, 1)\n",
"\n",
"x_test = x_test.astype(x_train.dtype)\n",
"y_test = y_test.astype(y_train.dtype)"
] ]
}, },
{ {
@ -151,8 +129,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"y_train_one_hot = keras.src.utils.numerical_utils.to_categorical(y_train, 10)\n", "y_train_one_hot = keras.utils.to_categorical(y_train, 10)\n",
"y_test_one_hot = keras.src.utils.numerical_utils.to_categorical(y_test, 10)" "y_test_one_hot = keras.utils.to_categorical(y_test, 10)"
] ]
}, },
{ {
@ -278,6 +256,15 @@
"Let's try and feed a picture of a cat to the model, and see what it thinks... As a reminder, the model hasn't been trained on pictures of cats." "Let's try and feed a picture of a cat to the model, and see what it thinks... As a reminder, the model hasn't been trained on pictures of cats."
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!wget -O cat.jpg https://git.subspace.solutions/cads/ai-lesson-resources/media/branch/main/cat.jpg"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,

View file

@ -20,6 +20,16 @@
"## Import the packages 📦" "## Import the packages 📦"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!apt-get update && apt-get install -y build-essential cmake swig\n",
"!pip install stable-baselines3\\[extra\\]==2.3.2 gymnasium\\[box2d\\]==0.29.1"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
@ -407,6 +417,15 @@
"\n", "\n",
"Is moon landing too boring for you? Try to **change the environment**, why not use MountainCar-v0, CartPole-v1 or CarRacing-v0? Check how they work [using the gym documentation](https://www.gymlibrary.dev/) and have fun 🎉." "Is moon landing too boring for you? Try to **change the environment**, why not use MountainCar-v0, CartPole-v1 or CarRacing-v0? Check how they work [using the gym documentation](https://www.gymlibrary.dev/) and have fun 🎉."
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!wget -O ppo-LunarLander-v2-good.zip https://git.subspace.solutions/cads/ai-lesson-resources/media/branch/main/ppo-LunarLander-v2-good.zip"
]
} }
], ],
"metadata": { "metadata": {

View file

@ -1,24 +0,0 @@
FROM docker.io/library/ubuntu:22.04
ENV DEBIAN_FRONTEND noninteractive
RUN apt-get update \
&& apt-get install -y --no-install-recommends apt-utils build-essential g++ curl cmake zlib1g-dev libjpeg-dev xvfb xorg-dev libboost-all-dev libsdl2-dev swig python3 python3-dev python3-future python3-pip python3-setuptools python3-wheel python3-tk libatlas-base-dev cython3 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
RUN python3 -m pip install --upgrade pip \
&& python3 -m pip install jupyterlab keras==3.3.3 matplotlib==3.9.0 numpy==1.26.4 tensorflow==2.16.1 scikit-image==0.22.0 \
&& python3 -m pip install "gymnasium[box2d]==0.29.1" "stable-baselines3[extra]==2.3.2"
RUN apt-get update && apt-get install -y wget
WORKDIR /work
COPY . /work
RUN /work/download-data.sh \
&& rm /work/*_solutions.ipynb
ENV DEBIAN_FRONTEND teletype
CMD xvfb-run -s "-screen 0 1400x900x24" \
/usr/local/bin/jupyter lab --port 8888 --ip=0.0.0.0 --allow-root

View file

@ -1,4 +0,0 @@
#!/bin/bash
wget -O cifar-10-python.tar.gz https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
tar -xvf cifar-10-python.tar.gz
rm cifar-10-python.tar.gz