{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Data classification with interpreTS" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we ensure that the required libraries are installed" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install sktime scikit-learn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this tutorial, we show how you can use interpreTS for data classification." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import interpreTS as it\n", "from sktime.datasets import load_arrow_head\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.metrics import classification_report, accuracy_score" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train set size: (42168, 1) (168,)\n", "Test set size: (10793, 1) (43,)\n" ] }, { "data": { "text/html": [ "
\n", " | \n", " | dim_0 | \n", "
---|---|---|
0 | \n", "0 | \n", "-1.963009 | \n", "
1 | \n", "-1.957825 | \n", "|
2 | \n", "-1.956145 | \n", "|
3 | \n", "-1.938289 | \n", "|
4 | \n", "-1.896657 | \n", "
\n", " | absolute_energy_level_0 | \n", "absolute_energy_level_1 | \n", "absolute_energy_dim_0 | \n", "binarize_mean_level_0 | \n", "binarize_mean_level_1 | \n", "binarize_mean_dim_0 | \n", "dominant_level_0 | \n", "dominant_level_1 | \n", "dominant_dim_0 | \n", "entropy_level_0 | \n", "... | \n", "trough_dim_0 | \n", "variance_level_0 | \n", "variance_level_1 | \n", "variance_dim_0 | \n", "seasonality_strength_level_0 | \n", "seasonality_strength_level_1 | \n", "seasonality_strength_dim_0 | \n", "trend_strength_level_0 | \n", "trend_strength_level_1 | \n", "trend_strength_dim_0 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "225900 | \n", "5239625 | \n", "250.000000 | \n", "0.0 | \n", "0.501992 | \n", "0.498008 | \n", "30.0 | \n", "225.0 | \n", "-0.705068 | \n", "0.0 | \n", "... | \n", "-2.168225 | \n", "0.0 | \n", "5250.0 | \n", "0.996016 | \n", "0.0 | \n", "0.976096 | \n", "0.952867 | \n", "0.0 | \n", "1.0 | \n", "1.495348e-04 | \n", "
1 | \n", "7512179 | \n", "5239625 | \n", "250.000001 | \n", "0.0 | \n", "0.501992 | \n", "0.553785 | \n", "173.0 | \n", "225.0 | \n", "0.906215 | \n", "0.0 | \n", "... | \n", "-1.628334 | \n", "0.0 | \n", "5250.0 | \n", "0.996016 | \n", "0.0 | \n", "0.976096 | \n", "0.973132 | \n", "0.0 | \n", "1.0 | \n", "1.773299e-05 | \n", "
2 | \n", "4919600 | \n", "5239625 | \n", "249.999999 | \n", "0.0 | \n", "0.501992 | \n", "0.517928 | \n", "140.0 | \n", "225.0 | \n", "1.005348 | \n", "0.0 | \n", "... | \n", "-1.981786 | \n", "0.0 | \n", "5250.0 | \n", "0.996016 | \n", "0.0 | \n", "0.976096 | \n", "0.962001 | \n", "0.0 | \n", "1.0 | \n", "3.883215e-07 | \n", "
3 | \n", "1411875 | \n", "5239625 | \n", "250.000000 | \n", "0.0 | \n", "0.501992 | \n", "0.521912 | \n", "75.0 | \n", "225.0 | \n", "0.141633 | \n", "0.0 | \n", "... | \n", "-2.048952 | \n", "0.0 | \n", "5250.0 | \n", "0.996016 | \n", "0.0 | \n", "0.976096 | \n", "0.955270 | \n", "0.0 | \n", "1.0 | \n", "1.608024e-04 | \n", "
4 | \n", "903600 | \n", "5239625 | \n", "250.000000 | \n", "0.0 | \n", "0.501992 | \n", "0.537849 | \n", "60.0 | \n", "225.0 | \n", "-0.837348 | \n", "0.0 | \n", "... | \n", "-1.886216 | \n", "0.0 | \n", "5250.0 | \n", "0.996016 | \n", "0.0 | \n", "0.976096 | \n", "0.964042 | \n", "0.0 | \n", "1.0 | \n", "1.305618e-04 | \n", "
5 rows × 57 columns
\n", "