Spaces:
Running
Running
| import polars as pl | |
| import os | |
| from tqdm.auto import tqdm | |
| import pykakasi | |
| from huggingface_hub import snapshot_download | |
| import numpy as np | |
| from string import ascii_letters | |
| from convert import ( | |
| aux_global_id_to_code, presult, | |
| team_name_short, | |
| ball_kind, ball_kind_code, general_ball_kind, general_ball_kind_code, lr, | |
| game_kind | |
| ) | |
| DATA_PATH = snapshot_download( | |
| repo_id='Ramos-Ramos/npb_data_app', | |
| repo_type='dataset', | |
| local_dir='./files', | |
| cache_dir='./.cache', | |
| allow_patterns=['*/pbp_data.parquet', '*/pbp_text.parquet', '*/pbp_aux.parquet', '*/schedule.parquet', '*/aux_schedule.parquet', 'players.parquet', 'players_translated.parquet', 'players_translated_manual.parquet'] | |
| ) | |
| SEASONS = [2021, 2022, 2023, 2024, 2025] | |
| data_df = pl.DataFrame() | |
| text_df = pl.DataFrame() | |
| aux_df = pl.DataFrame() | |
| sched_df = pl.DataFrame() | |
| aux_sched_df = pl.DataFrame() | |
| for season in tqdm(SEASONS): | |
| _data_df = pl.read_parquet(os.path.join(DATA_PATH, str(season), 'pbp_data.parquet')) | |
| data_df = pl.concat((data_df, _data_df)) | |
| _text_df = pl.read_parquet(os.path.join(DATA_PATH, str(season), 'pbp_text.parquet')) | |
| text_df = pl.concat((text_df, _text_df)) | |
| _aux_df = pl.read_parquet(os.path.join(DATA_PATH, str(season), 'pbp_aux.parquet')) | |
| aux_df = pl.concat((aux_df, _aux_df), how='diagonal_relaxed') | |
| _sched_df = pl.read_parquet(os.path.join(DATA_PATH, str(season), 'schedule.parquet')) | |
| sched_df = pl.concat((sched_df, _sched_df)) | |
| _aux_sched_df = pl.read_parquet(os.path.join(DATA_PATH, str(season), 'aux_schedule.parquet')) | |
| aux_sched_df = pl.concat((aux_sched_df, _aux_sched_df)) | |
| aux_df = ( | |
| aux_df | |
| .filter(pl.col('type') != 'RUNNER') | |
| .join(aux_sched_df[['gameGlobalId', 'gameDate']], on='gameGlobalId') | |
| .with_columns( | |
| pl.col('gameDate').str.to_date().dt.strftime('%Y%m%d'), | |
| pl.col('home').struct.field('globalId').replace_strict(aux_global_id_to_code).alias('home'), | |
| pl.col('visitor').struct.field('globalId').replace_strict(aux_global_id_to_code).alias('visitor'), | |
| pl.when(pl.col('tob') == 'Top').then(pl.lit('1')).otherwise(pl.lit('2')).alias('tob_code'), | |
| ) | |
| .filter( | |
| # pl.col('pitch').struct.field('count') > 0 | |
| # either one alone should be enough but let's use them together to be safe | |
| ~((pl.col('code') == 98) & (pl.col('id') == 1)) | |
| ) | |
| .with_columns( | |
| (pl.col('pitch').struct.field('count') == 1).cum_sum().over(['gameGlobalId', 'inning', 'tob']).alias('pa_count') | |
| ) | |
| .with_columns( | |
| pl.col('code').is_in([6402, 6404, 6406, 6405]).any().over(['gameGlobalId', 'inning', 'tob', 'pa_count']).alias('ibb') | |
| ) | |
| .with_columns( | |
| pl.when(~pl.col('ibb')).then(pl.col('pitch').struct.field('count') == 1).cum_sum().over(['gameGlobalId', 'inning', 'tob']).alias('new_pa_count') | |
| ) | |
| .with_columns( | |
| pl.len().over(['gameGlobalId', 'inning', 'tob', 'new_pa_count']).alias('pa_pitches'), | |
| pl.max('new_pa_count').over(['gameGlobalId', 'inning', 'tob']).alias('inning_pas') | |
| ) | |
| .with_columns( | |
| ( | |
| pl.col('gameDate') + '_' + \ | |
| pl.col('visitor') + '_' + \ | |
| pl.col('home') + '_' + \ | |
| pl.col('inning').str.zfill(2) + pl.when(pl.col('tob') == 'Top').then(pl.lit('1')).otherwise(pl.lit('2')) + pl.col('new_pa_count').cast(pl.String).str.zfill(2) + '_' +\ | |
| pl.col('pitch').struct.field('count').cast(pl.String) | |
| ).alias('universal_code'), | |
| ( | |
| pl.col('gameDate') + '_' + \ | |
| pl.col('visitor') + '_' + \ | |
| pl.col('home') + '_' + \ | |
| pl.col('inning').str.zfill(2) + pl.when(pl.col('tob') == 'Top').then(pl.lit('1')).otherwise(pl.lit('2')) | |
| ).alias('inning_code'), | |
| ( | |
| pl.col('gameDate') + '_' + \ | |
| pl.col('visitor') + '_' + \ | |
| pl.col('home') + '_' + \ | |
| pl.col('inning').str.zfill(2) + pl.when(pl.col('tob') == 'Top').then(pl.lit('1')).otherwise(pl.lit('2')) + pl.col('new_pa_count').cast(pl.String).str.zfill(2) | |
| ).alias('pa_code') | |
| ) | |
| ) | |
| data_df = ( | |
| data_df | |
| .with_columns( | |
| *[ | |
| pl.col(col).cast(pl.Int32) | |
| for col | |
| in ['gameId', 'ballKind', 'ballSpeed', 'x', 'y', 'presult', 'bresult', 'battedX', 'battedY'] | |
| ], | |
| pl.col('UpdatedAt').str.to_datetime(), | |
| pl.col('fiveDigitSerialNumber').str.slice(offset=0, length=3).alias('half_inning'), | |
| pl.col('fiveDigitSerialNumber').str.slice(offset=3, length=2).alias('batter'), | |
| ) | |
| .with_columns( | |
| # pl.count('ID').over(['gameId', 'fiveDigitSerialNumber']).alias('pa_pitches') | |
| (~pl.col('presult').is_in([0])).sum().over(['gameId', 'fiveDigitSerialNumber']).alias('pa_pitches'), | |
| pl.col('presult').is_in([139]).any().over(['gameId', 'fiveDigitSerialNumber']).alias('ibb') | |
| ) | |
| .filter( | |
| (pl.col('pa_pitches') > 0) | |
| ) | |
| .with_columns( | |
| pl.when(~pl.col('ibb')).then(pl.col('batter')) | |
| ) | |
| .with_columns( | |
| pl.when(~pl.col('ibb')).then(pl.col('batter').rank('dense')).over(['gameId', 'half_inning']).cast(pl.String).str.zfill(2).alias('new_batter') | |
| ) | |
| .with_columns( | |
| (pl.col('half_inning') + pl.col('new_batter')).alias('newFiveDigitSerialNumber') | |
| ) | |
| .with_columns(pl.max('new_batter').cast(pl.Int32).over(['gameId', pl.col('newFiveDigitSerialNumber').str.slice(offset=0, length=3)]).alias('inning_pas')) | |
| .join( | |
| ( | |
| sched_df[['GameID', 'HomeTeamNameES', 'VisitorTeamNameES']] | |
| .rename({'GameID': 'gameId'}) | |
| .with_columns( | |
| pl.col('HomeTeamNameES').replace_strict(team_name_short).alias('home_team_name_short'), | |
| pl.col('VisitorTeamNameES').replace_strict(team_name_short).alias('visitor_team_name_short') | |
| ) | |
| ), | |
| on='gameId' | |
| ) | |
| .with_columns(pl.col('UpdatedAt').dt.strftime('%Y%m%d').alias('date')) | |
| .with_columns( | |
| (pl.col('date') + '_' + pl.col('VisitorTeamNameES') + '_' + pl.col('HomeTeamNameES') + '_' + pl.col('newFiveDigitSerialNumber')).alias('universal_code') + '_' + pl.col('atBatBallCount'), | |
| (pl.col('date') + '_' + pl.col('VisitorTeamNameES') + '_' + pl.col('HomeTeamNameES') + '_' + pl.col('newFiveDigitSerialNumber').str.slice(offset=0, length=3)).alias('inning_code'), | |
| (pl.col('date') + '_' + pl.col('VisitorTeamNameES') + '_' + pl.col('HomeTeamNameES') + '_' + pl.col('newFiveDigitSerialNumber')).alias('pa_code') | |
| ) | |
| .join( | |
| ( | |
| aux_df.filter(~pl.col('ibb'))[['universal_code', 'battingResult', 'inning_pas', 'pa_pitches', 'beforeBso', 'bso']] | |
| .rename({'battingResult': 'aux_bresult', 'inning_pas': 'aux_inning_pas', 'pa_pitches': 'aux_pa_pitches'}) | |
| ), | |
| on='universal_code', | |
| how='left' | |
| ) | |
| # .join( | |
| # players_df.rename({'name_en': 'pitcher_name'}), left_on='pitId', right_on='playerId', how='left' | |
| # ) | |
| .join( | |
| text_df[['GameID', 'GameKindID']].with_columns( | |
| pl.col('GameID').cast(pl.Int32), | |
| pl.col('GameKindID').cast(pl.Int32), | |
| ).unique(), | |
| how='left', | |
| left_on='gameId', | |
| right_on='GameID' | |
| ) | |
| .with_columns(pl.col('GameKindID').replace_strict(game_kind).alias('GameKindName')) | |
| .with_columns( | |
| pl.when((pl.col('inning_pas') == pl.col('aux_inning_pas')) & (pl.col('pa_pitches') == pl.col('aux_pa_pitches'))) | |
| .then('aux_bresult') | |
| .alias('aux_bresult'), | |
| pl.col('x').add(-100).mul(-1), | |
| pl.col('y').neg().add(250), | |
| pl.col('presult').alias('presult_id'), | |
| pl.col('ballKind').replace_strict(ball_kind), | |
| pl.col('ballKind').replace_strict(ball_kind_code).alias('ballKind_code'), | |
| pl.col('ballKind').replace_strict(general_ball_kind).alias('general_ballKind'), | |
| pl.col('ballKind').replace_strict(general_ball_kind_code).alias('general_ballKind_code'), | |
| pl.col('batLR').replace_strict(lr), | |
| pl.col('pitLR').replace_strict(lr), | |
| pl.col('date').str.to_date('%Y%m%d'), | |
| pl.when(pl.col('GameKindName').str.contains('Regular Season') | (pl.col('GameKindName') == 'Interleague')) | |
| .then(pl.lit('Regular Season')) | |
| .when(~pl.col('GameKindName').is_in(['Spring Training', 'All-Star Game'])) | |
| .then(pl.lit('Postseason')) | |
| .otherwise('GameKindName') | |
| .alias('coarse_game_kind'), | |
| pl.when(pl.col('half_inning').str.ends_with(1)).then('HomeTeamNameES').otherwise('VisitorTeamNameES').alias('pitcher_team'), | |
| pl.when(pl.col('half_inning').str.ends_with(1)).then('home_team_name_short').otherwise('visitor_team_name_short').alias('pitcher_team_name_short'), | |
| pl.when(pl.col('half_inning').str.ends_with(2)).then('HomeTeamNameES').otherwise('VisitorTeamNameES').alias('batter_team'), | |
| pl.when(pl.col('half_inning').str.ends_with(2)).then('home_team_name_short').otherwise('visitor_team_name_short').alias('batter_team_name_short') | |
| ) | |
| .with_columns( | |
| pl.col('presult_id').replace_strict(presult).alias('presult') | |
| ) | |
| .with_columns( | |
| pl.col('presult').is_in(['None', 'Balk', 'Batter interference', 'Catcher interference', 'Pitcher delay', 'Intentional walk', 'Unknown']).not_().alias('pitch'), | |
| pl.col('presult').is_in(['Swinging strike', 'Swinging strikeout']).alias('whiff'), | |
| ) | |
| .with_columns( | |
| (pl.col('pitch') & pl.col('presult').is_in(['Hit by pitch', 'Sacrifice bunt', 'Sacrifice fly', 'Looking strike', 'Ball', 'Walk', 'Looking strikeout', 'Sacrifice hit error', 'Sacrifice fly error', "Sacrifice fielder's choice", 'Bunt strikeout']).not_()).alias('swing'), | |
| (pl.col('whiff') | pl.col('presult').is_in(['Looking strike', 'Uncaught third strike', 'Looking strikeout'])).alias('csw') | |
| ) | |
| .with_columns((pl.col('x').is_between(-60, 60) & pl.col('y').is_between(50, 50+150)).alias('zone')) | |
| .with_columns((pl.col('x').is_between(-40, 40) & pl.col('y').is_between(75, 75+100)).alias('heart')) | |
| .with_columns((pl.col('x').is_between(-80, 80) & pl.col('y').is_between(25, 25+200) & ~pl.col('heart')).alias('shadow')) | |
| .with_columns((pl.col('x').is_between(-100, 101) & pl.col('y').is_between(0, 0+251) & ~pl.col('heart') & ~pl.col('shadow')).alias('chase')) | |
| .filter(pl.col('ballKind_code') != '-') | |
| .unique() | |
| ) | |
| def select_name(names): | |
| ''' | |
| When given mutiple names, | |
| prioritizes the name with ASCII characters (ex. R. マルティネス > マルティネス), | |
| followed by the shorter name (ex. 大勢 > 翁田 大勢) | |
| Names with ASCII characters help differentiate between foreign players, | |
| whlie shorter names are more accurate for players going by shorter names | |
| ''' | |
| lens = [] | |
| for name in names: | |
| if any([char in ascii_letters for char in name]): | |
| return name | |
| else: | |
| lens.append(len(name)) | |
| return names[np.argmin(lens).item()] | |
| # load player dfs | |
| players_df = ( | |
| pl.read_parquet('files/players.parquet') | |
| .with_columns(pl.col('playerName').str.normalize('NFKC').str.replace_all('・', ' ')) | |
| .group_by('playerId').agg(pl.col('playerName').map_elements(select_name, return_dtype=pl.String)) | |
| ) | |
| translated_df = ( | |
| pl.read_parquet('files/players_translated.parquet') | |
| .with_columns(pl.col('name_jp').str.normalize('NFKC').str.replace_all('・', ' ')) | |
| # ['name_jp', 'name_kana', 'name_en'] | |
| ) | |
| manual_translated_df = pl.read_parquet('files/players_translated_manual.parquet') | |
| # get seasons and teams per player id | |
| batter_df = ( | |
| data_df | |
| .with_columns(pl.col('date').dt.year().alias('season')) | |
| .unique(['batId', 'batter_team', 'season']) | |
| ['batId', 'batter_team', 'season'] | |
| .rename({'batId': 'playerId', 'batter_team': 'team'}) | |
| ) | |
| pitcher_df = ( | |
| data_df | |
| .with_columns(pl.col('date').dt.year().alias('season')) | |
| .unique(['pitId', 'pitcher_team', 'season']) | |
| ['pitId', 'pitcher_team', 'season'] | |
| .rename({'pitId': 'playerId', 'pitcher_team': 'team'}) | |
| ) | |
| players_df = players_df.join(pl.concat((pitcher_df, batter_df)).unique(), on='playerId') | |
| # names with no romanization are approximated with kana translation | |
| kks = pykakasi.kakasi() | |
| # take names in parenthesis when they contain an ascii character | |
| translated_df = ( | |
| translated_df | |
| .with_columns( | |
| pl.when(pl.col('name_jp').str.contains(r'\(')) | |
| .then(pl.col('name_jp').str.extract(r'.*\(', 0).str.strip_chars_end(' (')) | |
| .otherwise(pl.col('name_jp')) | |
| .str.replace_all('・', ' ') | |
| .alias('name_jp') | |
| ) | |
| .with_columns(pl.col('name_kana').str.normalize('NFKC').str.replace_all('・', ' ')) | |
| .with_columns(pl.col('name_kana').str.extract(r'\(.*\)', 0).str.strip_chars('()').alias('in_parentheses')) | |
| .with_columns(pl.col('name_kana').str.extract(r'.*\(', 0).str.strip_chars_end('(').alias('before_parentheses')) | |
| .with_columns( | |
| pl.when(pl.col('name_en').is_null()) | |
| .then | |
| ( | |
| pl.when(pl.col('in_parentheses').is_not_null() | pl.col('before_parentheses').is_not_null()) | |
| .then( | |
| pl.when(pl.col('in_parentheses').map_elements(lambda name: any([char in ascii_letters for char in name]), pl.Boolean)) | |
| .then(pl.col('in_parentheses')) | |
| .otherwise(pl.col('before_parentheses')) | |
| ) | |
| .otherwise(pl.col('name_kana').map_elements(lambda name: ''.join([word['hepburn'].capitalize() for word in kks.convert(name)]), return_dtype=pl.String)) | |
| ) | |
| .otherwise(pl.col('name_en')) | |
| .alias('name_en') | |
| ) | |
| .with_columns(pl.when(pl.col('name_en') == pl.col('name_en').str.to_uppercase()).then(pl.col('name_en').str.to_titlecase()).otherwise('name_en').str.replace_all(',', '')) | |
| ) | |
| # handle inconsistent kanji between sources | |
| for old_char, new_char in [ | |
| ('崎', '﨑'), | |
| ('高', '髙'), | |
| ('徳', '德'), | |
| ('濱', '濵'), | |
| ('瀬', '瀨') | |
| ]: | |
| players_df = ( | |
| players_df.with_columns( | |
| pl.when(~pl.col('playerName').is_in(translated_df['name_jp'])) | |
| .then(pl.col('playerName').str.replace(old_char, new_char)) | |
| .otherwise('playerName') | |
| ) | |
| ) | |
| # merge player dfs | |
| players_df = ( | |
| players_df | |
| .join( | |
| translated_df | |
| .with_columns( | |
| pl.when(pl.col('name_jp').str.contains(r'\.') & ~pl.col('name_jp').is_in(players_df['playerName'].implode())) | |
| .then(pl.col('name_jp').str.strip_chars(ascii_letters+'.')) | |
| .otherwise('name_jp') | |
| ) | |
| [['name_jp', 'name_en', 'team', 'season']], | |
| left_on=['playerName', 'season', 'team'], | |
| right_on=['name_jp', 'season', 'team'] | |
| ) | |
| ) | |
| print(players_df.filter(pl.len().over('playerId', 'team', 'season') > 1)) | |
| players_df = pl.concat(( | |
| players_df.group_by('playerId').agg(pl.first('name_en')), | |
| manual_translated_df[['playerId', 'name_en']] | |
| )).unique() | |
| print(players_df.filter(pl.len().over('playerId') > 1).sort('playerId')) | |
| # join players to data | |
| data_df = ( | |
| data_df | |
| .join( | |
| players_df.rename({'name_en': 'pitcher_name'})[['playerId', 'pitcher_name']], | |
| left_on='pitId', | |
| right_on='playerId', | |
| how='left' | |
| ) | |
| .join( | |
| players_df.rename({'name_en': 'batter_name'})[['playerId', 'batter_name']], | |
| left_on='batId', | |
| right_on='playerId', | |
| how='left' | |
| ) | |
| ) | |
| if __name__ == '__main__': | |
| breakpoint() | |