{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Regression with interpreTS" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this tutorial, we show how you can use interpreTS for regression." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:interpreTS:scikit-learn is not installed. Please install it to use interpreTS.\n" ] } ], "source": [ "import urllib.request as urllib2\n", "from io import BytesIO\n", "from zipfile import ZipFile\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import interpreTS as it" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Loading in the data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\slawek\\AppData\\Local\\Temp\\ipykernel_130648\\302813920.py:5: FutureWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.\n", " df_power_consumption: pd.DataFrame = pd.read_csv(\n", "C:\\Users\\slawek\\AppData\\Local\\Temp\\ipykernel_130648\\302813920.py:5: UserWarning: Parsing dates in %d/%m/%Y %H:%M:%S format when dayfirst=False (the default) was specified. Pass `dayfirst=True` or specify a format to silence this warning.\n", " df_power_consumption: pd.DataFrame = pd.read_csv(\n" ] }, { "data": { "text/plain": [ "timestamp\n", "0 days 01:11:00 1\n", "0 days 01:24:00 1\n", "5 days 00:27:00 1\n", "Name: count, dtype: int64" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "zip_url: str = \"https://archive.ics.uci.edu/ml/machine-learning-databases/00235/household_power_consumption.zip\"\n", "zipped_file_name: str = \"household_power_consumption.txt\"\n", "\n", "\n", "df_power_consumption: pd.DataFrame = pd.read_csv(\n", " ZipFile(BytesIO(urllib2.urlopen(zip_url).read())).open(zipped_file_name),\n", " sep=\";\",\n", " parse_dates={\"timestamp\": [\"Date\", \"Time\"]},\n", " infer_datetime_format=True,\n", " low_memory=False,\n", " na_values=[\"nan\", \"?\"],\n", " index_col=\"timestamp\",\n", " dtype=\"float32\",\n", ")\n", "\n", "df_power_consumption = df_power_consumption.dropna()\n", "df_power_consumption.index.to_series().diff().value_counts().sample(3)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\slawek\\AppData\\Local\\Temp\\ipykernel_130648\\192953477.py:1: FutureWarning: The provided callable is currently using Rolling.mean. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string \"mean\" instead.\n", " df_power_consumption[\"avg_15min_GAP\"] = df_power_consumption.rolling(\"15min\")[\n" ] } ], "source": [ "df_power_consumption[\"avg_15min_GAP\"] = df_power_consumption.rolling(\"15min\")[\n", " \"Global_active_power\"\n", "].aggregate(np.nanmean)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Global_active_powerGlobal_reactive_powerVoltageGlobal_intensitySub_metering_1Sub_metering_2Sub_metering_3avg_15min_GAPyearmonth
adjusted_monthtimestamp
02006-12-16 17:24:004.2160.418234.83999618.40.01.017.04.216000200612
2006-12-16 17:25:005.3600.436233.63000523.00.01.016.04.788000200612
2006-12-16 17:26:005.3740.498233.28999323.00.02.017.04.983333200612
2006-12-16 17:27:005.3880.502233.74000523.00.01.017.05.084500200612
2006-12-16 17:28:003.6660.528235.67999315.80.01.017.04.800800200612
....................................
382010-02-05 04:24:000.3400.076245.9799961.40.01.00.00.34400020102
2010-02-05 04:25:000.3440.076245.8899991.60.01.01.00.34426720102
2010-02-05 04:26:000.3440.074245.6600041.60.01.01.00.34373320102
2010-02-05 04:27:000.3440.076246.1900021.60.01.00.00.34400020102
2010-02-05 04:28:000.4200.162246.7400052.00.01.01.00.34933320102
\n", "

1639424 rows × 10 columns

\n", "
" ], "text/plain": [ " Global_active_power \\\n", "adjusted_month timestamp \n", "0 2006-12-16 17:24:00 4.216 \n", " 2006-12-16 17:25:00 5.360 \n", " 2006-12-16 17:26:00 5.374 \n", " 2006-12-16 17:27:00 5.388 \n", " 2006-12-16 17:28:00 3.666 \n", "... ... \n", "38 2010-02-05 04:24:00 0.340 \n", " 2010-02-05 04:25:00 0.344 \n", " 2010-02-05 04:26:00 0.344 \n", " 2010-02-05 04:27:00 0.344 \n", " 2010-02-05 04:28:00 0.420 \n", "\n", " Global_reactive_power Voltage \\\n", "adjusted_month timestamp \n", "0 2006-12-16 17:24:00 0.418 234.839996 \n", " 2006-12-16 17:25:00 0.436 233.630005 \n", " 2006-12-16 17:26:00 0.498 233.289993 \n", " 2006-12-16 17:27:00 0.502 233.740005 \n", " 2006-12-16 17:28:00 0.528 235.679993 \n", "... ... ... \n", "38 2010-02-05 04:24:00 0.076 245.979996 \n", " 2010-02-05 04:25:00 0.076 245.889999 \n", " 2010-02-05 04:26:00 0.074 245.660004 \n", " 2010-02-05 04:27:00 0.076 246.190002 \n", " 2010-02-05 04:28:00 0.162 246.740005 \n", "\n", " Global_intensity Sub_metering_1 \\\n", "adjusted_month timestamp \n", "0 2006-12-16 17:24:00 18.4 0.0 \n", " 2006-12-16 17:25:00 23.0 0.0 \n", " 2006-12-16 17:26:00 23.0 0.0 \n", " 2006-12-16 17:27:00 23.0 0.0 \n", " 2006-12-16 17:28:00 15.8 0.0 \n", "... ... ... \n", "38 2010-02-05 04:24:00 1.4 0.0 \n", " 2010-02-05 04:25:00 1.6 0.0 \n", " 2010-02-05 04:26:00 1.6 0.0 \n", " 2010-02-05 04:27:00 1.6 0.0 \n", " 2010-02-05 04:28:00 2.0 0.0 \n", "\n", " Sub_metering_2 Sub_metering_3 \\\n", "adjusted_month timestamp \n", "0 2006-12-16 17:24:00 1.0 17.0 \n", " 2006-12-16 17:25:00 1.0 16.0 \n", " 2006-12-16 17:26:00 2.0 17.0 \n", " 2006-12-16 17:27:00 1.0 17.0 \n", " 2006-12-16 17:28:00 1.0 17.0 \n", "... ... ... \n", "38 2010-02-05 04:24:00 1.0 0.0 \n", " 2010-02-05 04:25:00 1.0 1.0 \n", " 2010-02-05 04:26:00 1.0 1.0 \n", " 2010-02-05 04:27:00 1.0 0.0 \n", " 2010-02-05 04:28:00 1.0 1.0 \n", "\n", " avg_15min_GAP year month \n", "adjusted_month timestamp \n", "0 2006-12-16 17:24:00 4.216000 2006 12 \n", " 2006-12-16 17:25:00 4.788000 2006 12 \n", " 2006-12-16 17:26:00 4.983333 2006 12 \n", " 2006-12-16 17:27:00 5.084500 2006 12 \n", " 2006-12-16 17:28:00 4.800800 2006 12 \n", "... ... ... ... \n", "38 2010-02-05 04:24:00 0.344000 2010 2 \n", " 2010-02-05 04:25:00 0.344267 2010 2 \n", " 2010-02-05 04:26:00 0.343733 2010 2 \n", " 2010-02-05 04:27:00 0.344000 2010 2 \n", " 2010-02-05 04:28:00 0.349333 2010 2 \n", "\n", "[1639424 rows x 10 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_columns = [f\"Sub_metering_{i}\" for i in range(1, 4)] + [\"timestamp\"]\n", "target_col = \"avg_15min_GAP\"\n", "\n", "# The percentage of data used for testing\n", "test_pct = 0.2\n", "day_margin = 3\n", "\n", "# add the timestamp col\n", "df_power_consumption[\"timestamp\"] = df_power_consumption.index\n", "\n", "# Ensure timestamp is in datetime format\n", "df_power_consumption['timestamp'] = pd.to_datetime(df_power_consumption['timestamp'])\n", "\n", "# Add 'year' and 'month' columns\n", "df_power_consumption['year'] = df_power_consumption['timestamp'].dt.year\n", "df_power_consumption['month'] = df_power_consumption['timestamp'].dt.month\n", "\n", "# Add 'adjusted_month' column\n", "df_power_consumption['adjusted_month'] = (df_power_consumption['year'] - 2007) * 12 + df_power_consumption['month']\n", "\n", "# Temporal split logic\n", "df_train = df_power_consumption[: -int(len(df_power_consumption) * test_pct)].copy()\n", "df_test = df_power_consumption[df_train.index[-1] + pd.Timedelta(days=day_margin):].copy()\n", "\n", "# Add MultiIndex for training data (adjusted_month, timestamp)\n", "df_train = df_train.set_index(['adjusted_month', 'timestamp'])\n", "df_train.sort_index(inplace=True)\n", "\n", "# Add MultiIndex for testing data (adjusted_month, timestamp)\n", "df_test = df_test.set_index(['adjusted_month', 'timestamp'])\n", "df_test.sort_index(inplace=True)\n", "\n", "# Output the training data head\n", "df_train" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((1639424,), (1639424, 9), (405537,), (405537, 9))" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Reset the index to remove the MultiIndex\n", "df_test_reshaped = df_test.reset_index()\n", "df_train_reshaped = df_train.reset_index()\n", "train_y = df_train_reshaped['avg_15min_GAP']\n", "train_y_monthly = df_train.groupby(level='adjusted_month').mean()\n", "test_y = df_test_reshaped['avg_15min_GAP']\n", "test_y_monthly = df_test.groupby(level='adjusted_month').mean()\n", "df_test_reshaped.drop(columns=['timestamp', 'month', 'avg_15min_GAP'], inplace=True)\n", "df_train_reshaped.drop(columns=['timestamp', 'month', 'avg_15min_GAP'], inplace=True)\n", "train_y.shape, df_train_reshaped.shape, test_y.shape, df_test_reshaped.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Feature extraction with interpreTS" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
length_Global_active_powerlength_Global_reactive_powerlength_Voltagelength_Global_intensitylength_Sub_metering_1length_Sub_metering_2length_Sub_metering_3length_yearmean_Global_active_powermean_Global_reactive_power...spikeness_Sub_metering_3spikeness_yearseasonality_strength_Global_active_powerseasonality_strength_Global_reactive_powerseasonality_strength_Voltageseasonality_strength_Global_intensityseasonality_strength_Sub_metering_1seasonality_strength_Sub_metering_2seasonality_strength_Sub_metering_3seasonality_strength_year
021992219922199221992219922199221992219921.9012950.131386...0.3215710.00.8834800.8521070.9588300.8912630.8205140.9304660.9802020.0
144638446384463844638446384463844638446381.5460340.132676...0.3258680.00.9018370.8613840.9547320.9027530.8231610.9315760.9807580.0
240318403184031840318403184031840318403181.4010840.113637...0.4886570.00.9407500.8669330.9409740.9419040.7861040.9261320.9827620.0
344639446394463944639446394463944639446391.3186270.114747...0.5366070.00.9451530.8718130.9432820.9443590.8019390.9302330.9795830.0
439477394773947739477394773947739477394770.8911890.118778...1.0071430.00.9294970.8746440.9636380.9278530.8024030.8966150.9789740.0
\n", "

5 rows × 56 columns

\n", "
" ], "text/plain": [ " length_Global_active_power length_Global_reactive_power length_Voltage \\\n", "0 21992 21992 21992 \n", "1 44638 44638 44638 \n", "2 40318 40318 40318 \n", "3 44639 44639 44639 \n", "4 39477 39477 39477 \n", "\n", " length_Global_intensity length_Sub_metering_1 length_Sub_metering_2 \\\n", "0 21992 21992 21992 \n", "1 44638 44638 44638 \n", "2 40318 40318 40318 \n", "3 44639 44639 44639 \n", "4 39477 39477 39477 \n", "\n", " length_Sub_metering_3 length_year mean_Global_active_power \\\n", "0 21992 21992 1.901295 \n", "1 44638 44638 1.546034 \n", "2 40318 40318 1.401084 \n", "3 44639 44639 1.318627 \n", "4 39477 39477 0.891189 \n", "\n", " mean_Global_reactive_power ... spikeness_Sub_metering_3 spikeness_year \\\n", "0 0.131386 ... 0.321571 0.0 \n", "1 0.132676 ... 0.325868 0.0 \n", "2 0.113637 ... 0.488657 0.0 \n", "3 0.114747 ... 0.536607 0.0 \n", "4 0.118778 ... 1.007143 0.0 \n", "\n", " seasonality_strength_Global_active_power \\\n", "0 0.883480 \n", "1 0.901837 \n", "2 0.940750 \n", "3 0.945153 \n", "4 0.929497 \n", "\n", " seasonality_strength_Global_reactive_power seasonality_strength_Voltage \\\n", "0 0.852107 0.958830 \n", "1 0.861384 0.954732 \n", "2 0.866933 0.940974 \n", "3 0.871813 0.943282 \n", "4 0.874644 0.963638 \n", "\n", " seasonality_strength_Global_intensity seasonality_strength_Sub_metering_1 \\\n", "0 0.891263 0.820514 \n", "1 0.902753 0.823161 \n", "2 0.941904 0.786104 \n", "3 0.944359 0.801939 \n", "4 0.927853 0.802403 \n", "\n", " seasonality_strength_Sub_metering_2 seasonality_strength_Sub_metering_3 \\\n", "0 0.930466 0.980202 \n", "1 0.931576 0.980758 \n", "2 0.926132 0.982762 \n", "3 0.930233 0.979583 \n", "4 0.896615 0.978974 \n", "\n", " seasonality_strength_year \n", "0 0.0 \n", "1 0.0 \n", "2 0.0 \n", "3 0.0 \n", "4 0.0 \n", "\n", "[5 rows x 56 columns]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "extractor = it.FeatureExtractor(id_column=\"adjusted_month\")\n", "features_train = extractor.extract_features(df_train_reshaped)\n", "features_test = extractor.extract_features(df_test_reshaped)\n", "features_train.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using interpreTS for regression" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [], "source": [ "import xgboost as xgb\n", "from sklearn.metrics import mean_squared_error" ] }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RMSE: 1.8527\n" ] } ], "source": [ "gb_regressor = xgb.XGBRegressor(random_state=42)\n", "\n", "gb_regressor.fit(features_train, train_y_monthly)\n", "\n", "y_pred = gb_regressor.predict(features_test)\n", "\n", "rmse = np.sqrt(mean_squared_error(test_y_monthly, y_pred))\n", "print(f\"RMSE: {rmse:.4f}\")\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 2 }