Skip to content

Commit

Permalink
seed once
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Aug 27, 2024
1 parent 9d7398d commit 57ea314
Showing 1 changed file with 19 additions and 21 deletions.
40 changes: 19 additions & 21 deletions tests/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,25 @@

DEBUG = False

# this ensures that the same config will always get the same result
REPRODUCIBLE=True
if REPRODUCIBLE:
skip_seed = 42
else:
skip_seed = time.time()
random.seed(skip_seed)

if DEBUG:
print("skip_seed:", skip_seed)

def skip_config(*args, skip_pct = 0.95):
return random.random() >= (1.0 - skip_pct)

def is_amd():
if torch.version.hip is not None:
return True
return False

MAX_HEADDIM_SM8x = 192


Expand All @@ -29,27 +48,6 @@
is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0)
is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0)

def is_amd():
if torch.version.hip is not None:
return True
return False

def skip_config(*args, reproducible=True, skip_pct = 0.95):
config_str = '_'.join(map(str, args))

if reproducible:
# this ensures that the same config will always get the same result
skip_seed = config_str
else:
skip_seed = time.time()

if DEBUG:
print("skip_seed:", skip_seed)
random.seed(config_str)


return random.random() >= (1.0 - skip_pct)


def attn_bias_from_alibi_slopes(
slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False, key_leftpad=None
Expand Down

0 comments on commit 57ea314

Please sign in to comment.