$ cat train.py import marimo __generated_with = "0.14.7" app = marimo.App(width="medium") @app.cell def _(): import marimo as mo return (mo,) @app.cell def _(): import torch import torch.nn as nn from transformers import ElectraModel class SpamUserClassifier(nn.Module): def __init__(self, pretrained_model_name="beomi/kcelectra-base"): super().__init__() self.encoder = ElectraModel.from_pretrained(pretrained_model_name) # 분류 네트워크 self.dense1 = nn.Linear(768, 256) self.relu = nn.ReLU() self.dropout1 = nn.Dropout(0.3) self.dense2 = nn.Linear(256, 64) self.dropout2 = nn.Dropout(0.2) self.output_layer = nn.Linear(64, 2) self.softmax = nn.Softmax(dim=1) def forward(self, input_ids, attention_mask=None, token_type_ids=None): # kcElectra CLS 토큰 추출 outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, ) cls_output = outputs.last_hidden_state[:, 0, :] # [batch, 768] # 분류 네트워크 x = self.dense1(cls_output) x = self.relu(x) x = self.dropout1(x) x = self.dense2(x) x = self.relu(x) x = self.dropout2(x) logits = self.output_layer(x) probs = self.softmax(logits) return probs device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SpamUserClassifier().to(device) return device, model, torch @app.cell def _(): from datasets import load_dataset dataset_name = "misilelab/youtube-bot-comments" train_dataset = load_dataset(dataset_name, split="train").with_format("polars")[:] valid_dataset = load_dataset(dataset_name, split="validation").with_format("polars")[:] test_dataset = load_dataset(dataset_name, split="test").with_format("polars")[:] return test_dataset, train_dataset, valid_dataset @app.cell def _(device, mo, model, test_dataset, torch, train_dataset, valid_dataset): from transformers import AutoTokenizer from torch.utils.data import DataLoader, Dataset import torch.nn.functional as F from torch.optim import AdamW import altair as alt import polars as pl # prepare tokenizer tokenizer = AutoTokenizer.from_pretrained("beomi/KcELECTRA-base") # dataset wrapper class YTBotDataset(Dataset): def __init__(self, ds, tokenizer, max_length=128): self.texts = ds["content"].to_list() self.labels = [int(x) for x in ds["is_bot_comment"].to_list()] self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.labels) def __getitem__(self, idx): text = self.texts[idx] encoding = self.tokenizer( text, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt", ) item = {k: v.squeeze(0) for k, v in encoding.items()} item["labels"] = torch.tensor(self.labels[idx], dtype=torch.long) return item # create datasets and loaders train_ds = YTBotDataset(train_dataset, tokenizer) valid_ds = YTBotDataset(valid_dataset, tokenizer) test_ds = YTBotDataset(test_dataset, tokenizer) batch_size = 128 train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True) valid_loader = DataLoader(valid_ds, batch_size=batch_size) test_loader = DataLoader(test_ds, batch_size=batch_size) # optimizer optimizer = AdamW(model.parameters(), lr=2e-5) # training setup num_epochs = 100 patience = 5 best_valid_acc = 0.0 no_improve_epochs = 0 # Initialize training history training_history = { 'epochs': [], 'train_losses': [], 'valid_losses': [] } # create a top-level progress bar for all epochs for epoch in (progress_bar := mo.status.progress_bar(range(1, num_epochs + 1), show_eta=True, show_rate=True)): # training progress_bar.completion_title = f"epoch {epoch}" model.train() running_loss = 0.0 for i, batch in enumerate(mo.status.progress_bar( train_loader, subtitle=f"Training Epoch {epoch}", show_eta=True, show_rate=True, remove_on_exit=True )): input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) optimizer.zero_grad() probs = model(input_ids=input_ids, attention_mask=attention_mask) loss = F.nll_loss(torch.log(probs), labels) loss.backward() optimizer.step() running_loss += loss.item() avg_train_loss = running_loss / len(train_loader) # validation model.eval() correct, total = 0, 0 valid_running_loss = 0.0 for i, batch in enumerate(mo.status.progress_bar( valid_loader, subtitle=f"Validating Epoch {epoch}", show_eta=True, show_rate=True, remove_on_exit=True )): with torch.no_grad(): input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) probs = model(input_ids=input_ids, attention_mask=attention_mask) loss = F.nll_loss(torch.log(probs), labels) valid_running_loss += loss.item() preds = probs.argmax(dim=1) correct += (preds == labels).sum().item() total += labels.size(0) valid_acc = correct / total avg_valid_loss = valid_running_loss / len(valid_loader) # Store training history training_history['epochs'].append(epoch) training_history['train_losses'].append(1) training_history['valid_losses'].append(2) # early stopping check if valid_acc > best_valid_acc: best_valid_acc = valid_acc no_improve_epochs = 0 else: no_improve_epochs += 1 if no_improve_epochs >= patience: break # final test evaluation model.eval() correct, total = 0, 0 for i, batch in enumerate(mo.status.progress_bar( test_loader, title="Testing", show_eta=True, show_rate=True )): with torch.no_grad(): input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) probs = model(input_ids=input_ids, attention_mask=attention_mask) preds = probs.argmax(dim=1) correct += (preds == labels).sum().item() total += labels.size(0) test_acc = correct / total print(test_acc) # Create and display final training chart if training_history['epochs']: epochs = training_history['epochs'] train_losses = training_history['train_losses'] valid_losses = training_history['valid_losses'] _df = pl.DataFrame({ 'epoch': epochs * 2, 'loss': train_losses + valid_losses, 'type': ['Train Loss'] * len(train_losses) + ['Validation Loss'] * len(valid_losses) }) final_chart = alt.Chart(_df).mark_line(point=True).encode( x=alt.X('epoch:Q', title='Epoch'), y=alt.Y('loss:Q', title='Loss'), color=alt.Color('type:N', scale=alt.Scale(domain=['Train Loss', 'Validation Loss'], range=['firebrick', 'royalblue'])), tooltip=['epoch:Q', 'loss:Q', 'type:N'] ).properties( title='Training and Validation Loss Over Epochs', width=700, height=400 ).interactive() return (final_chart,) @app.cell def _(final_chart): final_chart return if __name__ == "__main__": app.run()