Spaces:
Runtime error
Runtime error
| from datasets import load_dataset | |
| from disaggregators import Disaggregator, DisaggregationModuleLabels, CustomDisaggregator | |
| from disaggregators.disaggregation_modules.age import Age, AgeLabels, AgeConfig | |
| import matplotlib | |
| matplotlib.use('TKAgg') | |
| import joblib | |
| import os | |
| cache_file = "cached_data.pkl" | |
| cache_dict = {} | |
| if os.path.exists(cache_file): | |
| cache_dict = joblib.load("cached_data.pkl") | |
| class MeSHAgeLabels(AgeLabels): | |
| INFANT = "infant" | |
| CHILD_PRESCHOOL = "child_preschool" | |
| CHILD = "child" | |
| ADOLESCENT = "adolescent" | |
| ADULT = "adult" | |
| MIDDLE_AGED = "middle_aged" | |
| AGED = "aged" | |
| AGED_80_OVER = "aged_80_over" | |
| age = Age( | |
| config=AgeConfig( | |
| labels=MeSHAgeLabels, | |
| ages=[list(MeSHAgeLabels)], | |
| breakpoints=[0, 2, 5, 12, 18, 44, 64, 79] | |
| ), | |
| column="question" | |
| ) | |
| class TabsSpacesLabels(DisaggregationModuleLabels): | |
| TABS = "tabs" | |
| SPACES = "spaces" | |
| class TabsSpaces(CustomDisaggregator): | |
| module_id = "tabs_spaces" | |
| labels = TabsSpacesLabels | |
| def __call__(self, row, *args, **kwargs): | |
| if "\t" in row[self.column]: | |
| return {self.labels.TABS: True, self.labels.SPACES: False} | |
| else: | |
| return {self.labels.TABS: False, self.labels.SPACES: True} | |
| class ReactComponentLabels(DisaggregationModuleLabels): | |
| CLASS = "class" | |
| FUNCTION = "function" | |
| class ReactComponent(CustomDisaggregator): | |
| module_id = "react_component" | |
| labels = ReactComponentLabels | |
| def __call__(self, row, *args, **kwargs): | |
| if "extends React.Component" in row[self.column] or "extends Component" in row[self.column]: | |
| return {self.labels.CLASS: True, self.labels.FUNCTION: False} | |
| else: | |
| return {self.labels.CLASS: False, self.labels.FUNCTION: True} | |
| configs = { | |
| "laion": { | |
| "disaggregation_modules": ["continent"], | |
| "dataset_name": "society-ethics/laion2B-en_continents", | |
| "column": "TEXT", | |
| "feature_names": { | |
| "continent.africa": "Africa", | |
| "continent.americas": "Americas", | |
| "continent.asia": "Asia", | |
| "continent.europe": "Europe", | |
| "continent.oceania": "Oceania", | |
| # Parent level | |
| "continent": "Continent", | |
| } | |
| }, | |
| "medmcqa": { | |
| "disaggregation_modules": [age, "gender"], | |
| "dataset_name": "society-ethics/medmcqa_age_gender_custom", | |
| "column": "question", | |
| "feature_names": { | |
| "age.infant": "Infant", | |
| "age.child_preschool": "Preschool", | |
| "age.child": "Child", | |
| "age.adolescent": "Adolescent", | |
| "age.adult": "Adult", | |
| "age.middle_aged": "Middle Aged", | |
| "age.aged": "Aged", | |
| "age.aged_80_over": "Aged 80+", | |
| "gender.male": "Male", | |
| "gender.female": "Female", | |
| # Parent level | |
| "gender": "Gender", | |
| "age": "Age", | |
| "Both": "Age + Gender", | |
| } | |
| }, | |
| "stack": { | |
| "disaggregation_modules": [TabsSpaces, ReactComponent], | |
| "dataset_name": "society-ethics/the-stack-tabs_spaces", | |
| "column": "content", | |
| "feature_names": { | |
| "react_component.class": "Class", | |
| "react_component.function": "Function", | |
| "tabs_spaces.tabs": "Tabs", | |
| "tabs_spaces.spaces": "Spaces", | |
| # Parent level | |
| "react_component": "React Component Syntax", | |
| "tabs_spaces": "Tabs vs. Spaces", | |
| "Both": "React Component Syntax + Tabs vs. Spaces", | |
| } | |
| } | |
| } | |
| def generate_cached_data(disaggregation_modules, dataset_name, column, feature_names): | |
| disaggregator = Disaggregator(disaggregation_modules, column=column) | |
| ds = load_dataset(dataset_name, split="train") | |
| df = ds.to_pandas() | |
| all_fields = {*disaggregator.fields, "None"} | |
| distributions = df[sorted(list(disaggregator.fields))].value_counts() | |
| return { | |
| "fields": all_fields, | |
| "data_fields": disaggregator.fields, | |
| "distributions": distributions, | |
| "disaggregators": [module.name for module in disaggregator.modules], | |
| "column": column, | |
| "feature_names": feature_names, | |
| } | |
| cache_dict.update({ | |
| "laion": generate_cached_data(**configs["laion"]), | |
| "medmcqa": generate_cached_data(**configs["medmcqa"]), | |
| "stack": generate_cached_data(**configs["stack"]) | |
| }) | |
| joblib.dump(cache_dict, cache_file) | |