{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Data classification with interpreTS" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we ensure that the required libraries are installed" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install sktime scikit-learn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this tutorial, we show how you can use interpreTS for data classification." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import interpreTS as it\n", "from sktime.datasets import load_arrow_head\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.metrics import classification_report, accuracy_score" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train set size: (42168, 1) (168,)\n", "Test set size: (10793, 1) (43,)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
dim_0
00-1.963009
1-1.957825
2-1.956145
3-1.938289
4-1.896657
\n", "
" ], "text/plain": [ " dim_0\n", "0 0 -1.963009\n", " 1 -1.957825\n", " 2 -1.956145\n", " 3 -1.938289\n", " 4 -1.896657" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#prepare data\n", "X, y = load_arrow_head(return_type=\"pd-multiindex\")\n", "instance_ids = np.unique(X.index.get_level_values(0))\n", "train_ids, test_ids = train_test_split(instance_ids, test_size=0.2, random_state=42)\n", "\n", "X_train = X.loc[train_ids]\n", "X_test = X.loc[test_ids]\n", "train_indices = [np.where(instance_ids == id_)[0][0] for id_ in train_ids]\n", "test_indices = [np.where(instance_ids == id_)[0][0] for id_ in test_ids]\n", "\n", "y_train = y[train_indices]\n", "y_test = y[test_indices]\n", "\n", "print(\"Train set size:\", X_train.shape, y_train.shape)\n", "print(\"Test set size:\", X_test.shape, y_test.shape)\n", "X.head()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
absolute_energy_level_0absolute_energy_level_1absolute_energy_dim_0binarize_mean_level_0binarize_mean_level_1binarize_mean_dim_0dominant_level_0dominant_level_1dominant_dim_0entropy_level_0...trough_dim_0variance_level_0variance_level_1variance_dim_0seasonality_strength_level_0seasonality_strength_level_1seasonality_strength_dim_0trend_strength_level_0trend_strength_level_1trend_strength_dim_0
02259005239625250.0000000.00.5019920.49800830.0225.0-0.7050680.0...-2.1682250.05250.00.9960160.00.9760960.9528670.01.01.495348e-04
175121795239625250.0000010.00.5019920.553785173.0225.00.9062150.0...-1.6283340.05250.00.9960160.00.9760960.9731320.01.01.773299e-05
249196005239625249.9999990.00.5019920.517928140.0225.01.0053480.0...-1.9817860.05250.00.9960160.00.9760960.9620010.01.03.883215e-07
314118755239625250.0000000.00.5019920.52191275.0225.00.1416330.0...-2.0489520.05250.00.9960160.00.9760960.9552700.01.01.608024e-04
49036005239625250.0000000.00.5019920.53784960.0225.0-0.8373480.0...-1.8862160.05250.00.9960160.00.9760960.9640420.01.01.305618e-04
\n", "

5 rows × 57 columns

\n", "
" ], "text/plain": [ " absolute_energy_level_0 absolute_energy_level_1 absolute_energy_dim_0 \\\n", "0 225900 5239625 250.000000 \n", "1 7512179 5239625 250.000001 \n", "2 4919600 5239625 249.999999 \n", "3 1411875 5239625 250.000000 \n", "4 903600 5239625 250.000000 \n", "\n", " binarize_mean_level_0 binarize_mean_level_1 binarize_mean_dim_0 \\\n", "0 0.0 0.501992 0.498008 \n", "1 0.0 0.501992 0.553785 \n", "2 0.0 0.501992 0.517928 \n", "3 0.0 0.501992 0.521912 \n", "4 0.0 0.501992 0.537849 \n", "\n", " dominant_level_0 dominant_level_1 dominant_dim_0 entropy_level_0 ... \\\n", "0 30.0 225.0 -0.705068 0.0 ... \n", "1 173.0 225.0 0.906215 0.0 ... \n", "2 140.0 225.0 1.005348 0.0 ... \n", "3 75.0 225.0 0.141633 0.0 ... \n", "4 60.0 225.0 -0.837348 0.0 ... \n", "\n", " trough_dim_0 variance_level_0 variance_level_1 variance_dim_0 \\\n", "0 -2.168225 0.0 5250.0 0.996016 \n", "1 -1.628334 0.0 5250.0 0.996016 \n", "2 -1.981786 0.0 5250.0 0.996016 \n", "3 -2.048952 0.0 5250.0 0.996016 \n", "4 -1.886216 0.0 5250.0 0.996016 \n", "\n", " seasonality_strength_level_0 seasonality_strength_level_1 \\\n", "0 0.0 0.976096 \n", "1 0.0 0.976096 \n", "2 0.0 0.976096 \n", "3 0.0 0.976096 \n", "4 0.0 0.976096 \n", "\n", " seasonality_strength_dim_0 trend_strength_level_0 trend_strength_level_1 \\\n", "0 0.952867 0.0 1.0 \n", "1 0.973132 0.0 1.0 \n", "2 0.962001 0.0 1.0 \n", "3 0.955270 0.0 1.0 \n", "4 0.964042 0.0 1.0 \n", "\n", " trend_strength_dim_0 \n", "0 1.495348e-04 \n", "1 1.773299e-05 \n", "2 3.883215e-07 \n", "3 1.608024e-04 \n", "4 1.305618e-04 \n", "\n", "[5 rows x 57 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# create a feature extractor\n", "t = it.FeatureExtractor(window_size=251, stride=251, features=\"for-ml\")\n", "X_train_ts = t.extract_features(X_train)\n", "X_test_ts = t.extract_features(X_test)\n", "X_test_ts.head()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 1.0\n", "Classification Report:\n", " precision recall f1-score support\n", "\n", " 0 1.00 1.00 1.00 17\n", " 1 1.00 1.00 1.00 13\n", " 2 1.00 1.00 1.00 13\n", "\n", " accuracy 1.00 43\n", " macro avg 1.00 1.00 1.00 43\n", "weighted avg 1.00 1.00 1.00 43\n", "\n" ] } ], "source": [ "# Initialize the classifier\n", "clf = RandomForestClassifier(random_state=42)\n", "\n", "# Train the classifier\n", "clf.fit(X_train_ts, y_train)\n", "y_pred = clf.predict(X_test_ts)\n", "\n", "print(\"Accuracy:\", accuracy_score(y_test, y_pred))\n", "print(\"Classification Report:\\n\", classification_report(y_test, y_pred))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" } }, "nbformat": 4, "nbformat_minor": 2 }