24

GAT-DQN

Graph Attention Deep Q-Learning Network

Graph Attention Deep Q-Network (GDQN)

This directory contains the implementation of the Graph Attention Deep Q-Network (GDQN) agent and related utilities.

Overview

The GDQN module is designed to provide a robust agent that selects actions in a graph environment based on Q-values; A variant of the Deep Q-Network. The architecture is shown below [1]

GAT

GAT Layer

  • Input - A set of node features (pos, group)

    • h={h1,h2,...hn}h = \{h_1, h_2, ... h_n\}
  • Output - A new set of node features A new set of node feautres

    • h={h1,h2,...hn}h' = \{h_1', h_2', ... h_n'\}
  • Weight Matrix WW is applied to every node

    • h = tf.matmul(input, W)
  • jj allows every node to attend every other node

  • Compute Attention Coefficient ee

    • eij=a(Whi,Whj)e_{ij} = a(Wh_{i}, Wh_{j})
  • Inject graph information by masked attention to compute for eije_{{ij}} for jj in Neighborhood NiN_{{i}}

e = leaky_relu(e)
adjacency_expanded = tf.cast(adjacency, tf.float32)
adjacency_expanded = tf.expand_dims(adjacency_expanded, axis=-1)
e = tf.where(adjacency_expanded > 0, e, -1e9 * tf.ones_like(e))
  • Normalize coefficients for easy comparison across all jj
ai_j = tf.nn.softmax(e, axis=2)
  • hh' is either concatenated or averaged
h_prime = tf.reduce_sum(alpha_expanded * h_j_tiled, axis=2)
if self.concat:
    #concat heads
    h_prime = tf.reshape(h_prime, (batch_size, num_nodes, self.num_heads*self.out_dim))
else:
    #Average heads
    h_prime = tf.reduce_mean(h_prime, axis=2)

Action Selection

  1. Initialize the environment using
env.reset()

This will provide the initial state S0

  1. Then S0S_0 is passed into
agent.select_action(state)
  1. select_action() converts the state (node) into (X,A)(X, A) tuple

    • XX - tensor of node features
    • AA - tensor of adjacent nodes
  2. (X,A)(X, A) is then passed into a Q-network and returns a tensor of Q-values

  3. The tensor of Q-values are mapped to their node_ids and the node with the max Q-value is selected and returned to the agent

Train Step

After taking a action (step) the agent recieves next_state, reward, done, info and is added to the ReplayBuffer and we update the state ss to the next state si+1s_{i+1}, and the reward rr is added to the total reward RR si=si+1s_i = s_{i+1} R=R+riR = R + r_i

  1. Check Replay Buffer Length

    • train_step(batch) Check the replay buffer to ensure the buffer has atleast a batch of experiences stored. takes a random batch of the replay buffer
  2. Sample Experiences

    • For each state in the batch preprocess_state() is called which returns a tuple where the first element is XX and AA and is converted into a numPy array
  3. Process States

    • The target netowrk computes the Q-values for each next state. These values are later used to calculate the target for our Bellman update
  4. Target Q-values & Bellman Update

    • For each sample in the batch the target Q-value is computed Qtarget=r+γmaxQnext(s,a)×(1done)Q_{target} = r + {\gamma}maxQ_{next}(s',a') × (1 - done)
    np.max(next_q.numpy(), axis=1)

    takes the maxium Q-value across all possible actions in the next state. Multiplying (1done)(1-done) ensures that if the state was terminated, the future reward is 00

  5. Forward Pass under GradientTape

    • The current Q-values for the batch of states are computed by passing XX and AA through the Q-network Forward pass under GradientTape
  6. Extracting Q-values for Taken Actions

    • The Q-values are extracted for taken actions. SInce the network outputs Q-values for all possible actions (for each state), we need to select the Q-value corresponding to the action that was actually taken.

    • tf.range(batch_size) creates a tensor with batch indices, which are stacked with the actions to form a tensor of indices

    • tf.gather_nd extracts the Q-value at each (batch, action) pair

  7. Loss Computation

    • the mean squared error (MSE) is computed between the computed target Q-values and the Q-values for the actions taken.
  8. Backpropagation and Optimization

    • Gradients of the loss with respect to the Q-network's weights are computed then applied via the Adam optimizer to update the network parameters
  9. Epsilon Decay

    • The exploration rate ϵ{\epsilon} is decayed to reduce exploration in favor of exploitation as training progesses.

Getting Started

  1. Install Dependencies
    Ensure you have installed the required dependencies for the project. This includes the setup.py file to ensure all modules can be imported.

    pip install -e .

  2. Run Training
    from the base directory, use the provided training scripts to train your GDQN agent. Example:

    python RL/GDQN/agent.py 

References

@misc{veličković2018graphattentionnetworks,
      title={Graph Attention Networks}, 
      author={Petar Veličković and Guillem Cucurull and Arantxa Casanova and Adriana Romero and Pietro Liò and Yoshua Bengio},
      year={2018},
      eprint={1710.10903},
      archivePrefix={arXiv},
      primaryClass={stat.ML},
      url={https://arxiv.org/abs/1710.10903}, 
},