Source code for Swarm_Observer.Swarm

import torch
import numpy as np
from .BirdParticle import BirdParticle
import pandas as pd
import logging

[docs] class PSO: def __init__(self, starting_positions: torch.Tensor, cost_func: callable, model: torch.nn.Module, w: float = 1.0, c1: float = 0.8, c2: float = 0.2): """ Initializes the Adversarial Particle Swarm Optimization algorithm. Args: starting_positions (torch.Tensor): The starting positions of the swarm. cost_func (callable): The cost function to be maximized. model (torch.nn.Module): The model to be used in the cost function. w (float): The inertia weight. c1 (float): The cognitive weight. c2 (float): The social weight. """ self.swarm = [BirdParticle(pos, w=w, c1=c1, c2=c2) for pos in starting_positions] self.cost_func = cost_func self.model = model self.pos_best_g = self.swarm[0].position_i self.cos_best_g = self.swarm[0].cost_i self.epoch = 0 self.history = []
[docs] def step(self) -> tuple: """Performs one iteration of the Adversarial Particle Swarm Optimization algorithm.""" self.epoch += 1 for p in self.swarm: p.evaluate(self.cost_func, self.model) p.update_velocity(self.pos_best_g) p.update_position() p.evaluate(self.cost_func, self.model) for particle in self.swarm: if particle.cost_i > self.cos_best_g: self.pos_best_g = particle.position_i self.cos_best_g = particle.cost_i particle.history.append(particle.position_i)
[docs] def get_points(self): return torch.vstack([particle.position_i for particle in self.swarm])
[docs] def get_best(self): return self.pos_best_g
[docs] def run(self, epochs: int): """Runs the Adversarial Particle Swarm Optimization algorithm for the specified number of epochs.""" for _ in range(epochs): self.step()
[docs] def get_history(self) -> pd.DataFrame: """Returns the history of the swarm's positions for each epoch.""" history = {f"epoch_{i}": [particle.history[i] for particle in self.swarm] for i in range(self.epoch)} return pd.DataFrame(history)
[docs] def save_history(self, filename): """Saves the history of the swarm's positions for each epoch.""" history = self.get_history() history.to_csv(filename, index=False)