Generator এবং Discriminator মডেল তৈরি

Caffe2 তে GAN (Generative Adversarial Networks) মডেল তৈরি - ক্যাফে২ (Caffe2) - Machine Learning

275

Generator এবং Discriminator মডেলগুলি সাধারণত Generative Adversarial Networks (GANs)-এ ব্যবহৃত হয়। GANs হল একটি ধরনের ডিপ লার্নিং আর্কিটেকচার যেখানে দুটি মডেল একে অপরের বিরুদ্ধে প্রতিদ্বন্দ্বিতা করে। Generator একটি সৃষ্টিকারী মডেল যা নতুন ডেটা তৈরি করে, এবং Discriminator একটি পার্সেপশন মডেল যা সৃষ্ট ডেটা এবং আসল ডেটা মধ্যে পার্থক্য নির্ধারণ করে। এই দুইটি মডেল একে অপরের বিরুদ্ধে প্রশিক্ষণ নেয়, যার ফলে জেনারেটর সময়ের সাথে আরও বাস্তবসম্মত ডেটা তৈরি করতে শিখে এবং ডিসক্রিমিনেটর আরও দক্ষ হয়ে ওঠে।

1. Generator মডেল তৈরি

Generator মডেলটি কৃত্রিম ডেটা তৈরি করে যা আসল ডেটার মতো দেখতে হতে হবে। সাধারণত, এটি একটি নিউরাল নেটওয়ার্ক যা একটি র্যান্ডম নোইজ (noise) ভেক্টর ইনপুট হিসেবে নেয় এবং তার মাধ্যমে নতুন ডেটা (যেমন ইমেজ) তৈরি করে।

PyTorch-এ Generator তৈরি:

import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super(Generator, self).__init__()
        
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 128),          # প্রথম লেয়ার
            nn.ReLU(True),                  # অ্যাক্টিভেশন ফাংশন
            nn.Linear(128, 256),            # দ্বিতীয় লেয়ার
            nn.ReLU(True),                  # অ্যাক্টিভেশন ফাংশন
            nn.Linear(256, 512),            # তৃতীয় লেয়ার
            nn.ReLU(True),                  # অ্যাক্টিভেশন ফাংশন
            nn.Linear(512, 1024),           # চতুর্থ লেয়ার
            nn.ReLU(True),                  # অ্যাক্টিভেশন ফাংশন
            nn.Linear(1024, img_dim),       # আউটপুট লেয়ার
            nn.Tanh()                        # আউটপুট পরিসীমা -1 থেকে 1 পর্যন্ত করতে Tanh ফাংশন
        )
        
    def forward(self, z):
        return self.gen(z)

মডেল বর্ণনা:

  • z_dim: ইনপুট noise ভেক্টরের ডাইমেনশন (এটি একটি র্যান্ডম ভেক্টর)
  • img_dim: আউটপুট ইমেজের ডাইমেনশন (যেমন 28x28x1 বা 64x64x3)

এটি Fully Connected (FC) লেয়ারগুলির একটি সিকোয়েন্সের মাধ্যমে একটি fake image তৈরি করবে যা বাস্তব ইমেজের মতো দেখাবে।


2. Discriminator মডেল তৈরি

Discriminator মডেলটি নির্ধারণ করে যে একটি ইমেজ আসল নাকি জেনারেটেড। এটি আসল এবং জেনারেটেড ইমেজের মধ্যে পার্থক্য শিখতে চেষ্টা করে এবং তাদের মধ্যে আসল/ভুয়া চিহ্নিত করে।

PyTorch-এ Discriminator তৈরি:

class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super(Discriminator, self).__init__()
        
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 1024),        # প্রথম লেয়ার
            nn.LeakyReLU(0.2),               # LeakyReLU অ্যাক্টিভেশন ফাংশন
            nn.Linear(1024, 512),            # দ্বিতীয় লেয়ার
            nn.LeakyReLU(0.2),               # LeakyReLU অ্যাক্টিভেশন ফাংশন
            nn.Linear(512, 256),             # তৃতীয় লেয়ার
            nn.LeakyReLU(0.2),               # LeakyReLU অ্যাক্টিভেশন ফাংশন
            nn.Linear(256, 1),               # আউটপুট লেয়ার (এখানে সিগময়েড ব্যবহৃত)
            nn.Sigmoid()                     # আউটপুট 0 বা 1 হওয়ার জন্য
        )
        
    def forward(self, img):
        return self.disc(img)

মডেল বর্ণনা:

  • img_dim: ইনপুট ইমেজের ডাইমেনশন
  • LeakyReLU: Leaky ReLU অ্যাক্টিভেশন ফাংশন যেটি কিছু পরিমাণ নেতিবাচক প্রবাহ বজায় রাখে। এটি ব্যাকপ্রপাগেশন সমস্যা (যেমন, গ্রেডিয়েন্ট ভ্যানিশিং) কমায়।

এই মডেলটি একটি ইমেজ গ্রহণ করে এবং এটি আসল (1) বা জেনারেটেড (0) কিনা তা শ্রেণীবদ্ধ করে।


3. GAN মডেল প্রশিক্ষণ

Generator এবং Discriminator এর মধ্যে প্রতিদ্বন্দ্বিতা সৃষ্টি করা হয়, যার ফলে তাদের দুটি মডেল একে অপরের বিরুদ্ধে প্রশিক্ষণ নেয়। Generator চেষ্টা করে ডিসক্রিমিনেটরকে বিভ্রান্ত করতে, এবং Discriminator চেষ্টা করে সেটি চিহ্নিত করতে।

PyTorch-এ GAN Training Loop:

import torch.optim as optim

# Hyperparameters
z_dim = 100         # Noise vector size
img_dim = 784       # MNIST image size (28x28 = 784)
lr = 0.0002         # Learning rate
batch_size = 64

# Models
generator = Generator(z_dim, img_dim)
discriminator = Discriminator(img_dim)

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# Loss function
criterion = nn.BCELoss()  # Binary Cross Entropy Loss

# Training loop
for epoch in range(epochs):
    for real_images, _ in dataloader:  # MNIST dataloader or any other dataset
        batch_size = real_images.size(0)
        
        # Create labels for real and fake images
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)
        
        # Train Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
        optimizer_D.zero_grad()
        
        # Real images
        outputs = discriminator(real_images.view(batch_size, -1))
        d_loss_real = criterion(outputs, real_labels)
        d_loss_real.backward()
        
        # Fake images generated by the Generator
        noise = torch.randn(batch_size, z_dim)
        fake_images = generator(noise)
        outputs = discriminator(fake_images.detach())
        d_loss_fake = criterion(outputs, fake_labels)
        d_loss_fake.backward()
        
        optimizer_D.step()

        # Train Generator: maximize log(D(G(z)))
        optimizer_G.zero_grad()
        outputs = discriminator(fake_images)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        optimizer_G.step()
        
    print(f"Epoch [{epoch}/{epochs}], d_loss: {d_loss_real.item() + d_loss_fake.item()}, g_loss: {g_loss.item()}")

4. GAN Training Dynamics

  • Discriminator শিখছে কিভাবে আসল এবং জেনারেটেড ইমেজের মধ্যে পার্থক্য চিহ্নিত করতে হয়।
  • Generator শিখছে কিভাবে আরও ভালো এবং বাস্তবসম্মত ইমেজ তৈরি করতে হয় যা Discriminator এর কাছে আসল মনে হবে।

এটি adversarial training নামে পরিচিত, যেখানে দুইটি মডেল একে অপরের বিরুদ্ধে প্রতিযোগিতা করে এবং ট্রেনিংয়ের মাধ্যমে শক্তিশালী হয়।


সারাংশ:

  • Generator এবং Discriminator মডেলগুলি Generative Adversarial Networks (GANs) এর মূল অংশ।
  • Generator মডেলটি নতুন ডেটা তৈরি করে, এবং Discriminator মডেলটি আসল এবং জেনারেটেড ডেটার মধ্যে পার্থক্য নির্ধারণ করে।
  • দুটি মডেল একে অপরের বিরুদ্ধে প্রশিক্ষণ নিলেও, পুরো সিস্টেম মিলে আরও বাস্তবসম্মত ডেটা তৈরি করতে সক্ষম হয়।
Content added By
Promotion

Are you sure to start over?

Loading...