from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField
from pyspark.sql.types import DoubleType, IntegerType, StringType
from pyspark.sql.functions import lit,avg,sum,col,when,collect_list
from pyspark import StorageLevel
from pyspark.sql.functions import monotonically_increasing_id, row_number
from pyspark.sql.window import Window
import sys 

spark = SparkSession.builder.appName("Data Skew").getOrCreate()

# Create schemas 
lineSchema = StructType([
  StructField("l_orderkey", StringType()),
  StructField("l_partkey", IntegerType()),
  StructField("l_suppkey", IntegerType()),
  StructField("l_linenumber", IntegerType()),
  StructField("l_quantity", IntegerType()),
  StructField("l_extendedprice", IntegerType()),
  StructField("l_discount", DoubleType()),
  StructField("l_tax", DoubleType()),
  StructField("l_returnflag", StringType()),
  StructField("l_linestatus", StringType()),
  StructField("l_shipdate", StringType()),
  StructField("l_commitdate", StringType()),
  StructField("l_receiptdate", StringType()),
  StructField("l_shipinstruct", StringType()),
  StructField("l_shipmode", StringType()),
  StructField("l_comment", StringType())
])

partSchema = StructType([
  StructField("p_partkey", IntegerType()),
  StructField("p_name", StringType()),
  StructField("p_mfgr", StringType()),
  StructField("p_brand", StringType()),
  StructField("p_type", StringType()),
  StructField("p_size", IntegerType()),
  StructField("p_container", StringType()),
  StructField("p_retailprice", DoubleType()),
  StructField("p_comment", DoubleType())
])

# Create Dataframes
lineitem_df = spark.read.schema(lineSchema).option("delimiter","|").csv("s3://redshift-downloads/TPC-H/2.18/3TB/lineitem/")
part_df = spark.read.schema(partSchema).option("delimiter","|").csv("s3://redshift-downloads/TPC-H/2.18/3TB/part/")

# Cache a small dataframe (60 GB)
part_df.cache()

# Add an ID column
window_df = lineitem_df.withColumn("l_lineid", monotonically_increasing_id())
#lineid_df = lineitem_df.withColumn("id", monotonically_increasing_id())
#windowSpec = Window.orderBy("id")
#window_df = lineid_df.withColumn("l_lineid", row_number().over(windowSpec)).drop("id")

# Generate a data skew based on the ID column
no_df = window_df.withColumn("l_partkey",when(col("l_lineid") <= 2080000000000, lit(22)).otherwise(col("l_partkey")))

# Cache a large dataframe (2.2 TB)
no_df.cache()

# Verify the skew in data
#no_df.createOrReplaceTempView("no_df")
#spark.sql("select count(*),l_partkey from no_df group by l_partkey order by 1 desc").show()
#spark.sql("select count(*) from no_df where l_partkey <> 22").show()

# Join the tables on skewed keys
#join_df = no_df.join(part_df, no_df.l_partkey == part_df.p_partkey)
#print("DataFrame size: ", join_df.rdd.map(lambda x: sys.getsizeof(x)).reduce(lambda x, y: x + y) / (1024**3), "GB")

# Collect list aggregate on skewed keys
coll_df = no_df.groupBy("l_orderkey").agg(collect_list("l_comment")).orderBy("l_orderkey")
print("DataFrame size: ", coll_df.rdd.map(lambda x: sys.getsizeof(x)).reduce(lambda x, y: x + y) / (1024**3), "GB")
