"""Attention (time mixing) modules for RWKV block.
Based/Modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py.
Some variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py.
""" # noqa
import math
from importlib.util import find_spec
from pathlib import Path
from typing import List, Optional, Tuple, Union
import torch
wkv_kernel = None
[docs]class WKVLinearAttention(torch.autograd.Function):
"""WKVLinearAttention function definition."""
[docs] @staticmethod
def forward(
ctx,
time_decay: torch.Tensor,
time_first: torch.Tensor,
key: torch.Tensor,
value: torch.tensor,
) -> torch.Tensor:
"""WKVLinearAttention function forward pass.
Args:
time_decay: Channel-wise time decay vector. (D_att)
time_first: Channel-wise time first vector. (D_att)
key: Key tensor. (B, U, D_att)
value: Value tensor. (B, U, D_att)
Returns:
out: Weighted Key-Value tensor. (B, U, D_att)
"""
batch, length, dim = key.size()
assert length <= wkv_kernel.context_size, (
f"Cannot process key of length {length} while context_size "
f"is ({wkv_kernel.context_size}). Limit should be increased."
)
assert batch * dim % min(dim, 32) == 0, (
f"batch size ({batch}) by dimension ({dim}) should be a multiple of "
f"{min(dim, 32)}"
)
ctx.input_dtype = key.dtype
time_decay = -torch.exp(time_decay.contiguous())
time_first = time_first.contiguous()
key = key.contiguous()
value = value.contiguous()
out = torch.empty_like(key, memory_format=torch.contiguous_format)
wkv_kernel.forward(time_decay, time_first, key, value, out)
ctx.save_for_backward(time_decay, time_first, key, value, out)
return out
[docs] @staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""WKVLinearAttention function backward pass.
Args:
grad_output: Output gradient. (B, U, D_att)
Returns:
grad_time_decay: Gradient for channel-wise time decay vector. (D_att)
grad_time_first: Gradient for channel-wise time first vector. (D_att)
grad_key: Gradient for key tensor. (B, U, D_att)
grad_value: Gradient for value tensor. (B, U, D_att)
"""
time_decay, time_first, key, value, output = ctx.saved_tensors
grad_dtype = ctx.input_dtype
batch, _, dim = key.size()
grad_time_decay = torch.empty(
(batch, dim),
memory_format=torch.contiguous_format,
dtype=time_decay.dtype,
device=time_decay.device,
)
grad_time_first = torch.empty(
(batch, dim),
memory_format=torch.contiguous_format,
dtype=time_decay.dtype,
device=time_decay.device,
)
grad_key = torch.empty_like(key, memory_format=torch.contiguous_format)
grad_value = torch.empty_like(value, memory_format=torch.contiguous_format)
wkv_kernel.backward(
time_decay,
time_first,
key,
value,
output,
grad_output.contiguous(),
grad_time_decay,
grad_time_first,
grad_key,
grad_value,
)
grad_time_decay = torch.sum(grad_time_decay, dim=0)
grad_time_first = torch.sum(grad_time_first, dim=0)
return (
grad_time_decay,
grad_time_first,
grad_key,
grad_value,
)
[docs]def load_wkv_kernel(context_size: int) -> None:
"""Load WKV CUDA kernel.
Args:
context_size: Context size.
"""
from torch.utils.cpp_extension import load
global wkv_kernel
if wkv_kernel is not None and wkv_kernel.context_size == context_size:
return
if find_spec("ninja") is None:
raise ImportError(
"Ninja package was not found. WKV kernel module can't be loaded "
"for training. Please, 'pip install ninja' in your environment."
)
if not torch.cuda.is_available():
raise ImportError(
"CUDA is currently a requirement for WKV kernel loading. "
"Please set your devices properly and launch again."
)
kernel_folder = Path(__file__).resolve().parent / "cuda"
kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu"]]
kernel_cflags = [
"-t 4",
"-std=c++17",
"-res-usage",
"--maxrregcount 60",
"--use_fast_math",
"-O3",
"-Xptxas -O3",
"--extra-device-vectorization",
f"-DTmax={context_size}",
]
wkv_kernel = load(
name=f"wkv_{context_size}",
sources=kernel_files,
verbose=False,
extra_cuda_cflags=kernel_cflags,
)
wkv_kernel.context_size = context_size
[docs]class SelfAttention(torch.nn.Module):
"""SelfAttention module definition.
Args:
size: Input/Output size.
attention_size: Attention hidden size.
context_size: Context size for WKV kernel.
block_id: Block index.
num_blocks: Number of blocks in the architecture.
"""
def __init__(
self,
size: int,
attention_size: int,
context_size: int,
block_id: int,
num_blocks: int,
) -> None:
"""Construct a SelfAttention object."""
super().__init__()
load_wkv_kernel(context_size)
self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1))
self.time_decay = torch.nn.Parameter(torch.empty(attention_size))
self.time_first = torch.nn.Parameter(torch.empty(attention_size))
self.time_mix_key = torch.nn.Parameter(torch.empty(1, 1, size))
self.time_mix_value = torch.nn.Parameter(torch.empty(1, 1, size))
self.time_mix_receptance = torch.nn.Parameter(torch.empty(1, 1, size))
self.proj_key = torch.nn.Linear(size, attention_size, bias=True)
self.proj_value = torch.nn.Linear(size, attention_size, bias=True)
self.proj_receptance = torch.nn.Linear(size, attention_size, bias=True)
self.proj_output = torch.nn.Linear(attention_size, size, bias=True)
self.block_id = block_id
self.reset_parameters(size, attention_size, block_id, num_blocks)
[docs] def reset_parameters(
self, size: int, attention_size: int, block_id: int, num_blocks: int
) -> None:
"""Reset module parameters.
Args:
size: Block size.
attention_size: Attention hidden size.
block_id: Block index.
num_blocks: Number of blocks in the architecture.
"""
ratio_0_to_1 = block_id / (num_blocks - 1)
ratio_1_to_almost0 = 1.0 - (block_id / num_blocks)
time_weight = torch.ones(1, 1, size)
for i in range(size):
time_weight[0, 0, i] = i / size
decay_speed = [
-5 + 8 * (h / (attention_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
for h in range(attention_size)
]
decay_speed = torch.tensor(
decay_speed, dtype=self.time_decay.dtype, device=self.time_decay.device
)
zigzag = (
torch.tensor(
[(i + 1) % 3 - 1 for i in range(attention_size)],
dtype=self.time_first.dtype,
device=self.time_first.device,
)
* 0.5
)
with torch.no_grad():
self.time_decay.data = decay_speed
self.time_first.data = torch.ones_like(
self.time_first * math.log(0.3) + zigzag
)
self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
self.time_mix_value.data = (
torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
)
self.time_mix_receptance.data = torch.pow(
time_weight, 0.5 * ratio_1_to_almost0
)
[docs] @torch.no_grad()
def wkv_linear_attention(
self,
time_decay: torch.Tensor,
time_first: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Compute WKV with state (i.e.: for inference).
Args:
time_decay: Channel-wise time decay vector. (D_att)
time_first: Channel-wise time first vector. (D_att)
key: Key tensor. (B, 1, D_att)
value: Value tensor. (B, 1, D_att)
state: Decoder hidden states. [3 x (B, D_att)]
Returns:
output: Weighted Key-Value. (B, 1, D_att)
state: Decoder hidden states. [3 x (B, 1, D_att)]
"""
num_state, den_state, max_state = state
max_for_output = torch.maximum(max_state, (time_first + key))
e1 = torch.exp(max_state - max_for_output)
e2 = torch.exp((time_first + key) - max_for_output)
numerator = e1 * num_state + e2 * value
denominator = e1 * den_state + e2
max_for_state = torch.maximum(key, (max_state + time_decay))
e1 = torch.exp((max_state + time_decay) - max_for_state)
e2 = torch.exp(key - max_for_state)
wkv = numerator / denominator
state = [e1 * num_state + e2 * value, e1 * den_state + e2, max_for_state]
return wkv, state
[docs] def forward(
self,
x: torch.Tensor,
state: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
"""Compute time mixing.
Args:
x: SelfAttention input sequences. (B, U, size)
state: Decoder hidden states. [5 x (B, 1, D_att, N)]
Returns:
x: SelfAttention output sequences. (B, U, size)
"""
shifted_x = (
self.time_shift(x) if state is None else state[1][..., self.block_id]
)
key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value)
receptance = x * self.time_mix_receptance + shifted_x * (
1 - self.time_mix_receptance
)
key = self.proj_key(key)
value = self.proj_value(value)
receptance = torch.sigmoid(self.proj_receptance(receptance))
if state is not None:
state[1][..., self.block_id] = x
wkv, att_state = self.wkv_linear_attention(
self.time_decay,
self.time_first,
key,
value,
tuple(s[..., self.block_id] for s in state[2:]),
)
state[2][..., self.block_id] = att_state[0]
state[3][..., self.block_id] = att_state[1]
state[4][..., self.block_id] = att_state[2]
else:
wkv = WKVLinearAttention.apply(self.time_decay, self.time_first, key, value)
x = self.proj_output(receptance * wkv)
return x, state