Swarm_Observer package

Submodules

Swarm_Observer.BirdParticle module

class Swarm_Observer.BirdParticle.BirdParticle(position, w=1.0, c1=0.8, c2=0.2, name=None)[source]

Bases: object

evaluate(costFunc: callable, model: Module)[source]

Evaluates the current fitness of the particle.

Parameters:
  • costFunc (callable) – The cost function to be maximized.

  • model (torch.nn.Module) – The model to be used in the cost function.

get_history()[source]

Returns the history of the particle’s positions.

update_position()[source]

Updates the particle position based on its velocity.

update_velocity(pos_best_g: Tensor)[source]

Updates the particle velocity based on its own position and the global best position.

Parameters:

pos_best_g (torch.Tensor) – The global best position.

Swarm_Observer.Swarm module

class Swarm_Observer.Swarm.PSO(starting_positions: Tensor, cost_func: callable, model: Module, w: float = 1.0, c1: float = 0.8, c2: float = 0.2)[source]

Bases: object

get_best()[source]
get_history() DataFrame[source]

Returns the history of the swarm’s positions for each epoch.

get_points()[source]
run(epochs: int)[source]

Runs the Adversarial Particle Swarm Optimization algorithm for the specified number of epochs.

save_history(filename)[source]

Saves the history of the swarm’s positions for each epoch.

step() tuple[source]

Performs one iteration of the Adversarial Particle Swarm Optimization algorithm.

Module contents