New paste Repaste Download
$ cat train.py
import marimo
__generated_with = "0.14.7"
app = marimo.App(width="medium")
@app.cell
def _():
    import marimo as mo
    return (mo,)
@app.cell
def _():
    import torch
    import torch.nn as nn
    from transformers import ElectraModel
    class SpamUserClassifier(nn.Module):
        def __init__(self, pretrained_model_name="beomi/kcelectra-base"):
            super().__init__()
            self.encoder = ElectraModel.from_pretrained(pretrained_model_name)
            # 분류 네트워크
            self.dense1 = nn.Linear(768, 256)
            self.relu = nn.ReLU()
            self.dropout1 = nn.Dropout(0.3)
            self.dense2 = nn.Linear(256, 64)
            self.dropout2 = nn.Dropout(0.2)
            self.output_layer = nn.Linear(64, 2)
            self.softmax = nn.Softmax(dim=1)
        def forward(self, input_ids, attention_mask=None, token_type_ids=None):
            # kcElectra CLS 토큰 추출
            outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
            )
            cls_output = outputs.last_hidden_state[:, 0, :]  # [batch, 768]
            # 분류 네트워크
            x = self.dense1(cls_output)
            x = self.relu(x)
            x = self.dropout1(x)
            x = self.dense2(x)
            x = self.relu(x)
            x = self.dropout2(x)
            logits = self.output_layer(x)
            probs = self.softmax(logits)
            return probs
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SpamUserClassifier().to(device)
    return device, model, torch
@app.cell
def _():
    from datasets import load_dataset
    dataset_name = "misilelab/youtube-bot-comments"
    train_dataset = load_dataset(dataset_name, split="train").with_format("polars")[:]
    valid_dataset = load_dataset(dataset_name, split="validation").with_format("polars")[:]
    test_dataset  = load_dataset(dataset_name, split="test").with_format("polars")[:]
    return test_dataset, train_dataset, valid_dataset
@app.cell
def _(device, mo, model, test_dataset, torch, train_dataset, valid_dataset):
    from transformers import AutoTokenizer
    from torch.utils.data import DataLoader, Dataset
    import torch.nn.functional as F
    from torch.optim import AdamW
    import altair as alt
    import polars as pl
    # prepare tokenizer
    tokenizer = AutoTokenizer.from_pretrained("beomi/KcELECTRA-base")
    # dataset wrapper
    class YTBotDataset(Dataset):
        def __init__(self, ds, tokenizer, max_length=128):
            self.texts = ds["content"].to_list()
            self.labels = [int(x) for x in ds["is_bot_comment"].to_list()]
            self.tokenizer = tokenizer
            self.max_length = max_length
        def __len__(self):
            return len(self.labels)
        def __getitem__(self, idx):
            text = self.texts[idx]
            encoding = self.tokenizer(
                text,
                truncation=True,
                padding="max_length",
                max_length=self.max_length,
                return_tensors="pt",
            )
            item = {k: v.squeeze(0) for k, v in encoding.items()}
            item["labels"] = torch.tensor(self.labels[idx], dtype=torch.long)
            return item
    # create datasets and loaders
    train_ds = YTBotDataset(train_dataset, tokenizer)
    valid_ds = YTBotDataset(valid_dataset, tokenizer)
    test_ds  = YTBotDataset(test_dataset, tokenizer)
    batch_size = 128
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(valid_ds, batch_size=batch_size)
    test_loader  = DataLoader(test_ds, batch_size=batch_size)
    # optimizer
    optimizer = AdamW(model.parameters(), lr=2e-5)
    # training setup
    num_epochs = 100
    patience = 5
    best_valid_acc = 0.0
    no_improve_epochs = 0
    # Initialize training history
    training_history = {
        'epochs': [],
        'train_losses': [],
        'valid_losses': []
    }
    # create a top-level progress bar for all epochs
    for epoch in (progress_bar := mo.status.progress_bar(range(1, num_epochs + 1), show_eta=True, show_rate=True)):
        # training
        progress_bar.completion_title = f"epoch {epoch}"
        model.train()
        running_loss = 0.0
        for i, batch in enumerate(mo.status.progress_bar(
            train_loader,
            subtitle=f"Training Epoch {epoch}",
            show_eta=True,
            show_rate=True,
            remove_on_exit=True
        )):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            optimizer.zero_grad()
            probs = model(input_ids=input_ids, attention_mask=attention_mask)
            loss = F.nll_loss(torch.log(probs), labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        avg_train_loss = running_loss / len(train_loader)
        # validation
        model.eval()
        correct, total = 0, 0
        valid_running_loss = 0.0
        for i, batch in enumerate(mo.status.progress_bar(
            valid_loader,
            subtitle=f"Validating Epoch {epoch}",
            show_eta=True,
            show_rate=True,
            remove_on_exit=True
        )):
            with torch.no_grad():
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)
                probs = model(input_ids=input_ids, attention_mask=attention_mask)
                loss = F.nll_loss(torch.log(probs), labels)
                valid_running_loss += loss.item()
                preds = probs.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        valid_acc = correct / total
        avg_valid_loss = valid_running_loss / len(valid_loader)
        # Store training history
        training_history['epochs'].append(epoch)
        training_history['train_losses'].append(1)
        training_history['valid_losses'].append(2)
        # early stopping check
        if valid_acc > best_valid_acc:
            best_valid_acc = valid_acc
            no_improve_epochs = 0
        else:
            no_improve_epochs += 1
            if no_improve_epochs >= patience:
                break
    # final test evaluation
    model.eval()
    correct, total = 0, 0
    for i, batch in enumerate(mo.status.progress_bar(
        test_loader,
        title="Testing",
        show_eta=True,
        show_rate=True
    )):
        with torch.no_grad():
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            probs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = probs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    test_acc = correct / total
    print(test_acc)
    # Create and display final training chart
    if training_history['epochs']:
        epochs = training_history['epochs']
        train_losses = training_history['train_losses']
        valid_losses = training_history['valid_losses']
    
        _df = pl.DataFrame({
            'epoch': epochs * 2,
            'loss': train_losses + valid_losses,
            'type': ['Train Loss'] * len(train_losses) + ['Validation Loss'] * len(valid_losses)
        })
    
        final_chart = alt.Chart(_df).mark_line(point=True).encode(
            x=alt.X('epoch:Q', title='Epoch'),
            y=alt.Y('loss:Q', title='Loss'),
            color=alt.Color('type:N',
                           scale=alt.Scale(domain=['Train Loss', 'Validation Loss'],
                                           range=['firebrick', 'royalblue'])),
            tooltip=['epoch:Q', 'loss:Q', 'type:N']
        ).properties(
            title='Training and Validation Loss Over Epochs',
            width=700,
            height=400
        ).interactive()
    return (final_chart,)
@app.cell
def _(final_chart):
    final_chart
    return
if __name__ == "__main__":
    app.run()
Filename: $ cat train.py;. Size: 8kb. View raw, , hex, or download this file.

This paste expires on 2025-07-03 05:51:25.828184. Pasted through v1-api.