Generating scatter plot from a CSV file












1














Assume you have the following data in the form of a csv-file. The content looks something like this:





,Action,Comedy,Horror
1,650,819,
,76,63,
2,,462,19
,,18,96
3,652,457,18
,75,36,89


which can be interpreted as a table of the form:



           Action       Comedy       Horror      
1 650 819
76 63
2 462 19
18 96
3 652 457 18
75 36 89


The goal was to write a function that takes a lst with genre names as elements in form of a str and returns a scatter plot of the data, where the data that should appear on the scatter plot is in the second row of every index (76, 63 , and , 18, 96 and 75, 36, 89). The function should be able to distinguish between two-dimensional and three-dimensional scatter plots depending on the input.





from pandas import DataFrame
from csv import reader
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


def genre_scatter(lst):
"""
Creates an scatter plot using the data from genre_scores.csv.
:param lst: a list with names of the genres considered
:return: saves a pdf-file to the folder Fig with the name gen_1_ge_2.pdf
"""
# First we need to determine the right columns of genre_scores.
first_row = [row for row in reader(open('genre_scores.csv', 'r'))][0]
index = [first_row.index(x) for x in lst]

# Get the relevant data in the form of a DataFrame.
# Please note that the first row of data for every index is not necessary for this task.
data = DataFrame.from_csv('genre_scores.csv')
gen_scores = [data.dropna().iloc[1::2, ind - 1].transpose() for ind in index]

# rewrite the values in an flattened array for plotting
coordinates = [gen.as_matrix().flatten() for gen in gen_scores]

# Plot the results
fig = plt.figure()
if len(coordinates) == 2:
plt.scatter(*coordinates)
plt.text(70, 110, "pearson={}".format(round(pearson_coeff(coordinates[0], coordinates[1]), 3)))
plt.xlabel(lst[0])
plt.ylabel(lst[1])
plt.savefig("Fig/{}_{}.pdf".format(*lst))
else:
ax = fig.add_subplot(111, projection='3d')
ax.scatter(*coordinates)
ax.update({'xlabel': lst[0], 'ylabel': lst[1], 'zlabel': lst[2]})
plt.savefig("Fig/{}_{}_{}.pdf".format(*lst))
plt.show()
plt.close("all")


if __name__ == "__main__":
genre_scatter(['Action', 'Horror', 'Comedy'])


The code works and I'm happy with the output but there are a few things that bug me and I'm not sure if I used them right.




  1. I'm not incredibly familiar with list comprehension (I think that is what you call expressions of the form [x for x in list], please correct me if I'm wrong) and haven't used them very often, so I'm not quite sure if this here was the right approach for the problem. My biggest concern is the first use of this kind of expression, where I basically need the first row of the CSV file but create a list with all the rows only to use the first. Is there a smarter way to do this?

  2. Is there a better way to label the axes? Ideally some function where I just could pass the *lst argument?


Please forget the pearson_coeff() part in the code, it's not really relevant for this.










share|improve this question
























  • Are you sure that this code runs? You appear to use csv.reader but don't import it.
    – Reinderien
    1 hour ago










  • @Reinderien Thanks for pointing that out, copied this out of a larger code and forgot to check every import...
    – Sito
    1 hour ago
















1














Assume you have the following data in the form of a csv-file. The content looks something like this:





,Action,Comedy,Horror
1,650,819,
,76,63,
2,,462,19
,,18,96
3,652,457,18
,75,36,89


which can be interpreted as a table of the form:



           Action       Comedy       Horror      
1 650 819
76 63
2 462 19
18 96
3 652 457 18
75 36 89


The goal was to write a function that takes a lst with genre names as elements in form of a str and returns a scatter plot of the data, where the data that should appear on the scatter plot is in the second row of every index (76, 63 , and , 18, 96 and 75, 36, 89). The function should be able to distinguish between two-dimensional and three-dimensional scatter plots depending on the input.





from pandas import DataFrame
from csv import reader
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


def genre_scatter(lst):
"""
Creates an scatter plot using the data from genre_scores.csv.
:param lst: a list with names of the genres considered
:return: saves a pdf-file to the folder Fig with the name gen_1_ge_2.pdf
"""
# First we need to determine the right columns of genre_scores.
first_row = [row for row in reader(open('genre_scores.csv', 'r'))][0]
index = [first_row.index(x) for x in lst]

# Get the relevant data in the form of a DataFrame.
# Please note that the first row of data for every index is not necessary for this task.
data = DataFrame.from_csv('genre_scores.csv')
gen_scores = [data.dropna().iloc[1::2, ind - 1].transpose() for ind in index]

# rewrite the values in an flattened array for plotting
coordinates = [gen.as_matrix().flatten() for gen in gen_scores]

# Plot the results
fig = plt.figure()
if len(coordinates) == 2:
plt.scatter(*coordinates)
plt.text(70, 110, "pearson={}".format(round(pearson_coeff(coordinates[0], coordinates[1]), 3)))
plt.xlabel(lst[0])
plt.ylabel(lst[1])
plt.savefig("Fig/{}_{}.pdf".format(*lst))
else:
ax = fig.add_subplot(111, projection='3d')
ax.scatter(*coordinates)
ax.update({'xlabel': lst[0], 'ylabel': lst[1], 'zlabel': lst[2]})
plt.savefig("Fig/{}_{}_{}.pdf".format(*lst))
plt.show()
plt.close("all")


if __name__ == "__main__":
genre_scatter(['Action', 'Horror', 'Comedy'])


The code works and I'm happy with the output but there are a few things that bug me and I'm not sure if I used them right.




  1. I'm not incredibly familiar with list comprehension (I think that is what you call expressions of the form [x for x in list], please correct me if I'm wrong) and haven't used them very often, so I'm not quite sure if this here was the right approach for the problem. My biggest concern is the first use of this kind of expression, where I basically need the first row of the CSV file but create a list with all the rows only to use the first. Is there a smarter way to do this?

  2. Is there a better way to label the axes? Ideally some function where I just could pass the *lst argument?


Please forget the pearson_coeff() part in the code, it's not really relevant for this.










share|improve this question
























  • Are you sure that this code runs? You appear to use csv.reader but don't import it.
    – Reinderien
    1 hour ago










  • @Reinderien Thanks for pointing that out, copied this out of a larger code and forgot to check every import...
    – Sito
    1 hour ago














1












1








1







Assume you have the following data in the form of a csv-file. The content looks something like this:





,Action,Comedy,Horror
1,650,819,
,76,63,
2,,462,19
,,18,96
3,652,457,18
,75,36,89


which can be interpreted as a table of the form:



           Action       Comedy       Horror      
1 650 819
76 63
2 462 19
18 96
3 652 457 18
75 36 89


The goal was to write a function that takes a lst with genre names as elements in form of a str and returns a scatter plot of the data, where the data that should appear on the scatter plot is in the second row of every index (76, 63 , and , 18, 96 and 75, 36, 89). The function should be able to distinguish between two-dimensional and three-dimensional scatter plots depending on the input.





from pandas import DataFrame
from csv import reader
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


def genre_scatter(lst):
"""
Creates an scatter plot using the data from genre_scores.csv.
:param lst: a list with names of the genres considered
:return: saves a pdf-file to the folder Fig with the name gen_1_ge_2.pdf
"""
# First we need to determine the right columns of genre_scores.
first_row = [row for row in reader(open('genre_scores.csv', 'r'))][0]
index = [first_row.index(x) for x in lst]

# Get the relevant data in the form of a DataFrame.
# Please note that the first row of data for every index is not necessary for this task.
data = DataFrame.from_csv('genre_scores.csv')
gen_scores = [data.dropna().iloc[1::2, ind - 1].transpose() for ind in index]

# rewrite the values in an flattened array for plotting
coordinates = [gen.as_matrix().flatten() for gen in gen_scores]

# Plot the results
fig = plt.figure()
if len(coordinates) == 2:
plt.scatter(*coordinates)
plt.text(70, 110, "pearson={}".format(round(pearson_coeff(coordinates[0], coordinates[1]), 3)))
plt.xlabel(lst[0])
plt.ylabel(lst[1])
plt.savefig("Fig/{}_{}.pdf".format(*lst))
else:
ax = fig.add_subplot(111, projection='3d')
ax.scatter(*coordinates)
ax.update({'xlabel': lst[0], 'ylabel': lst[1], 'zlabel': lst[2]})
plt.savefig("Fig/{}_{}_{}.pdf".format(*lst))
plt.show()
plt.close("all")


if __name__ == "__main__":
genre_scatter(['Action', 'Horror', 'Comedy'])


The code works and I'm happy with the output but there are a few things that bug me and I'm not sure if I used them right.




  1. I'm not incredibly familiar with list comprehension (I think that is what you call expressions of the form [x for x in list], please correct me if I'm wrong) and haven't used them very often, so I'm not quite sure if this here was the right approach for the problem. My biggest concern is the first use of this kind of expression, where I basically need the first row of the CSV file but create a list with all the rows only to use the first. Is there a smarter way to do this?

  2. Is there a better way to label the axes? Ideally some function where I just could pass the *lst argument?


Please forget the pearson_coeff() part in the code, it's not really relevant for this.










share|improve this question















Assume you have the following data in the form of a csv-file. The content looks something like this:





,Action,Comedy,Horror
1,650,819,
,76,63,
2,,462,19
,,18,96
3,652,457,18
,75,36,89


which can be interpreted as a table of the form:



           Action       Comedy       Horror      
1 650 819
76 63
2 462 19
18 96
3 652 457 18
75 36 89


The goal was to write a function that takes a lst with genre names as elements in form of a str and returns a scatter plot of the data, where the data that should appear on the scatter plot is in the second row of every index (76, 63 , and , 18, 96 and 75, 36, 89). The function should be able to distinguish between two-dimensional and three-dimensional scatter plots depending on the input.





from pandas import DataFrame
from csv import reader
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


def genre_scatter(lst):
"""
Creates an scatter plot using the data from genre_scores.csv.
:param lst: a list with names of the genres considered
:return: saves a pdf-file to the folder Fig with the name gen_1_ge_2.pdf
"""
# First we need to determine the right columns of genre_scores.
first_row = [row for row in reader(open('genre_scores.csv', 'r'))][0]
index = [first_row.index(x) for x in lst]

# Get the relevant data in the form of a DataFrame.
# Please note that the first row of data for every index is not necessary for this task.
data = DataFrame.from_csv('genre_scores.csv')
gen_scores = [data.dropna().iloc[1::2, ind - 1].transpose() for ind in index]

# rewrite the values in an flattened array for plotting
coordinates = [gen.as_matrix().flatten() for gen in gen_scores]

# Plot the results
fig = plt.figure()
if len(coordinates) == 2:
plt.scatter(*coordinates)
plt.text(70, 110, "pearson={}".format(round(pearson_coeff(coordinates[0], coordinates[1]), 3)))
plt.xlabel(lst[0])
plt.ylabel(lst[1])
plt.savefig("Fig/{}_{}.pdf".format(*lst))
else:
ax = fig.add_subplot(111, projection='3d')
ax.scatter(*coordinates)
ax.update({'xlabel': lst[0], 'ylabel': lst[1], 'zlabel': lst[2]})
plt.savefig("Fig/{}_{}_{}.pdf".format(*lst))
plt.show()
plt.close("all")


if __name__ == "__main__":
genre_scatter(['Action', 'Horror', 'Comedy'])


The code works and I'm happy with the output but there are a few things that bug me and I'm not sure if I used them right.




  1. I'm not incredibly familiar with list comprehension (I think that is what you call expressions of the form [x for x in list], please correct me if I'm wrong) and haven't used them very often, so I'm not quite sure if this here was the right approach for the problem. My biggest concern is the first use of this kind of expression, where I basically need the first row of the CSV file but create a list with all the rows only to use the first. Is there a smarter way to do this?

  2. Is there a better way to label the axes? Ideally some function where I just could pass the *lst argument?


Please forget the pearson_coeff() part in the code, it's not really relevant for this.







python beginner python-3.x






share|improve this question















share|improve this question













share|improve this question




share|improve this question








edited 24 mins ago









Jamal

30.2k11116226




30.2k11116226










asked 4 hours ago









Sito

1335




1335












  • Are you sure that this code runs? You appear to use csv.reader but don't import it.
    – Reinderien
    1 hour ago










  • @Reinderien Thanks for pointing that out, copied this out of a larger code and forgot to check every import...
    – Sito
    1 hour ago


















  • Are you sure that this code runs? You appear to use csv.reader but don't import it.
    – Reinderien
    1 hour ago










  • @Reinderien Thanks for pointing that out, copied this out of a larger code and forgot to check every import...
    – Sito
    1 hour ago
















Are you sure that this code runs? You appear to use csv.reader but don't import it.
– Reinderien
1 hour ago




Are you sure that this code runs? You appear to use csv.reader but don't import it.
– Reinderien
1 hour ago












@Reinderien Thanks for pointing that out, copied this out of a larger code and forgot to check every import...
– Sito
1 hour ago




@Reinderien Thanks for pointing that out, copied this out of a larger code and forgot to check every import...
– Sito
1 hour ago










1 Answer
1






active

oldest

votes


















0














This really isn't bad, in terms of base Python. The only thing that stands out to me is this:



first_row = [row for row in reader(open('genre_scores.csv', 'r'))][0]


Firstly, you aren't closing the file. Always close the file after you're done.



'r' is implicit, so you don't need to write it in the arguments to open.



Also, you're building up an entire list in memory from the CSV file, and then throwing it all away only to use the first row. Instead, you should use something like:



with open('genre_scores.csv') as f:
csv_reader = reader(f)
first_row = next(csv_reader)


You also ask:




I'd like to implement something that makes sure that lst isn't longer than three elements (since four dimensional plots aren't really a thing). The only way I know to do this is assert len(lst) <=3, which gets the job done but it would be nice if it also could raise a useful error message.




Fairly straightforward, and I'll also assume that the minimum is 2:



if not (2 <= len(lst) <= 3):
raise ValueError(f'Invalid lst length of {len(lst)}')





share|improve this answer























    Your Answer





    StackExchange.ifUsing("editor", function () {
    return StackExchange.using("mathjaxEditing", function () {
    StackExchange.MarkdownEditor.creationCallbacks.add(function (editor, postfix) {
    StackExchange.mathjaxEditing.prepareWmdForMathJax(editor, postfix, [["\$", "\$"]]);
    });
    });
    }, "mathjax-editing");

    StackExchange.ifUsing("editor", function () {
    StackExchange.using("externalEditor", function () {
    StackExchange.using("snippets", function () {
    StackExchange.snippets.init();
    });
    });
    }, "code-snippets");

    StackExchange.ready(function() {
    var channelOptions = {
    tags: "".split(" "),
    id: "196"
    };
    initTagRenderer("".split(" "), "".split(" "), channelOptions);

    StackExchange.using("externalEditor", function() {
    // Have to fire editor after snippets, if snippets enabled
    if (StackExchange.settings.snippets.snippetsEnabled) {
    StackExchange.using("snippets", function() {
    createEditor();
    });
    }
    else {
    createEditor();
    }
    });

    function createEditor() {
    StackExchange.prepareEditor({
    heartbeatType: 'answer',
    autoActivateHeartbeat: false,
    convertImagesToLinks: false,
    noModals: true,
    showLowRepImageUploadWarning: true,
    reputationToPostImages: null,
    bindNavPrevention: true,
    postfix: "",
    imageUploader: {
    brandingHtml: "Powered by u003ca class="icon-imgur-white" href="https://imgur.com/"u003eu003c/au003e",
    contentPolicyHtml: "User contributions licensed under u003ca href="https://creativecommons.org/licenses/by-sa/3.0/"u003ecc by-sa 3.0 with attribution requiredu003c/au003e u003ca href="https://stackoverflow.com/legal/content-policy"u003e(content policy)u003c/au003e",
    allowUrls: true
    },
    onDemand: true,
    discardSelector: ".discard-answer"
    ,immediatelyShowMarkdownHelp:true
    });


    }
    });














    draft saved

    draft discarded


















    StackExchange.ready(
    function () {
    StackExchange.openid.initPostLogin('.new-post-login', 'https%3a%2f%2fcodereview.stackexchange.com%2fquestions%2f210188%2fgenerating-scatter-plot-from-a-csv-file%23new-answer', 'question_page');
    }
    );

    Post as a guest















    Required, but never shown

























    1 Answer
    1






    active

    oldest

    votes








    1 Answer
    1






    active

    oldest

    votes









    active

    oldest

    votes






    active

    oldest

    votes









    0














    This really isn't bad, in terms of base Python. The only thing that stands out to me is this:



    first_row = [row for row in reader(open('genre_scores.csv', 'r'))][0]


    Firstly, you aren't closing the file. Always close the file after you're done.



    'r' is implicit, so you don't need to write it in the arguments to open.



    Also, you're building up an entire list in memory from the CSV file, and then throwing it all away only to use the first row. Instead, you should use something like:



    with open('genre_scores.csv') as f:
    csv_reader = reader(f)
    first_row = next(csv_reader)


    You also ask:




    I'd like to implement something that makes sure that lst isn't longer than three elements (since four dimensional plots aren't really a thing). The only way I know to do this is assert len(lst) <=3, which gets the job done but it would be nice if it also could raise a useful error message.




    Fairly straightforward, and I'll also assume that the minimum is 2:



    if not (2 <= len(lst) <= 3):
    raise ValueError(f'Invalid lst length of {len(lst)}')





    share|improve this answer




























      0














      This really isn't bad, in terms of base Python. The only thing that stands out to me is this:



      first_row = [row for row in reader(open('genre_scores.csv', 'r'))][0]


      Firstly, you aren't closing the file. Always close the file after you're done.



      'r' is implicit, so you don't need to write it in the arguments to open.



      Also, you're building up an entire list in memory from the CSV file, and then throwing it all away only to use the first row. Instead, you should use something like:



      with open('genre_scores.csv') as f:
      csv_reader = reader(f)
      first_row = next(csv_reader)


      You also ask:




      I'd like to implement something that makes sure that lst isn't longer than three elements (since four dimensional plots aren't really a thing). The only way I know to do this is assert len(lst) <=3, which gets the job done but it would be nice if it also could raise a useful error message.




      Fairly straightforward, and I'll also assume that the minimum is 2:



      if not (2 <= len(lst) <= 3):
      raise ValueError(f'Invalid lst length of {len(lst)}')





      share|improve this answer


























        0












        0








        0






        This really isn't bad, in terms of base Python. The only thing that stands out to me is this:



        first_row = [row for row in reader(open('genre_scores.csv', 'r'))][0]


        Firstly, you aren't closing the file. Always close the file after you're done.



        'r' is implicit, so you don't need to write it in the arguments to open.



        Also, you're building up an entire list in memory from the CSV file, and then throwing it all away only to use the first row. Instead, you should use something like:



        with open('genre_scores.csv') as f:
        csv_reader = reader(f)
        first_row = next(csv_reader)


        You also ask:




        I'd like to implement something that makes sure that lst isn't longer than three elements (since four dimensional plots aren't really a thing). The only way I know to do this is assert len(lst) <=3, which gets the job done but it would be nice if it also could raise a useful error message.




        Fairly straightforward, and I'll also assume that the minimum is 2:



        if not (2 <= len(lst) <= 3):
        raise ValueError(f'Invalid lst length of {len(lst)}')





        share|improve this answer














        This really isn't bad, in terms of base Python. The only thing that stands out to me is this:



        first_row = [row for row in reader(open('genre_scores.csv', 'r'))][0]


        Firstly, you aren't closing the file. Always close the file after you're done.



        'r' is implicit, so you don't need to write it in the arguments to open.



        Also, you're building up an entire list in memory from the CSV file, and then throwing it all away only to use the first row. Instead, you should use something like:



        with open('genre_scores.csv') as f:
        csv_reader = reader(f)
        first_row = next(csv_reader)


        You also ask:




        I'd like to implement something that makes sure that lst isn't longer than three elements (since four dimensional plots aren't really a thing). The only way I know to do this is assert len(lst) <=3, which gets the job done but it would be nice if it also could raise a useful error message.




        Fairly straightforward, and I'll also assume that the minimum is 2:



        if not (2 <= len(lst) <= 3):
        raise ValueError(f'Invalid lst length of {len(lst)}')






        share|improve this answer














        share|improve this answer



        share|improve this answer








        edited 1 hour ago

























        answered 1 hour ago









        Reinderien

        2,241617




        2,241617






























            draft saved

            draft discarded




















































            Thanks for contributing an answer to Code Review Stack Exchange!


            • Please be sure to answer the question. Provide details and share your research!

            But avoid



            • Asking for help, clarification, or responding to other answers.

            • Making statements based on opinion; back them up with references or personal experience.


            Use MathJax to format equations. MathJax reference.


            To learn more, see our tips on writing great answers.





            Some of your past answers have not been well-received, and you're in danger of being blocked from answering.


            Please pay close attention to the following guidance:


            • Please be sure to answer the question. Provide details and share your research!

            But avoid



            • Asking for help, clarification, or responding to other answers.

            • Making statements based on opinion; back them up with references or personal experience.


            To learn more, see our tips on writing great answers.




            draft saved


            draft discarded














            StackExchange.ready(
            function () {
            StackExchange.openid.initPostLogin('.new-post-login', 'https%3a%2f%2fcodereview.stackexchange.com%2fquestions%2f210188%2fgenerating-scatter-plot-from-a-csv-file%23new-answer', 'question_page');
            }
            );

            Post as a guest















            Required, but never shown





















































            Required, but never shown














            Required, but never shown












            Required, but never shown







            Required, but never shown

































            Required, but never shown














            Required, but never shown












            Required, but never shown







            Required, but never shown







            Popular posts from this blog

            Morgemoulin

            Scott Moir

            Souastre