| import copy
|
| import datetime
|
| import ephem
|
| import matplotlib.pyplot as plt
|
| import numpy as np
|
| import pandas as pd
|
| from sklearn.datasets import fetch_california_housing
|
| from sklearn.preprocessing import MinMaxScaler, StandardScaler
|
| from sklearn.model_selection import train_test_split
|
| import torch
|
| from torch import nn, optim, tensor
|
| import tqdm
|
| from typing import List, Tuple
|
|
|
| # Get some IMDB data
|
| y = pd.read_csv("YTrain.csv")
|
| X = pd.read_csv("XTrain.csv")
|
|
|
| # Clean/Massage the data #
|
|
|
| # Re-label `id` column
|
| y.rename(columns={'Unnamed: 0': 'id',}, inplace=True)
|
|
|
| # We only want a few features
|
| columns = ['popularity', 'release_date', 'revenue', 'runtime']
|
| X = pd.DataFrame(X, columns=columns)
|
|
|
| # Convert release date to Unix time
|
| X['release_date'] = pd.to_datetime(X['release_date']).astype(int) / 10**9
|
|
|
| # Convert to numpy arrays
|
| X = X.to_numpy()
|
| y = y['vote_average'].to_numpy()
|
|
|
| # Get some train-test splits for model evaluation
|
| X_train_raw, X_test_raw, y_train, y_test = train_test_split(X, y, train_size=0.7, shuffle=True)
|
|
|
| # Standardizing data
|
| scaler = StandardScaler()
|
| scaler.fit(X_train_raw)
|
| X_train = scaler.transform(X_train_raw)
|
| X_test = scaler.transform(X_test_raw)
|
|
|
| # Convert to 2D PyTorch tensors
|
| X_train = torch.tensor(X_train, dtype=torch.float32)
|
| y_train = torch.tensor(y_train, dtype=torch.float32).reshape(-1, 1)
|
| X_test = torch.tensor(X_test, dtype=torch.float32)
|
| y_test = torch.tensor(y_test, dtype=torch.float32).reshape(-1, 1)
|
|
|
| # Define the model
|
| model = nn.Sequential(
|
| nn.Linear(4, 12),
|
| nn.ReLU(),
|
| nn.Linear(12, 8),
|
| nn.ReLU(),
|
| nn.Linear(8, 4),
|
| nn.ReLU(),
|
| nn.Linear(4, 1)
|
| )
|
|
|
| # Loss Function and Optimizer
|
| loss_fn = nn.MSELoss() # mean square error
|
| optimizer = optim.Adam(model.parameters(), lr=0.0001)
|
|
|
| n_epochs = 100 # number of epochs to run
|
| batch_size = 10 # size of each batch
|
| batch_start = torch.arange(0, len(X_train), batch_size)
|
|
|
| # Hold the best model
|
| best_mse = np.inf # init to infinity
|
| best_weights = None
|
| history = []
|
|
|
| for epoch in range(n_epochs):
|
| model.train()
|
| with tqdm.tqdm(batch_start, unit="batch", mininterval=0, disable=True) as bar:
|
| bar.set_description(f"Epoch {epoch}")
|
| for start in bar:
|
| # take a batch
|
| X_batch = X_train[start:start+batch_size]
|
| y_batch = y_train[start:start+batch_size]
|
| # forward pass
|
| y_pred = model(X_batch)
|
| loss = loss_fn(y_pred, y_batch)
|
| # backward pass
|
| optimizer.zero_grad()
|
| loss.backward()
|
| # update weights
|
| optimizer.step()
|
| # print progress
|
| bar.set_postfix(mse=float(loss))
|
| # evaluate accuracy at end of each epoch
|
| model.eval()
|
| y_pred = model(X_test)
|
| mse = loss_fn(y_pred, y_test)
|
| mse = float(mse)
|
| history.append(mse)
|
| if mse < best_mse:
|
| best_mse = mse
|
| best_weights = copy.deepcopy(model.state_dict())
|
|
|
| ### Runs fine to here, but `best_mse` is inf and `best_weights` is NoneType ###
|
|
|
| # restore model and return best accuracy
|
| model.load_state_dict(best_weights)
|
| print("MSE: %.2f" % best_mse)
|
| print("RMSE: %.2f" % np.sqrt(best_mse))
|
| plt.plot(history)
|
| plt.show()
|
|
|
| model.eval()
|
| with torch.no_grad():
|
| # Test out inference with 5 samples
|
| for i in range(5):
|
| X_sample = X_test_raw[i: i+1]
|
| X_sample = scaler.transform(X_sample)
|
| X_sample = torch.tensor(X_sample, dtype=torch.float32)
|
| y_pred = model(X_sample)
|
| print(f"{X_test_raw[i]} -> {y_pred[0].numpy()} (expected {y_test[i].numpy()})")
|
|
|
| # EOF #
|
|
|
| ###############################
|
| ## The data in X looks like ###
|
| ###############################
|
| popularity release_date revenue runtime
|
| 0 10.142 1.563926e+09 0 120.0
|
| 1 28.346 1.161994e+09 50710400 141.0
|
| 2 13.098 1.558656e+09 0 100.0
|
| 3 42.672 1.532650e+09 28646544 85.0
|
| 4 20.728 7.337088e+08 13609396 89.0
|
| ... ... ... ... ...
|
| 6995 13.863 1.607299e+09 0 96.0
|
| 6996 17.171 9.018432e+08 65705772 121.0
|
| 6997 11.337 1.321315e+09 0 18.0
|
| 6998 99.513 1.539907e+09 34934009 133.0
|
| 6999 13.196 1.518566e+09 0 91.0
|
|
|
| 7000 rows × 4 columns
|