{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Example usage\n", "\n", "To use `fundaml` in a project:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: fundaml in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (0.1.27)\n", "Collecting fundaml\n", " Downloading fundaml-0.1.29-py3-none-any.whl (8.8 kB)\n", "Requirement already satisfied: numpy>=1.24.0 in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (from fundaml) (1.25.0)\n", "Requirement already satisfied: pandas>=1.4.0 in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (from fundaml) (2.0.3)\n", "Requirement already satisfied: torch<3.0.0,>=2.0.1 in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (from fundaml) (2.0.1)\n", "Requirement already satisfied: torchaudio<3.0.0,>=2.0.2 in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (from fundaml) (2.0.2)\n", "Requirement already satisfied: torchvision<0.16.0,>=0.15.2 in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (from fundaml) (0.15.2)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (from pandas>=1.4.0->fundaml) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (from pandas>=1.4.0->fundaml) (2023.3)\n", "Requirement already satisfied: tzdata>=2022.1 in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (from pandas>=1.4.0->fundaml) (2023.3)\n", "Requirement already satisfied: filelock in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (from torch<3.0.0,>=2.0.1->fundaml) (3.12.2)\n", "Requirement already satisfied: typing-extensions in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (from torch<3.0.0,>=2.0.1->fundaml) (4.7.1)\n", "Requirement already satisfied: sympy in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (from torch<3.0.0,>=2.0.1->fundaml) (1.12)\n", "Requirement already satisfied: networkx in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (from torch<3.0.0,>=2.0.1->fundaml) (3.1)\n", "Requirement already satisfied: jinja2 in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (from torch<3.0.0,>=2.0.1->fundaml) (3.1.2)\n", "Requirement already satisfied: requests in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (from torchvision<0.16.0,>=0.15.2->fundaml) (2.31.0)\n", "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (from torchvision<0.16.0,>=0.15.2->fundaml) (10.0.0)\n", "Requirement already satisfied: six>=1.5 in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas>=1.4.0->fundaml) (1.16.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (from jinja2->torch<3.0.0,>=2.0.1->fundaml) (2.1.3)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (from requests->torchvision<0.16.0,>=0.15.2->fundaml) (3.1.0)\n", "Requirement already satisfied: idna<4,>=2.5 in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (from requests->torchvision<0.16.0,>=0.15.2->fundaml) (3.4)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (from requests->torchvision<0.16.0,>=0.15.2->fundaml) (2.0.3)\n", "Requirement already satisfied: certifi>=2017.4.17 in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (from requests->torchvision<0.16.0,>=0.15.2->fundaml) (2023.5.7)\n", "Requirement already satisfied: mpmath>=0.19 in /Users/tzoght/opt/miniconda3/envs/fundaml/lib/python3.10/site-packages (from sympy->torch<3.0.0,>=2.0.1->fundaml) (1.3.0)\n", "Installing collected packages: fundaml\n", " Attempting uninstall: fundaml\n", " Found existing installation: fundaml 0.1.27\n", " Uninstalling fundaml-0.1.27:\n", " Successfully uninstalled fundaml-0.1.27\n", "Successfully installed fundaml-0.1.29\n", "0.1.27\n" ] } ], "source": [ "!pip install --upgrade fundaml\n", "import fundaml\n", "print(fundaml.__version__)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## eda submodule" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------------\n", " Shape: (3, 2)\n", "----------------------------------------\n", "\n", "\n", "RangeIndex: 3 entries, 0 to 2\n", "Data columns (total 2 columns):\n", " # Column Non-Null Count Dtype\n", "--- ------ -------------- -----\n", " 0 a 3 non-null int64\n", " 1 b 3 non-null int64\n", "dtypes: int64(2)\n", "memory usage: 176.0 bytes\n", "----------------------------------------\n", " Null:\n", " a 0\n", "b 0\n", "dtype: int64\n", "----------------------------------------\n", " Describe:\n", " a b\n", "count 3.0 3.0\n", "mean 2.0 5.0\n", "std 1.0 1.0\n", "min 1.0 4.0\n", "25% 1.5 4.5\n", "50% 2.0 5.0\n", "75% 2.5 5.5\n", "max 3.0 6.0\n", "----------------------------------------\n", " NA:\n", " a 0\n", "b 0\n", "dtype: int64\n", "----------------------------------------\n", " Duplicate:\n", " 0\n", "----------------------------------------\n", " Duplicated:\n", " 0 False\n", "1 False\n", "2 False\n", "dtype: bool\n" ] } ], "source": [ "from fundaml.eda import print_basic_eda\n", "import pandas as pd\n", "\n", "df = pd.DataFrame({\"a\": [1, 2, 3], \"b\": [4, 5, 6]})\n", "print_basic_eda(df)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Training a Neural Network Classifier" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Available devices on this machine: [device(type='mps')]\n", "Using device: cpu\n", "In training [ 64/60000] epoch: 1 loss: 2.301378 accuracy 0.125\n", "In training [ 6464/60000] epoch: 1 loss: 0.548937 accuracy 0.796875\n", "In training [12864/60000] epoch: 1 loss: 0.403776 accuracy 0.859375\n", "In training [19264/60000] epoch: 1 loss: 0.509262 accuracy 0.828125\n", "In training [25664/60000] epoch: 1 loss: 0.466876 accuracy 0.796875\n", "In training [32064/60000] epoch: 1 loss: 0.446234 accuracy 0.8125\n", "In training [38464/60000] epoch: 1 loss: 0.365968 accuracy 0.84375\n", "In training [44864/60000] epoch: 1 loss: 0.513393 accuracy 0.8125\n", "In training [51264/60000] epoch: 1 loss: 0.461992 accuracy 0.84375\n", "In training [57664/60000] epoch: 1 loss: 0.519337 accuracy 0.828125\n", "In training [30016/60000] epoch: 1 loss: 0.351558 accuracy 0.8125\n", "In testing [ 64/10000] epoch: 1 loss: 0.316642 accuracy 0.84375\n", "In testing [ 6464/10000] epoch: 1 loss: 0.386479 accuracy 0.84375\n", "In training [ 64/60000] epoch: 2 loss: 0.262214 accuracy 0.90625\n", "In training [ 6464/60000] epoch: 2 loss: 0.364398 accuracy 0.890625\n", "In training [12864/60000] epoch: 2 loss: 0.269503 accuracy 0.859375\n", "In training [19264/60000] epoch: 2 loss: 0.375651 accuracy 0.828125\n", "In training [25664/60000] epoch: 2 loss: 0.428343 accuracy 0.796875\n", "In training [32064/60000] epoch: 2 loss: 0.391344 accuracy 0.828125\n", "In training [38464/60000] epoch: 2 loss: 0.325320 accuracy 0.875\n", "In training [44864/60000] epoch: 2 loss: 0.457571 accuracy 0.8125\n", "In training [51264/60000] epoch: 2 loss: 0.388944 accuracy 0.84375\n", "In training [57664/60000] epoch: 2 loss: 0.450031 accuracy 0.796875\n", "In training [30016/60000] epoch: 2 loss: 0.271165 accuracy 0.8125\n", "In testing [ 64/10000] epoch: 2 loss: 0.295752 accuracy 0.875\n", "In testing [ 6464/10000] epoch: 2 loss: 0.342953 accuracy 0.890625\n", "In training [ 64/60000] epoch: 3 loss: 0.235921 accuracy 0.9375\n", "In training [ 6464/60000] epoch: 3 loss: 0.326183 accuracy 0.890625\n", "In training [12864/60000] epoch: 3 loss: 0.226515 accuracy 0.875\n", "In training [19264/60000] epoch: 3 loss: 0.328245 accuracy 0.84375\n", "In training [25664/60000] epoch: 3 loss: 0.383195 accuracy 0.859375\n", "In training [32064/60000] epoch: 3 loss: 0.315123 accuracy 0.84375\n", "In training [38464/60000] epoch: 3 loss: 0.289108 accuracy 0.859375\n", "In training [44864/60000] epoch: 3 loss: 0.456471 accuracy 0.859375\n", "In training [51264/60000] epoch: 3 loss: 0.348210 accuracy 0.859375\n", "In training [57664/60000] epoch: 3 loss: 0.356107 accuracy 0.796875\n", "In training [30016/60000] epoch: 3 loss: 0.262893 accuracy 0.875\n", "In testing [ 64/10000] epoch: 3 loss: 0.306127 accuracy 0.84375\n", "In testing [ 6464/10000] epoch: 3 loss: 0.325759 accuracy 0.875\n", "Early stopping!\n" ] } ], "source": [ "from fundaml.models import SampleNNClassifier\n", "from fundaml.trainers import NNTrainer, get_available_devices\n", "from fundaml.scores import score_accuracy\n", "import unittest\n", "import torch\n", "from torch import Tensor\n", "import torch\n", "from torch import nn\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import ToTensor\n", "\n", "training_data = datasets.FashionMNIST(\n", " root=\"data\",\n", " train=True,\n", " download=True,\n", " transform=ToTensor()\n", ")\n", "\n", "test_data = datasets.FashionMNIST(\n", " root=\"data\",\n", " train=False,\n", " download=True,\n", " transform=ToTensor()\n", ")\n", "\n", "learning_rate = 1e-3\n", "batch_size = 64\n", "epochs = 1\n", "weight_decay = 0.01\n", "\n", "train_dataloader = DataLoader(training_data, batch_size=batch_size)\n", "test_dataloader = DataLoader(test_data, batch_size=batch_size) \n", "\n", "model = SampleNNClassifier(short_name=\"Sample_NN_Classifier\")\n", "loss_fn = nn.CrossEntropyLoss()\n", "trainer = NNTrainer()\n", "\n", "# optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(0.9, 0.99), eps=1e-8)\n", "\n", "print(f\"Available devices on this machine: {get_available_devices()}\")\n", "trainer.with_model(model).with_optimizer(optimizer).with_loss_function(loss_fn)\n", "trainer.with_scoring_functions({'accuracy':score_accuracy}).with_device('cpu')\n", "# scores = trainer.train_loop(train_dataloader,update_every_n_batches=20,epochs=epochs)\n", "# scores = trainer.test_loop(test_dataloader,update_every_n_batches=2)\n", "scores = trainer.train_test_loop(train_dataloader, test_dataloader,update_every_n_batches=100, epochs=5)\n", "# print(scores)" ] } ], "metadata": { "kernelspec": { "display_name": "mltz_base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" }, "vscode": { "interpreter": { "hash": "7cfa4bdf23d59d966cb348cbf6b359c162aa4b758a5fdbf685462279e0c1f14f" } } }, "nbformat": 4, "nbformat_minor": 4 }